Shortcuts

Source code for connectomics.training.lightning.trainer

"""
PyTorch Lightning trainer utilities for PyTorch Connectomics.

This module provides Lightning trainer factory functions with:
- Hydra/OmegaConf configuration
- Modern callbacks (checkpointing, early stopping, logging)
- Distributed training support
- Mixed precision training
"""

from __future__ import annotations

import logging
import os
from pathlib import Path
from typing import Optional

import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import (
    EarlyStopping,
    ModelCheckpoint,
)
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.plugins.environments import LightningEnvironment
from pytorch_lightning.strategies import DDPStrategy

from ...config import Config
from ...runtime.torch_safe_globals import register_torch_safe_globals
from .callbacks import EMAWeightsCallback, ValidationReseedingCallback, VisualizationCallback

_log = logging.getLogger(__name__)

register_torch_safe_globals()


[docs]def create_trainer( cfg: Config, run_dir: Optional[Path] = None, fast_dev_run: bool = False, ckpt_path: Optional[str] = None, mode: str = "train", ) -> pl.Trainer: """ Create PyTorch Lightning Trainer. Args: cfg: Hydra Config object run_dir: Directory for this training run (required for mode='train') fast_dev_run: Whether to run quick debug mode ckpt_path: Path to checkpoint for resuming (used to extract best_score) mode: 'train' or 'test' - determines which system config to use Returns: Configured Trainer instance """ _log.info(f"Creating Lightning trainer (mode={mode})...") # Setup callbacks (only for training mode) callbacks = [] if mode == "train": if run_dir is None: raise ValueError("run_dir is required when mode='train'") # Setup checkpoint directory checkpoint_dir = run_dir / "checkpoints" checkpoint_dir.mkdir(parents=True, exist_ok=True) # Model checkpoint (in run_dir/checkpoints/) checkpoint_callback = ModelCheckpoint( dirpath=str(checkpoint_dir), filename=cfg.monitor.checkpoint.checkpoint_filename, monitor=cfg.monitor.checkpoint.monitor, mode=cfg.monitor.checkpoint.mode, save_top_k=cfg.monitor.checkpoint.save_top_k, save_last=cfg.monitor.checkpoint.save_last, every_n_epochs=cfg.monitor.checkpoint.save_every_n_epochs, verbose=False, save_on_train_epoch_end=cfg.monitor.checkpoint.save_on_train_epoch_end, ) callbacks.append(checkpoint_callback) save_every_n_steps = getattr(cfg.monitor.checkpoint, "save_every_n_steps", None) if save_every_n_steps is not None and int(save_every_n_steps) > 0: step_checkpoint_callback = ModelCheckpoint( dirpath=str(checkpoint_dir), filename=getattr( cfg.monitor.checkpoint, "step_checkpoint_filename", "{step:08d}", ), monitor=None, save_top_k=-1, save_last=False, every_n_train_steps=int(save_every_n_steps), every_n_epochs=0, verbose=False, save_on_train_epoch_end=False, ) callbacks.append(step_checkpoint_callback) _log.info(" Step checkpoints: every %d train steps", int(save_every_n_steps)) # Early stopping (training only) if cfg.monitor.early_stopping.enabled: # Import here to avoid circular dependency from .utils import extract_best_score_from_checkpoint # Extract best_score from checkpoint filename if resuming best_score = None if ckpt_path: best_score = extract_best_score_from_checkpoint( ckpt_path, cfg.monitor.early_stopping.monitor ) if best_score is not None: _log.info( f" Early stopping: Extracted best_score={best_score:.6f} from checkpoint" ) early_stop_callback = EarlyStopping( monitor=cfg.monitor.early_stopping.monitor, patience=cfg.monitor.early_stopping.patience, mode=cfg.monitor.early_stopping.mode, min_delta=cfg.monitor.early_stopping.min_delta, verbose=False, check_on_train_epoch_end=True, # Check at end of train epoch (not validation) check_finite=cfg.monitor.early_stopping.check_finite, # Stop on NaN/inf stopping_threshold=cfg.monitor.early_stopping.threshold, divergence_threshold=cfg.monitor.early_stopping.divergence_threshold, strict=False, # Don't crash if metric not available (wait for it) ) # Manually set best_score if extracted from checkpoint if best_score is not None: early_stop_callback.best_score = torch.tensor(best_score) callbacks.append(early_stop_callback) # Visualization callback (training only, end-of-epoch only) if cfg.monitor.logging.images.enabled: vis_callback = VisualizationCallback( cfg=cfg, max_images=cfg.monitor.logging.images.max_images, num_slices=cfg.monitor.logging.images.num_slices, slice_sampling=cfg.monitor.logging.images.slice_sampling, log_every_n_epochs=cfg.monitor.logging.images.log_every_n_epochs, ) callbacks.append(vis_callback) log_freq = cfg.monitor.logging.images.log_every_n_epochs _log.info(f" Visualization: Enabled (every {log_freq} epoch(s))") else: _log.info(" Visualization: Disabled") # EMA weights for stabler validation ema_cfg = getattr(cfg.optimization, "ema", None) if ema_cfg and getattr(ema_cfg, "enabled", False): ema_callback = EMAWeightsCallback( decay=getattr(ema_cfg, "decay", 0.999), warmup_steps=getattr(ema_cfg, "warmup_steps", 0), validate_with_ema=getattr(ema_cfg, "validate_with_ema", True), device=getattr(ema_cfg, "device", None), copy_buffers=getattr(ema_cfg, "copy_buffers", True), ) callbacks.append(ema_callback) _log.info( f" EMA: Enabled (decay={ema_cfg.decay}, warmup_steps={ema_cfg.warmup_steps}, " f"validate_with_ema={ema_cfg.validate_with_ema})" ) # [FIX 1 - PROPER IMPLEMENTATION] Validation reseeding callback # This ensures validation datasets are reseeded at the start of EACH validation epoch # Previous fix in val_dataloader() only ran once during setup validation_reseeding_callback = ValidationReseedingCallback( base_seed=cfg.system.seed, log_fingerprint=False, log_all_ranks=False, verbose=False, ) callbacks.append(validation_reseeding_callback) _log.info(f" Validation Reseeding: Enabled (base_seed={cfg.system.seed})") # Setup logger. # Train: keep TensorBoard logs in run_dir/logs. # Test/tune: disable logger so no logs/ folder is created in results output directories. logger = False if mode == "train": if run_dir is None: raise ValueError("run_dir is required when mode='train'") logger = TensorBoardLogger( save_dir=str(run_dir), name="", # No name subdirectory version="logs", # Logs go directly to run_dir/logs/ ) _log.info(f" Logger: TensorBoard (logs saved to {run_dir}/logs/)") # Create trainer system_cfg = cfg.system # Check if GPU is actually available use_gpu = system_cfg.num_gpus > 0 and torch.cuda.is_available() # Check if anomaly detection is enabled (useful for debugging NaN) detect_anomaly = getattr(cfg.monitor, "detect_anomaly", False) if detect_anomaly: _log.warning("PyTorch anomaly detection ENABLED (training will be slower)") _log.warning(" This helps pinpoint the exact operation causing NaN in backward pass") # Configure DDP strategy for multi-GPU training with deep supervision strategy = "auto" # Default strategy if system_cfg.num_gpus > 1: # Multi-GPU training: configure DDP loss_cfg = getattr(cfg.model, "loss", None) deep_supervision_enabled = getattr(loss_cfg, "deep_supervision", False) ddp_find_unused_params = getattr(cfg.model, "ddp_find_unused_parameters", False) architecture = getattr(getattr(cfg.model, "arch", None), "type", "") is_mednext = architecture.startswith("mednext") # MedNeXt always creates deep supervision layers internally (even when disabled) # so it always needs find_unused_parameters=True if is_mednext or deep_supervision_enabled or ddp_find_unused_params: strategy = DDPStrategy(find_unused_parameters=True) # Determine reason for using find_unused_parameters if is_mednext and not deep_supervision_enabled: reason = "MedNeXt (has unused DS layers)" elif deep_supervision_enabled: reason = "deep supervision enabled" else: reason = "explicit config" _log.info(f" Strategy: DDP with find_unused_parameters=True ({reason})") else: strategy = DDPStrategy(find_unused_parameters=False) _log.info(" Strategy: DDP (standard)") # [FIX 2] Implement TRUE step-based training # PyTorch Lightning stops when EITHER max_epochs OR max_steps is reached # To ensure step-based training works correctly, we must disable epochs when using steps max_steps_cfg = getattr(cfg.optimization, "max_steps", None) if max_steps_cfg is not None and max_steps_cfg > 0: # Step-based training: disable epoch limit max_epochs = -1 # -1 means unlimited epochs max_steps = max_steps_cfg training_mode = f"step-based ({max_steps:,} steps)" else: # Epoch-based training: disable step limit max_epochs = cfg.optimization.max_epochs max_steps = -1 # -1 means unlimited steps training_mode = f"epoch-based ({max_epochs} epochs)" # Treat optimization.val_check_interval as epoch interval by default. # BANIS-style reproduction can opt into true step-based validation. val_check_cfg = cfg.optimization.val_check_interval if isinstance(val_check_cfg, float): if not val_check_cfg.is_integer(): raise ValueError( "optimization.val_check_interval must be an integer number of epochs/steps " f"(got {val_check_cfg})." ) val_check_interval = int(val_check_cfg) else: val_check_interval = int(val_check_cfg) if val_check_interval < 1: raise ValueError( "optimization.val_check_interval must be >= 1 " f"(got {val_check_interval})." ) val_check_unit = str(getattr(cfg.optimization, "val_check_interval_unit", "epoch")).lower() if val_check_unit not in {"epoch", "step"}: raise ValueError( "optimization.val_check_interval_unit must be 'epoch' or 'step' " f"(got {val_check_unit!r})." ) if val_check_unit == "step": check_val_every_n_epoch = None trainer_val_check_interval = val_check_interval _log.info(f" Validation: every {val_check_interval} train step(s)") else: check_val_every_n_epoch = val_check_interval trainer_val_check_interval = 1.0 _log.info(f" Validation: every {check_val_every_n_epoch} epoch(s)") # In Slurm jobs launched with ntasks=1, force local process spawning for multi-GPU # so Lightning uses world_size=devices instead of treating Slurm as externally launched DDP. plugins = None slurm_ntasks = os.environ.get("SLURM_NTASKS") if use_gpu and system_cfg.num_gpus > 1 and slurm_ntasks == "1": plugins = [LightningEnvironment()] _log.info(" Launch mode: local multi-GPU spawn (SLURM_NTASKS=1)") trainer = pl.Trainer( max_epochs=max_epochs, max_steps=max_steps, accelerator="gpu" if use_gpu else "cpu", devices=system_cfg.num_gpus if use_gpu else 1, strategy=strategy, precision=cfg.optimization.precision, gradient_clip_val=cfg.optimization.gradient_clip_val, accumulate_grad_batches=cfg.optimization.accumulate_grad_batches, check_val_every_n_epoch=check_val_every_n_epoch, val_check_interval=trainer_val_check_interval, num_sanity_val_steps=cfg.optimization.num_sanity_val_steps, log_every_n_steps=cfg.optimization.log_every_n_steps, callbacks=callbacks, logger=logger, fast_dev_run=bool(fast_dev_run), detect_anomaly=detect_anomaly, enable_progress_bar=True, plugins=plugins, use_distributed_sampler=mode not in ("test", "tune-test"), ) _log.info(f" Training mode: {training_mode}") _log.info(f" Devices: {system_cfg.num_gpus if system_cfg.num_gpus > 0 else 1} ({mode} mode)") _log.info(f" Precision: {cfg.optimization.precision}") return trainer
__all__ = [ "create_trainer", ]