"""
PyTorch Lightning trainer utilities for PyTorch Connectomics.
This module provides Lightning trainer factory functions with:
- Hydra/OmegaConf configuration
- Modern callbacks (checkpointing, early stopping, logging)
- Distributed training support
- Mixed precision training
"""
from __future__ import annotations
import logging
import os
from pathlib import Path
from typing import Optional
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import (
EarlyStopping,
ModelCheckpoint,
)
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.plugins.environments import LightningEnvironment
from pytorch_lightning.strategies import DDPStrategy
from ...config import Config
from ...runtime.torch_safe_globals import register_torch_safe_globals
from .callbacks import EMAWeightsCallback, ValidationReseedingCallback, VisualizationCallback
_log = logging.getLogger(__name__)
register_torch_safe_globals()
[docs]def create_trainer(
cfg: Config,
run_dir: Optional[Path] = None,
fast_dev_run: bool = False,
ckpt_path: Optional[str] = None,
mode: str = "train",
) -> pl.Trainer:
"""
Create PyTorch Lightning Trainer.
Args:
cfg: Hydra Config object
run_dir: Directory for this training run (required for mode='train')
fast_dev_run: Whether to run quick debug mode
ckpt_path: Path to checkpoint for resuming (used to extract best_score)
mode: 'train' or 'test' - determines which system config to use
Returns:
Configured Trainer instance
"""
_log.info(f"Creating Lightning trainer (mode={mode})...")
# Setup callbacks (only for training mode)
callbacks = []
if mode == "train":
if run_dir is None:
raise ValueError("run_dir is required when mode='train'")
# Setup checkpoint directory
checkpoint_dir = run_dir / "checkpoints"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
# Model checkpoint (in run_dir/checkpoints/)
checkpoint_callback = ModelCheckpoint(
dirpath=str(checkpoint_dir),
filename=cfg.monitor.checkpoint.checkpoint_filename,
monitor=cfg.monitor.checkpoint.monitor,
mode=cfg.monitor.checkpoint.mode,
save_top_k=cfg.monitor.checkpoint.save_top_k,
save_last=cfg.monitor.checkpoint.save_last,
every_n_epochs=cfg.monitor.checkpoint.save_every_n_epochs,
verbose=False,
save_on_train_epoch_end=cfg.monitor.checkpoint.save_on_train_epoch_end,
)
callbacks.append(checkpoint_callback)
save_every_n_steps = getattr(cfg.monitor.checkpoint, "save_every_n_steps", None)
if save_every_n_steps is not None and int(save_every_n_steps) > 0:
step_checkpoint_callback = ModelCheckpoint(
dirpath=str(checkpoint_dir),
filename=getattr(
cfg.monitor.checkpoint,
"step_checkpoint_filename",
"{step:08d}",
),
monitor=None,
save_top_k=-1,
save_last=False,
every_n_train_steps=int(save_every_n_steps),
every_n_epochs=0,
verbose=False,
save_on_train_epoch_end=False,
)
callbacks.append(step_checkpoint_callback)
_log.info(" Step checkpoints: every %d train steps", int(save_every_n_steps))
# Early stopping (training only)
if cfg.monitor.early_stopping.enabled:
# Import here to avoid circular dependency
from .utils import extract_best_score_from_checkpoint
# Extract best_score from checkpoint filename if resuming
best_score = None
if ckpt_path:
best_score = extract_best_score_from_checkpoint(
ckpt_path, cfg.monitor.early_stopping.monitor
)
if best_score is not None:
_log.info(
f" Early stopping: Extracted best_score={best_score:.6f} from checkpoint"
)
early_stop_callback = EarlyStopping(
monitor=cfg.monitor.early_stopping.monitor,
patience=cfg.monitor.early_stopping.patience,
mode=cfg.monitor.early_stopping.mode,
min_delta=cfg.monitor.early_stopping.min_delta,
verbose=False,
check_on_train_epoch_end=True, # Check at end of train epoch (not validation)
check_finite=cfg.monitor.early_stopping.check_finite, # Stop on NaN/inf
stopping_threshold=cfg.monitor.early_stopping.threshold,
divergence_threshold=cfg.monitor.early_stopping.divergence_threshold,
strict=False, # Don't crash if metric not available (wait for it)
)
# Manually set best_score if extracted from checkpoint
if best_score is not None:
early_stop_callback.best_score = torch.tensor(best_score)
callbacks.append(early_stop_callback)
# Visualization callback (training only, end-of-epoch only)
if cfg.monitor.logging.images.enabled:
vis_callback = VisualizationCallback(
cfg=cfg,
max_images=cfg.monitor.logging.images.max_images,
num_slices=cfg.monitor.logging.images.num_slices,
slice_sampling=cfg.monitor.logging.images.slice_sampling,
log_every_n_epochs=cfg.monitor.logging.images.log_every_n_epochs,
)
callbacks.append(vis_callback)
log_freq = cfg.monitor.logging.images.log_every_n_epochs
_log.info(f" Visualization: Enabled (every {log_freq} epoch(s))")
else:
_log.info(" Visualization: Disabled")
# EMA weights for stabler validation
ema_cfg = getattr(cfg.optimization, "ema", None)
if ema_cfg and getattr(ema_cfg, "enabled", False):
ema_callback = EMAWeightsCallback(
decay=getattr(ema_cfg, "decay", 0.999),
warmup_steps=getattr(ema_cfg, "warmup_steps", 0),
validate_with_ema=getattr(ema_cfg, "validate_with_ema", True),
device=getattr(ema_cfg, "device", None),
copy_buffers=getattr(ema_cfg, "copy_buffers", True),
)
callbacks.append(ema_callback)
_log.info(
f" EMA: Enabled (decay={ema_cfg.decay}, warmup_steps={ema_cfg.warmup_steps}, "
f"validate_with_ema={ema_cfg.validate_with_ema})"
)
# [FIX 1 - PROPER IMPLEMENTATION] Validation reseeding callback
# This ensures validation datasets are reseeded at the start of EACH validation epoch
# Previous fix in val_dataloader() only ran once during setup
validation_reseeding_callback = ValidationReseedingCallback(
base_seed=cfg.system.seed,
log_fingerprint=False,
log_all_ranks=False,
verbose=False,
)
callbacks.append(validation_reseeding_callback)
_log.info(f" Validation Reseeding: Enabled (base_seed={cfg.system.seed})")
# Setup logger.
# Train: keep TensorBoard logs in run_dir/logs.
# Test/tune: disable logger so no logs/ folder is created in results output directories.
logger = False
if mode == "train":
if run_dir is None:
raise ValueError("run_dir is required when mode='train'")
logger = TensorBoardLogger(
save_dir=str(run_dir),
name="", # No name subdirectory
version="logs", # Logs go directly to run_dir/logs/
)
_log.info(f" Logger: TensorBoard (logs saved to {run_dir}/logs/)")
# Create trainer
system_cfg = cfg.system
# Check if GPU is actually available
use_gpu = system_cfg.num_gpus > 0 and torch.cuda.is_available()
# Check if anomaly detection is enabled (useful for debugging NaN)
detect_anomaly = getattr(cfg.monitor, "detect_anomaly", False)
if detect_anomaly:
_log.warning("PyTorch anomaly detection ENABLED (training will be slower)")
_log.warning(" This helps pinpoint the exact operation causing NaN in backward pass")
# Configure DDP strategy for multi-GPU training with deep supervision
strategy = "auto" # Default strategy
if system_cfg.num_gpus > 1:
# Multi-GPU training: configure DDP
loss_cfg = getattr(cfg.model, "loss", None)
deep_supervision_enabled = getattr(loss_cfg, "deep_supervision", False)
ddp_find_unused_params = getattr(cfg.model, "ddp_find_unused_parameters", False)
architecture = getattr(getattr(cfg.model, "arch", None), "type", "")
is_mednext = architecture.startswith("mednext")
# MedNeXt always creates deep supervision layers internally (even when disabled)
# so it always needs find_unused_parameters=True
if is_mednext or deep_supervision_enabled or ddp_find_unused_params:
strategy = DDPStrategy(find_unused_parameters=True)
# Determine reason for using find_unused_parameters
if is_mednext and not deep_supervision_enabled:
reason = "MedNeXt (has unused DS layers)"
elif deep_supervision_enabled:
reason = "deep supervision enabled"
else:
reason = "explicit config"
_log.info(f" Strategy: DDP with find_unused_parameters=True ({reason})")
else:
strategy = DDPStrategy(find_unused_parameters=False)
_log.info(" Strategy: DDP (standard)")
# [FIX 2] Implement TRUE step-based training
# PyTorch Lightning stops when EITHER max_epochs OR max_steps is reached
# To ensure step-based training works correctly, we must disable epochs when using steps
max_steps_cfg = getattr(cfg.optimization, "max_steps", None)
if max_steps_cfg is not None and max_steps_cfg > 0:
# Step-based training: disable epoch limit
max_epochs = -1 # -1 means unlimited epochs
max_steps = max_steps_cfg
training_mode = f"step-based ({max_steps:,} steps)"
else:
# Epoch-based training: disable step limit
max_epochs = cfg.optimization.max_epochs
max_steps = -1 # -1 means unlimited steps
training_mode = f"epoch-based ({max_epochs} epochs)"
# Treat optimization.val_check_interval as epoch interval by default.
# BANIS-style reproduction can opt into true step-based validation.
val_check_cfg = cfg.optimization.val_check_interval
if isinstance(val_check_cfg, float):
if not val_check_cfg.is_integer():
raise ValueError(
"optimization.val_check_interval must be an integer number of epochs/steps "
f"(got {val_check_cfg})."
)
val_check_interval = int(val_check_cfg)
else:
val_check_interval = int(val_check_cfg)
if val_check_interval < 1:
raise ValueError(
"optimization.val_check_interval must be >= 1 " f"(got {val_check_interval})."
)
val_check_unit = str(getattr(cfg.optimization, "val_check_interval_unit", "epoch")).lower()
if val_check_unit not in {"epoch", "step"}:
raise ValueError(
"optimization.val_check_interval_unit must be 'epoch' or 'step' "
f"(got {val_check_unit!r})."
)
if val_check_unit == "step":
check_val_every_n_epoch = None
trainer_val_check_interval = val_check_interval
_log.info(f" Validation: every {val_check_interval} train step(s)")
else:
check_val_every_n_epoch = val_check_interval
trainer_val_check_interval = 1.0
_log.info(f" Validation: every {check_val_every_n_epoch} epoch(s)")
# In Slurm jobs launched with ntasks=1, force local process spawning for multi-GPU
# so Lightning uses world_size=devices instead of treating Slurm as externally launched DDP.
plugins = None
slurm_ntasks = os.environ.get("SLURM_NTASKS")
if use_gpu and system_cfg.num_gpus > 1 and slurm_ntasks == "1":
plugins = [LightningEnvironment()]
_log.info(" Launch mode: local multi-GPU spawn (SLURM_NTASKS=1)")
trainer = pl.Trainer(
max_epochs=max_epochs,
max_steps=max_steps,
accelerator="gpu" if use_gpu else "cpu",
devices=system_cfg.num_gpus if use_gpu else 1,
strategy=strategy,
precision=cfg.optimization.precision,
gradient_clip_val=cfg.optimization.gradient_clip_val,
accumulate_grad_batches=cfg.optimization.accumulate_grad_batches,
check_val_every_n_epoch=check_val_every_n_epoch,
val_check_interval=trainer_val_check_interval,
num_sanity_val_steps=cfg.optimization.num_sanity_val_steps,
log_every_n_steps=cfg.optimization.log_every_n_steps,
callbacks=callbacks,
logger=logger,
fast_dev_run=bool(fast_dev_run),
detect_anomaly=detect_anomaly,
enable_progress_bar=True,
plugins=plugins,
use_distributed_sampler=mode not in ("test", "tune-test"),
)
_log.info(f" Training mode: {training_mode}")
_log.info(f" Devices: {system_cfg.num_gpus if system_cfg.num_gpus > 0 else 1} ({mode} mode)")
_log.info(f" Precision: {cfg.optimization.precision}")
return trainer
__all__ = [
"create_trainer",
]