Shortcuts

Lightning Module API

PyTorch Lightning integration for training orchestration and distributed computing.

Overview

The Lightning module provides three main components:

  1. ConnectomicsModule: Lightning wrapper for models

  2. ConnectomicsDataModule: Lightning data handling

  3. create_trainer: Convenience function for trainer creation

Quick Example

from connectomics.config import load_config
from connectomics.training.lightning import (
    ConnectomicsModule,
    create_datamodule,
    create_trainer
)
from pytorch_lightning import seed_everything

# Load config
cfg = load_config("tutorials/minimal.yaml")

# Set seed
seed_everything(cfg.system.seed)

# Create components
datamodule = create_datamodule(cfg)
model = ConnectomicsModule(cfg)
trainer = create_trainer(cfg)

# Train
trainer.fit(model, datamodule=datamodule)

# Test
trainer.test(model, datamodule=datamodule)

Module Reference

ConnectomicsModule

class connectomics.training.lightning.ConnectomicsModule(cfg, model=None, skip_loss=False)[source]

Bases: LightningModule

PyTorch Lightning module for connectomics tasks.

This module provides automatic training features including: - Distributed training - Mixed precision - Gradient accumulation - Checkpointing - Logging - Learning rate scheduling

Parameters
  • cfg (Union[Config, DictConfig]) – Hydra Config object or OmegaConf DictConfig

  • model (Optional[nn.Module]) – Optional pre-built model (if None, builds from config)

  • skip_loss (bool) –

Lightning module wrapper for connectomics models.

This class wraps segmentation models with automatic training features:

  • Distributed training (DDP)

  • Mixed precision (AMP)

  • Gradient accumulation

  • Learning rate scheduling

  • Checkpointing

  • Multi-loss support

  • Deep supervision

Example:

from connectomics.config import load_config
from connectomics.training.lightning import ConnectomicsModule

cfg = load_config("tutorials/minimal.yaml")
model = ConnectomicsModule(cfg)

# Access underlying model
print(model.model)

# Get model info
print(model.get_model_info())

With custom model:

import torch.nn as nn
from connectomics.training.lightning import ConnectomicsModule

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv3d(1, 2, 3, padding=1)

    def forward(self, x):
        return self.conv(x)

custom_model = MyModel()
lit_model = ConnectomicsModule(cfg, model=custom_model)
configure_optimizers()[source]

Configure optimizers and learning rate schedulers.

Return type

Dict[str, Any]

forward(x)[source]

Lightning forward pass that delegates to the underlying model.

This is required so Lightning can execute the module during training/inference.

Parameters

x (Tensor) –

Return type

Tensor

property inference_manager: connectomics.inference.manager.InferenceManager

Lazily build the inference manager on first access.

Train-only runs never trigger this and therefore never validate inference-only knobs (e.g., inference.window.blending).

load_state_dict(state_dict, strict=True)[source]

Load checkpoint state with compatibility filtering for stale loss-function buffers.

Parameters
on_save_checkpoint(checkpoint)[source]

Persist primitive PyTC metadata without embedding config objects.

Parameters

checkpoint (Dict[str, Any]) –

Return type

None

on_test_epoch_end()[source]

Log aggregated test metrics after all ranks finish their assigned volumes.

Return type

None

on_test_start()[source]

Called at the beginning of testing to initialize metrics and inferer.

on_train_epoch_end()[source]

Called at the end of training epoch.

Return type

None

on_validation_start()[source]

Called before validation starts.

Return type

None

test_step(batch, batch_idx)[source]

Operates on a single batch of data from the test set. In this step you’d normally generate examples or calculate anything of interest such as accuracy.

Parameters
  • batch (Dict[str, Tensor]) – The output of your data iterable, normally a DataLoader.

  • batch_idx (int) – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns

  • Tensor - The loss tensor

  • dict - A dictionary. Can include any keys, but must include the key 'loss'.

  • None - Skip to the next batch.

Return type

Optional[Union[Tensor, Mapping[str, Any]]]

# if you have one test dataloader:
def test_step(self, batch, batch_idx): ...


# if you have multiple test dataloaders:
def test_step(self, batch, batch_idx, dataloader_idx=0): ...

Examples:

# CASE 1: A single test dataset
def test_step(self, batch, batch_idx):
    x, y = batch

    # implement your own
    out = self(x)
    loss = self.loss(out, y)

    # log 6 example images
    # or generated text... or whatever
    sample_imgs = x[:6]
    grid = torchvision.utils.make_grid(sample_imgs)
    self.logger.experiment.add_image('example_images', grid, 0)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs!
    self.log_dict({'test_loss': loss, 'test_acc': test_acc})

If you pass in multiple test dataloaders, test_step() will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.

# CASE 2: multiple test dataloaders
def test_step(self, batch, batch_idx, dataloader_idx=0):
    # dataloader_idx tells you which dataset this is.
    x, y = batch

    # implement your own
    out = self(x)

    if dataloader_idx == 0:
        loss = self.loss0(out, y)
    else:
        loss = self.loss1(out, y)

    # calculate acc
    labels_hat = torch.argmax(out, dim=1)
    acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

    # log the outputs separately for each dataloader
    self.log_dict({f"test_loss_{dataloader_idx}": loss, f"test_acc_{dataloader_idx}": acc})

Note

If you don’t need to test you don’t need to implement this method.

Note

When the test_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled.

training_step(batch, batch_idx)[source]

Training step with deep supervision support.

Parameters
Return type

Optional[Union[Tensor, Mapping[str, Any]]]

transfer_batch_to_device(batch, device, dataloader_idx)[source]

Keep large test/predict input volumes on CPU for MONAI sliding-window inference.

Parameters
  • batch (Any) –

  • device (device) –

  • dataloader_idx (int) –

Return type

Any

validation_step(batch, batch_idx)[source]

Validation step with deep supervision support.

Parameters
Return type

Optional[Union[Tensor, Mapping[str, Any]]]

ConnectomicsDataModule

class connectomics.training.lightning.ConnectomicsDataModule(train_data_dicts, val_data_dicts=None, test_data_dicts=None, transforms=None, dataset_type='standard', batch_size=1, num_workers=0, pin_memory=True, persistent_workers=False, cache_rate=1.0, val_steps_per_epoch=None, seed=0, distributed_tta_sharding=False, distributed_window_sharding=False, distributed_chunked_raw_sharding=False, **dataset_kwargs)[source]

Bases: LightningDataModule

Lightning DataModule using MONAI Dataset/CacheDataset.

Used as a fallback when pre-loaded cache is not enabled. Transforms (including loading and cropping) are applied on-the-fly.

Parameters
  • train_data_dicts (List[Dict[str, Any]]) – Training data dictionaries.

  • val_data_dicts (Optional[List[Dict[str, Any]]]) – Validation data dictionaries.

  • test_data_dicts (Optional[List[Dict[str, Any]]]) – Test data dictionaries.

  • transforms (Optional[Dict[str, Compose]]) – Dict of Compose for ‘train’/’val’/’test’.

  • dataset_type (str) – ‘standard’ or ‘cached’.

  • batch_size (int) – Batch size for dataloaders.

  • num_workers (int) – Number of dataloader workers.

  • pin_memory (bool) – Pin memory for GPU transfer.

  • persistent_workers (bool) – Keep workers alive between epochs.

  • cache_rate (float) – Cache rate for CacheDataset.

  • val_steps_per_epoch (Optional[int]) – Override validation dataset length.

  • seed (int) – Random seed for validation reseeding.

  • distributed_tta_sharding (bool) – Keep test samples replicated on all ranks so TTA passes can be partitioned inside inference rather than by sampler.

  • distributed_window_sharding (bool) – Keep the single test sample replicated on all ranks so lazy sliding-window patches can be partitioned inside inference.

  • distributed_chunked_raw_sharding (bool) – Keep the single test sample replicated on all ranks so raw prediction chunks can be partitioned inside inference.

  • **dataset_kwargs – Extra args (iter_num, sample_size, etc.).

prepare_data_per_node

If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.

allow_zero_length_dataloader_with_multiple_devices

If True, dataloader with zero length within local rank is allowed. Default value is False.

Lightning data module for connectomics datasets.

Handles data loading with MONAI transforms:

  • Train/val/test splits

  • MONAI CacheDataset for fast loading

  • Automatic augmentation pipeline

  • Persistent workers for efficiency

Example:

from connectomics.config import load_config
from connectomics.training.lightning import create_datamodule

cfg = load_config("tutorials/minimal.yaml")
datamodule = create_datamodule(cfg)

# Setup for training
datamodule.setup('fit')

# Access dataloaders
train_loader = datamodule.train_dataloader()
val_loader = datamodule.val_dataloader()

# Get dataset info
print(f"Train samples: {len(datamodule.train_dataset)}")
print(f"Val samples: {len(datamodule.val_dataset)}")
setup(stage=None)[source]

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters

stage (Optional[str]) – either 'fit', 'validate', 'test', or 'predict'

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
test_dataloader()[source]

An iterable or collection of iterables specifying test samples.

For more information about multiple dataloaders, see this section.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Note

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

train_dataloader()[source]

An iterable or collection of iterables specifying training samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

val_dataloader()[source]

An iterable or collection of iterables specifying validation samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

  • fit()

  • validate()

  • prepare_data()

  • setup()

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

create_trainer

connectomics.training.lightning.create_trainer(cfg, run_dir=None, fast_dev_run=False, ckpt_path=None, mode='train')[source]

Create PyTorch Lightning Trainer.

Parameters
  • cfg (Config) – Hydra Config object

  • run_dir (Optional[Path]) – Directory for this training run (required for mode=’train’)

  • fast_dev_run (bool) – Whether to run quick debug mode

  • ckpt_path (Optional[str]) – Path to checkpoint for resuming (used to extract best_score)

  • mode (str) – ‘train’ or ‘test’ - determines which system config to use

Returns

Configured Trainer instance

Return type

Trainer

Create PyTorch Lightning Trainer with appropriate callbacks.

Example:

from connectomics.config import load_config
from connectomics.training.lightning import create_trainer

cfg = load_config("tutorials/minimal.yaml")
trainer = create_trainer(cfg)

# Access trainer properties
print(f"Max epochs: {trainer.max_epochs}")
print(f"Precision: {trainer.precision}")
print(f"Devices: {trainer.num_devices}")

Custom trainer:

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping

# Create custom trainer
trainer = Trainer(
    max_epochs=100,
    accelerator='gpu',
    devices=2,
    callbacks=[EarlyStopping(monitor='val/loss', patience=10)]
)

Training Features

Distributed Training

Automatically uses DistributedDataParallel (DDP) with multiple GPUs:

system:
  num_gpus: 4  # Uses DDP automatically
trainer = create_trainer(cfg)  # DDP enabled automatically

Mixed Precision

Enable mixed precision for faster training:

optimization:
  precision: "16-mixed"  # FP16
  # or
  precision: "bf16-mixed"  # BFloat16 (Ampere+ GPUs)

Gradient Accumulation

Simulate larger batch sizes:

optimization:
  accumulate_grad_batches: 4

Gradient Clipping

Prevent exploding gradients:

optimization:
  gradient_clip_val: 1.0

Learning Rate Scheduling

Automatic LR scheduling with warmup:

optimization:
  scheduler:
    name: CosineAnnealingLR
    warmup_epochs: 5
    min_lr: 1e-6

Deep Supervision

Multi-scale loss computation:

model:
  loss:
    deep_supervision: true
    losses:
      - function: DiceLoss
        weight: 1.0

The module automatically:

  • Computes losses at multiple scales

  • Resizes ground truth to match each scale

  • Averages losses across scales

Callbacks

The trainer includes several useful callbacks:

Model Checkpointing

monitor:
  checkpoint:
    monitor: "val/loss"
    mode: "min"
    save_top_k: 3
    save_last: true
    filename: "epoch{epoch:02d}-loss{val/loss:.2f}"

Early Stopping

monitor:
  early_stopping:
    enabled: true
    monitor: "val/loss"
    patience: 10
    mode: "min"
    min_delta: 0.0

Learning Rate Monitoring

Automatically logs learning rate to TensorBoard/Wandb.

Logging

TensorBoard (Default)

monitor:
  logging:
    scalar:
      loss_every_n_steps: 10

Logs are saved to outputs/lightning_logs/.

View with:

tensorboard --logdir outputs/lightning_logs

Weights & Biases (Optional)

monitor:
  wandb:
    use_wandb: true
    project: "connectomics"
    entity: "your_team"
    name: "lucchi_exp"

Advanced Usage

Custom Callbacks

from pytorch_lightning.callbacks import Callback

class MyCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        print(f"Epoch {trainer.current_epoch} finished!")

# Add to trainer
from pytorch_lightning import Trainer

trainer = Trainer(
    max_epochs=100,
    callbacks=[MyCallback()]
)

Custom Training Step

from connectomics.training.lightning import ConnectomicsModule

class CustomModule(ConnectomicsModule):
    def training_step(self, batch, batch_idx):
        # Custom training logic
        images, labels = batch
        outputs = self.model(images)

        # Custom loss computation
        loss = self.compute_loss(outputs, labels)

        # Log metrics
        self.log('train/loss', loss)

        return loss

Inference

Single Batch Prediction

# Load trained model
model = ConnectomicsModule.load_from_checkpoint(
    "outputs/epoch=99.ckpt",
    cfg=cfg
)

model.eval()
model.cuda()

# Predict
with torch.no_grad():
    output = model(input_batch)

Full Dataset Inference

# Load model
model = ConnectomicsModule.load_from_checkpoint(
    "outputs/epoch=99.ckpt",
    cfg=cfg
)

# Create datamodule
datamodule = create_datamodule(cfg)

# Create trainer
trainer = create_trainer(cfg)

# Run inference
predictions = trainer.predict(model, datamodule=datamodule)

Resuming Training

# Resume from checkpoint
trainer = create_trainer(cfg)
trainer.fit(
    model,
    datamodule=datamodule,
    ckpt_path="outputs/last.ckpt"
)

Or from command line:

python scripts/main.py \
    --config tutorials/minimal.yaml \
    --resume outputs/last.ckpt

See Also