Lightning Module API¶
PyTorch Lightning integration for training orchestration and distributed computing.
Overview¶
The Lightning module provides three main components:
ConnectomicsModule: Lightning wrapper for models
ConnectomicsDataModule: Lightning data handling
create_trainer: Convenience function for trainer creation
Quick Example¶
from connectomics.config import load_config
from connectomics.training.lightning import (
ConnectomicsModule,
create_datamodule,
create_trainer
)
from pytorch_lightning import seed_everything
# Load config
cfg = load_config("tutorials/minimal.yaml")
# Set seed
seed_everything(cfg.system.seed)
# Create components
datamodule = create_datamodule(cfg)
model = ConnectomicsModule(cfg)
trainer = create_trainer(cfg)
# Train
trainer.fit(model, datamodule=datamodule)
# Test
trainer.test(model, datamodule=datamodule)
Module Reference¶
ConnectomicsModule¶
- class connectomics.training.lightning.ConnectomicsModule(cfg, model=None, skip_loss=False)[source]¶
Bases:
LightningModulePyTorch Lightning module for connectomics tasks.
This module provides automatic training features including: - Distributed training - Mixed precision - Gradient accumulation - Checkpointing - Logging - Learning rate scheduling
- Parameters
cfg (Union[Config, DictConfig]) – Hydra Config object or OmegaConf DictConfig
model (Optional[nn.Module]) – Optional pre-built model (if None, builds from config)
skip_loss (bool) –
Lightning module wrapper for connectomics models.
This class wraps segmentation models with automatic training features:
Distributed training (DDP)
Mixed precision (AMP)
Gradient accumulation
Learning rate scheduling
Checkpointing
Multi-loss support
Deep supervision
Example:
from connectomics.config import load_config from connectomics.training.lightning import ConnectomicsModule cfg = load_config("tutorials/minimal.yaml") model = ConnectomicsModule(cfg) # Access underlying model print(model.model) # Get model info print(model.get_model_info())
With custom model:
import torch.nn as nn from connectomics.training.lightning import ConnectomicsModule class MyModel(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv3d(1, 2, 3, padding=1) def forward(self, x): return self.conv(x) custom_model = MyModel() lit_model = ConnectomicsModule(cfg, model=custom_model)
- forward(x)[source]¶
Lightning forward pass that delegates to the underlying model.
This is required so Lightning can execute the module during training/inference.
- property inference_manager: connectomics.inference.manager.InferenceManager¶
Lazily build the inference manager on first access.
Train-only runs never trigger this and therefore never validate inference-only knobs (e.g.,
inference.window.blending).
- load_state_dict(state_dict, strict=True)[source]¶
Load checkpoint state with compatibility filtering for stale loss-function buffers.
- on_save_checkpoint(checkpoint)[source]¶
Persist primitive PyTC metadata without embedding config objects.
- on_test_epoch_end()[source]¶
Log aggregated test metrics after all ranks finish their assigned volumes.
- Return type
None
- test_step(batch, batch_idx)[source]¶
Operates on a single batch of data from the test set. In this step you’d normally generate examples or calculate anything of interest such as accuracy.
- Parameters
batch (Dict[str, Tensor]) – The output of your data iterable, normally a
DataLoader.batch_idx (int) – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns
Tensor- The loss tensordict- A dictionary. Can include any keys, but must include the key'loss'.None- Skip to the next batch.
- Return type
# if you have one test dataloader: def test_step(self, batch, batch_idx): ... # if you have multiple test dataloaders: def test_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single test dataset def test_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'test_loss': loss, 'test_acc': test_acc})
If you pass in multiple test dataloaders,
test_step()will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple test dataloaders def test_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. x, y = batch # implement your own out = self(x) if dataloader_idx == 0: loss = self.loss0(out, y) else: loss = self.loss1(out, y) # calculate acc labels_hat = torch.argmax(out, dim=1) acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs separately for each dataloader self.log_dict({f"test_loss_{dataloader_idx}": loss, f"test_acc_{dataloader_idx}": acc})
Note
If you don’t need to test you don’t need to implement this method.
Note
When the
test_step()is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled.
ConnectomicsDataModule¶
- class connectomics.training.lightning.ConnectomicsDataModule(train_data_dicts, val_data_dicts=None, test_data_dicts=None, transforms=None, dataset_type='standard', batch_size=1, num_workers=0, pin_memory=True, persistent_workers=False, cache_rate=1.0, val_steps_per_epoch=None, seed=0, distributed_tta_sharding=False, distributed_window_sharding=False, distributed_chunked_raw_sharding=False, **dataset_kwargs)[source]¶
Bases:
LightningDataModuleLightning DataModule using MONAI Dataset/CacheDataset.
Used as a fallback when pre-loaded cache is not enabled. Transforms (including loading and cropping) are applied on-the-fly.
- Parameters
train_data_dicts (List[Dict[str, Any]]) – Training data dictionaries.
val_data_dicts (Optional[List[Dict[str, Any]]]) – Validation data dictionaries.
test_data_dicts (Optional[List[Dict[str, Any]]]) – Test data dictionaries.
transforms (Optional[Dict[str, Compose]]) – Dict of Compose for ‘train’/’val’/’test’.
dataset_type (str) – ‘standard’ or ‘cached’.
batch_size (int) – Batch size for dataloaders.
num_workers (int) – Number of dataloader workers.
pin_memory (bool) – Pin memory for GPU transfer.
persistent_workers (bool) – Keep workers alive between epochs.
cache_rate (float) – Cache rate for CacheDataset.
val_steps_per_epoch (Optional[int]) – Override validation dataset length.
seed (int) – Random seed for validation reseeding.
distributed_tta_sharding (bool) – Keep test samples replicated on all ranks so TTA passes can be partitioned inside inference rather than by sampler.
distributed_window_sharding (bool) – Keep the single test sample replicated on all ranks so lazy sliding-window patches can be partitioned inside inference.
distributed_chunked_raw_sharding (bool) – Keep the single test sample replicated on all ranks so raw prediction chunks can be partitioned inside inference.
**dataset_kwargs – Extra args (iter_num, sample_size, etc.).
- prepare_data_per_node¶
If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.
- allow_zero_length_dataloader_with_multiple_devices¶
If True, dataloader with zero length within local rank is allowed. Default value is False.
Lightning data module for connectomics datasets.
Handles data loading with MONAI transforms:
Train/val/test splits
MONAI CacheDataset for fast loading
Automatic augmentation pipeline
Persistent workers for efficiency
Example:
from connectomics.config import load_config from connectomics.training.lightning import create_datamodule cfg = load_config("tutorials/minimal.yaml") datamodule = create_datamodule(cfg) # Setup for training datamodule.setup('fit') # Access dataloaders train_loader = datamodule.train_dataloader() val_loader = datamodule.val_dataloader() # Get dataset info print(f"Train samples: {len(datamodule.train_dataset)}") print(f"Val samples: {len(datamodule.val_dataset)}")
- setup(stage=None)[source]¶
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- test_dataloader()[source]¶
An iterable or collection of iterables specifying test samples.
For more information about multiple dataloaders, see this section.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
test()prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Note
If you don’t need a test dataset and a
test_step(), you don’t need to implement this method.
- train_dataloader()[source]¶
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- val_dataloader()[source]¶
An iterable or collection of iterables specifying validation samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~pytorch_lightning.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data().fit()validate()prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
Note
If you don’t need a validation dataset and a
validation_step(), you don’t need to implement this method.
create_trainer¶
- connectomics.training.lightning.create_trainer(cfg, run_dir=None, fast_dev_run=False, ckpt_path=None, mode='train')[source]¶
Create PyTorch Lightning Trainer.
- Parameters
cfg (Config) – Hydra Config object
run_dir (Optional[Path]) – Directory for this training run (required for mode=’train’)
fast_dev_run (bool) – Whether to run quick debug mode
ckpt_path (Optional[str]) – Path to checkpoint for resuming (used to extract best_score)
mode (str) – ‘train’ or ‘test’ - determines which system config to use
- Returns
Configured Trainer instance
- Return type
Trainer
Create PyTorch Lightning Trainer with appropriate callbacks.
Example:
from connectomics.config import load_config from connectomics.training.lightning import create_trainer cfg = load_config("tutorials/minimal.yaml") trainer = create_trainer(cfg) # Access trainer properties print(f"Max epochs: {trainer.max_epochs}") print(f"Precision: {trainer.precision}") print(f"Devices: {trainer.num_devices}")
Custom trainer:
from pytorch_lightning import Trainer from pytorch_lightning.callbacks import EarlyStopping # Create custom trainer trainer = Trainer( max_epochs=100, accelerator='gpu', devices=2, callbacks=[EarlyStopping(monitor='val/loss', patience=10)] )
Training Features¶
Distributed Training¶
Automatically uses DistributedDataParallel (DDP) with multiple GPUs:
system:
num_gpus: 4 # Uses DDP automatically
trainer = create_trainer(cfg) # DDP enabled automatically
Mixed Precision¶
Enable mixed precision for faster training:
optimization:
precision: "16-mixed" # FP16
# or
precision: "bf16-mixed" # BFloat16 (Ampere+ GPUs)
Gradient Accumulation¶
Simulate larger batch sizes:
optimization:
accumulate_grad_batches: 4
Gradient Clipping¶
Prevent exploding gradients:
optimization:
gradient_clip_val: 1.0
Learning Rate Scheduling¶
Automatic LR scheduling with warmup:
optimization:
scheduler:
name: CosineAnnealingLR
warmup_epochs: 5
min_lr: 1e-6
Deep Supervision¶
Multi-scale loss computation:
model:
loss:
deep_supervision: true
losses:
- function: DiceLoss
weight: 1.0
The module automatically:
Computes losses at multiple scales
Resizes ground truth to match each scale
Averages losses across scales
Callbacks¶
The trainer includes several useful callbacks:
Model Checkpointing¶
monitor:
checkpoint:
monitor: "val/loss"
mode: "min"
save_top_k: 3
save_last: true
filename: "epoch{epoch:02d}-loss{val/loss:.2f}"
Early Stopping¶
monitor:
early_stopping:
enabled: true
monitor: "val/loss"
patience: 10
mode: "min"
min_delta: 0.0
Learning Rate Monitoring¶
Automatically logs learning rate to TensorBoard/Wandb.
Logging¶
TensorBoard (Default)¶
monitor:
logging:
scalar:
loss_every_n_steps: 10
Logs are saved to outputs/lightning_logs/.
View with:
tensorboard --logdir outputs/lightning_logs
Weights & Biases (Optional)¶
monitor:
wandb:
use_wandb: true
project: "connectomics"
entity: "your_team"
name: "lucchi_exp"
Advanced Usage¶
Custom Callbacks¶
from pytorch_lightning.callbacks import Callback
class MyCallback(Callback):
def on_train_epoch_end(self, trainer, pl_module):
print(f"Epoch {trainer.current_epoch} finished!")
# Add to trainer
from pytorch_lightning import Trainer
trainer = Trainer(
max_epochs=100,
callbacks=[MyCallback()]
)
Custom Training Step¶
from connectomics.training.lightning import ConnectomicsModule
class CustomModule(ConnectomicsModule):
def training_step(self, batch, batch_idx):
# Custom training logic
images, labels = batch
outputs = self.model(images)
# Custom loss computation
loss = self.compute_loss(outputs, labels)
# Log metrics
self.log('train/loss', loss)
return loss
Inference¶
Single Batch Prediction¶
# Load trained model
model = ConnectomicsModule.load_from_checkpoint(
"outputs/epoch=99.ckpt",
cfg=cfg
)
model.eval()
model.cuda()
# Predict
with torch.no_grad():
output = model(input_batch)
Full Dataset Inference¶
# Load model
model = ConnectomicsModule.load_from_checkpoint(
"outputs/epoch=99.ckpt",
cfg=cfg
)
# Create datamodule
datamodule = create_datamodule(cfg)
# Create trainer
trainer = create_trainer(cfg)
# Run inference
predictions = trainer.predict(model, datamodule=datamodule)
Resuming Training¶
# Resume from checkpoint
trainer = create_trainer(cfg)
trainer.fit(
model,
datamodule=datamodule,
ckpt_path="outputs/last.ckpt"
)
Or from command line:
python scripts/main.py \
--config tutorials/minimal.yaml \
--resume outputs/last.ckpt