Configuration System¶
Note
PyTorch Connectomics v2.0 uses Hydra/OmegaConf as the configuration system.
PyTorch Connectomics uses a flexible, type-safe configuration system built on Hydra and OmegaConf. Configuration files are written in YAML and support CLI overrides, composition, and type checking.
Quick Start¶
Basic training:
# Train with a config file
python scripts/main.py --config tutorials/minimal.yaml
# Override config from CLI
python scripts/main.py --config tutorials/minimal.yaml \
default.data.dataloader.batch_size=4 \
train.optimization.max_epochs=200
Python API:
from connectomics.config import load_config
from omegaconf import OmegaConf
# Load config
cfg = load_config("tutorials/minimal.yaml")
# Access values
print(cfg.model.arch.type) # 'monai_basic_unet3d'
print(cfg.data.dataloader.batch_size) # 1
# Modify values
cfg.data.dataloader.batch_size = 4
# Print entire config
print(OmegaConf.to_yaml(cfg, resolve=True))
Configuration Structure¶
A typical v2.0 config file has a default section plus stage-specific
overrides such as train and test:
experiment_name: example
default:
system:
num_gpus: 1
num_workers: 4
seed: 42
model:
arch:
type: monai_basic_unet3d
in_channels: 1
out_channels: 1
input_size: [64, 128, 128]
output_size: [64, 128, 128]
loss:
losses:
- function: DiceLoss
weight: 1.0
data:
dataloader:
batch_size: 2
patch_size: [64, 128, 128]
train:
data:
train:
image: datasets/example/train_image.h5
label: datasets/example/train_label.h5
val:
image: datasets/example/val_image.h5
label: datasets/example/val_label.h5
optimization:
max_epochs: 100
precision: "16-mixed"
optimizer:
name: AdamW
lr: 1e-4
monitor:
checkpoint:
monitor: train_loss_total_epoch
save_top_k: 3
save_last: true
Configuration Sections¶
System Configuration¶
Controls hardware and reproducibility:
system:
num_gpus: 1 # Number of GPUs (0 for CPU)
num_cpus: 4 # Number of CPU workers
seed: 42 # Random seed for reproducibility
deterministic: false # Use deterministic algorithms (slower)
Model Configuration¶
Specifies model architecture and loss functions:
model:
arch:
type: monai_basic_unet3d # Model architecture
in_channels: 1 # Input channels
out_channels: 2 # Output channels
monai:
filters: [32, 64, 128, 256] # Filter sizes per level
dropout: 0.1 # Dropout rate
# Loss functions
loss:
deep_supervision: true
losses:
- function: DiceLoss
weight: 1.0
- function: BCEWithLogitsLoss
weight: 1.0
# Optional: architecture-specific nested blocks
mednext:
size: S
Available architectures:
monai_basic_unet3d: Simple and fast 3D U-Netmonai_unet: U-Net with residual unitsmonai_unetr: Transformer-based UNETRmonai_swin_unetr: Swin Transformer U-Netmednext: MedNeXt with predefined sizes (S/B/M/L)mednext_custom: MedNeXt with custom parameters
Available loss functions:
DiceLoss: Soft Dice lossFocalLoss: Focal loss for class imbalanceTverskyLoss: Tversky lossDiceCELoss: Combined Dice + Cross-EntropyBCEWithLogitsLoss: Binary cross-entropyCrossEntropyLoss: Multi-class cross-entropy
Data Configuration¶
Specifies data paths and loading parameters:
data:
# Data paths
train:
image: "path/to/train_image.h5"
label: "path/to/train_label.h5"
val:
image: "path/to/val_image.h5"
label: "path/to/val_label.h5"
test:
image: "path/to/test_image.h5" # Optional
dataloader:
patch_size: [128, 128, 128]
batch_size: 2
persistent_workers: true
pin_memory: true
# Augmentation
augmentation:
profile: aug_standard
Optimizer Configuration¶
Specifies optimizer type and hyperparameters:
optimization:
optimizer:
name: AdamW # Optimizer type
lr: 1e-4 # Learning rate
weight_decay: 1e-4 # Weight decay (L2 regularization)
# Optimizer-specific params
betas: [0.9, 0.999] # For Adam/AdamW
momentum: 0.9 # For SGD
Supported optimizers:
Adam,AdamW,SGD,RMSprop,Adagrad
Scheduler Configuration¶
Specifies learning rate scheduling:
optimization:
scheduler:
name: CosineAnnealingLR
warmup_epochs: 5
min_lr: 1e-6
# Scheduler-specific params
params:
T_max: 100
Supported schedulers:
CosineAnnealingLR,StepLR,ExponentialLR,ReduceLROnPlateau
Training Configuration¶
Controls training loop parameters:
optimization:
max_epochs: 100
precision: "16-mixed" # "32", "16-mixed", "bf16-mixed"
gradient_clip_val: 1.0
accumulate_grad_batches: 1 # Gradient accumulation
val_check_interval: 1.0 # Validation frequency
Command Line Overrides¶
Override any config value from the command line:
# Override single values
python scripts/main.py --config tutorials/minimal.yaml \
default.data.dataloader.batch_size=4
# Override multiple values
python scripts/main.py --config tutorials/minimal.yaml \
default.data.dataloader.batch_size=4 \
train.optimization.max_epochs=200 \
train.optimization.optimizer.lr=1e-3
# Override nested values
python scripts/main.py --config tutorials/minimal.yaml \
default.model.monai.filters=[64,128,256,512]
# Add new values
python scripts/main.py --config tutorials/minimal.yaml \
+description="debug run"
Multiple Loss Functions¶
Combine multiple loss functions with different weights:
model:
loss:
losses:
- function: DiceLoss
weight: 1.0
- function: BCEWithLogitsLoss
weight: 1.0
- function: FocalLoss
weight: 0.5
The total loss is computed as:
total_loss = (1.0 * dice_loss +
1.0 * bce_loss +
0.5 * focal_loss)
Deep Supervision¶
Enable multi-scale loss computation for improved training:
model:
arch:
type: mednext
loss:
deep_supervision: true
losses:
- function: DiceLoss
weight: 1.0
Deep supervision automatically:
Computes losses at multiple scales (5 scales for MedNeXt)
Resizes ground truth to match each scale
Averages losses across scales
MedNeXt Configuration¶
Predefined sizes:
model:
arch:
type: mednext
mednext:
size: S # S, B, M, or L
kernel_size: 3 # 3, 5, or 7
in_channels: 1
out_channels: 2
loss:
deep_supervision: true
Custom configuration:
model:
arch:
type: mednext_custom
mednext:
base_channels: 32
exp_r: [2, 3, 4, 4, 4, 4, 4, 3, 2]
block_counts: [3, 4, 8, 8, 8, 8, 8, 4, 3]
kernel_size: 7
grn: true
loss:
deep_supervision: true
See .claude/MEDNEXT.md for details.
2D Configuration¶
For 2D segmentation tasks:
data:
train:
do_2d: true
dataloader:
patch_size: [1, 256, 256] # [D, H, W] - D=1 for 2D
Mixed Precision Training¶
Use mixed precision for faster training and reduced memory:
optimization:
precision: "16-mixed" # FP16 mixed precision
# Or for BFloat16 (requires Ampere+ GPUs)
optimization:
precision: "bf16-mixed"
Distributed Training¶
Automatically use distributed training with multiple GPUs:
system:
num_gpus: 4 # Uses DDP automatically
data:
dataloader:
batch_size: 2 # Per-GPU batch size
Effective batch size = num_gpus * batch_size = 4 * 2 = 8
Gradient Accumulation¶
Simulate larger batch sizes:
data:
dataloader:
batch_size: 2
optimization:
accumulate_grad_batches: 4
Effective batch size = batch_size * accumulate_grad_batches = 2 * 4 = 8
Checkpointing and Logging¶
Model checkpointing:
monitor:
checkpoint:
monitor: "val/loss"
mode: "min" # "min" or "max"
save_top_k: 3 # Keep best 3 checkpoints
save_last: true # Also save last checkpoint
filename: "epoch{epoch:02d}-loss{val/loss:.2f}"
Early stopping:
monitor:
early_stopping:
enabled: true
monitor: "val/loss"
patience: 10
mode: "min"
min_delta: 0.0
Logging:
monitor:
logging:
scalar:
loss_every_n_steps: 10
wandb:
use_wandb: false
project: "connectomics"
entity: "your_team"
Configuration in Python¶
Load and modify configs:
from connectomics.config import load_config, save_config
from omegaconf import OmegaConf
# Load config
cfg = load_config("tutorials/minimal.yaml")
# Access values
print(cfg.model.arch.type)
print(cfg.data.dataloader.batch_size)
# Modify values
cfg.data.dataloader.batch_size = 4
cfg.optimization.max_epochs = 200
# Merge configs
overrides = OmegaConf.create({
"data": {"dataloader": {"batch_size": 8}},
"optimization": {"optimizer": {"lr": 1e-3}}
})
cfg = OmegaConf.merge(cfg, overrides)
# Save config
save_config(cfg, "modified_config.yaml")
# Print config
print(OmegaConf.to_yaml(cfg, resolve=True))
Create configs programmatically:
from omegaconf import OmegaConf
cfg = OmegaConf.create({
"system": {"num_gpus": 1, "seed": 42},
"model": {
"arch": {"type": "monai_unet"},
"in_channels": 1,
"out_channels": 2
},
"data": {
"dataloader": {
"batch_size": 2,
"patch_size": [128, 128, 128]
}
}
})
Inference Configuration¶
Many training configs are reused for inference. Key differences:
# inference_config.yaml
model:
arch:
type: monai_unet
# ... same as training
data:
test:
image: "path/to/test.h5"
dataloader:
patch_size: [128, 128, 128]
batch_size: 4 # Can use larger batch size
inference:
output_path: "predictions/"
sliding_window:
overlap: 0.5
blend_mode: gaussian
test_time_augmentation:
enabled: false
Run inference:
python scripts/main.py \
--config inference_config.yaml \
--mode test \
--checkpoint outputs/best.ckpt
Configuration Examples¶
See the tutorials/ directory for complete examples:
tutorials/minimal.yaml: minimal MONAI smoke config
tutorials/mito_lucchi++.yaml: mitochondria segmentation
tutorials/neuron_snemi/neuron_snemi_sdt.yaml: MedNeXt SNEMI config
Best Practices¶
Use version control for config files
Document non-obvious parameter choices
Start simple with basic configs, then customize
Save configs with experiment outputs for reproducibility
Use meaningful names for experiments
Validate configs before long training runs
For more information: