connectomics.models¶
Model Builder¶
Modern model builder using architecture registry.
Uses MONAI and MedNeXt native models with automatic configuration. All models are registered in the architecture registry.
This module is intentionally model-only: - selects an architecture builder - instantiates the model - logs model structure info
It does not load checkpoints or move the model to a device.
- connectomics.models.build.build_model(cfg)[source]¶
Build model from configuration using architecture registry.
- Parameters
cfg – Hydra config object with model configuration
- Returns
Instantiated model (left on default device)
- Available architectures:
MONAI models: monai_basic_unet3d, monai_unet, monai_unetr, monai_swin_unetr
MedNeXt models: mednext, mednext_custom
Example
- cfg = OmegaConf.create({
- ‘model’: {
‘arch’: {‘type’: ‘mednext’}, ‘in_channels’: 1, ‘out_channels’: 2, ‘mednext’: {‘size’: ‘S’, ‘kernel_size’: 3}, ‘deep_supervision’: True,
}
}) model = build_model(cfg)
- To see all available architectures:
from connectomics.models.architectures import print_available_architectures print_available_architectures()
Architecture Registry¶
Architecture module for connectomics models.
Provides: - Registry system for architecture management - Base model interface with deep supervision support - MONAI model wrappers (BasicUNet, UNet, UNETR, SwinUNETR) - Future: MedNeXt model wrappers
- Usage:
- from connectomics.models.architectures import (
register_architecture, list_architectures, ConnectomicsModel,
)
# List available models print(list_architectures())
# Register custom model @register_architecture(‘my_model’) def build_my_model(cfg):
return MyModel(cfg)
- class connectomics.models.architectures.ConnectomicsModel[source]¶
Base class for all connectomics models.
Provides common interface for: - Forward pass (single or multi-scale outputs) - Deep supervision support - Model information
All models in the architectures module should inherit from this class or at least implement the same interface.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- abstract forward(x)[source]¶
Forward pass through the model.
- Parameters
x (Tensor) – Input tensor of shape (B, C, D, H, W) for 3D or (B, C, H, W) for 2D
- Returns
torch.Tensor: Output tensor of shape (B, num_classes, D, H, W)
- For deep supervision models:
- Dict[str, torch.Tensor]: Dictionary with keys:
’output’: Main output (full resolution)
’ds_1’: Deep supervision output at scale 1 (1/2 resolution)
’ds_2’: Deep supervision output at scale 2 (1/4 resolution)
’ds_3’: Deep supervision output at scale 3 (1/8 resolution)
’ds_4’: Deep supervision output at scale 4 (1/16 resolution)
- Return type
For single-scale models
- get_model_info()[source]¶
Get model information and statistics.
- Returns
name: Model class name
deep_supervision: Whether model supports deep supervision
output_scales: Number of output scales
parameters: Total trainable parameters
trainable_parameters: Number of trainable parameters
- Return type
Dictionary with model metadata
- connectomics.models.architectures.get_architecture_builder(name)[source]¶
Get builder function for architecture.
- Parameters
name (str) – Architecture name
- Returns
Builder function that takes cfg and returns a model
- Raises
ValueError – If architecture not found
- Return type
- connectomics.models.architectures.get_architecture_info()[source]¶
Get information about all registered architectures.
- connectomics.models.architectures.get_available_architectures()[source]¶
Get information about available architectures and their dependencies.
- Returns
‘monai’: List of MONAI architectures (if available)
’mednext’: List of MedNeXt architectures (if available)
’all’: List of all registered architectures
- Return type
Dictionary with
- connectomics.models.architectures.is_architecture_available(name)[source]¶
Check if architecture is available.
- connectomics.models.architectures.print_available_architectures()[source]¶
Log a formatted list of available architectures.
- connectomics.models.architectures.register_architecture(name)[source]¶
Decorator to register architecture builders.
Example
@register_architecture(‘my_model’) def build_my_model(cfg):
return MyModel(…)
- Parameters
name (str) – Unique name for the architecture
- Returns
Decorator function
- connectomics.models.architectures.unregister_architecture(name)[source]¶
Unregister an architecture (useful for testing).
- Parameters
name (str) – Architecture name
- Raises
ValueError – If architecture not found
- Return type
None
Loss Functions¶
MONAI-native loss functions for PyTorch Connectomics.
This module provides a clean interface for loss function creation using MONAI’s native implementations, with additional connectomics-specific losses.
Design pattern follows the same package structure used across data processing and augmentation.
- class connectomics.models.losses.BinaryRegularization(min_threshold=0.01, apply_sigmoid=True)[source]¶
Regularization encouraging outputs to be binary (close to 0 or 1).
Penalizes predictions that are close to 0.5 (uncertain).
- Parameters
Example
>>> reg = BinaryRegularization() >>> pred = torch.sigmoid(torch.randn(1, 1, 64, 64, 64)) >>> loss = reg(pred)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class connectomics.models.losses.ContourDistanceConsistency(*args, **kwargs)[source]¶
Consistency regularization between instance contour map and signed distance transform.
Encourages contour predictions (high at boundaries) to be consistent with distance transform predictions (low magnitude at boundaries).
Example
>>> reg = ContourDistanceConsistency() >>> contour_logits = torch.randn(1, 1, 64, 64, 64) >>> dt_pred = torch.randn(1, 1, 64, 64, 64) >>> loss = reg(contour_logits, dt_pred)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class connectomics.models.losses.ForegroundContourConsistency(kernel_half_size=1, eps=1e-07)[source]¶
Consistency regularization between binary foreground and instance contour maps.
Encourages contour predictions to align with foreground edges detected via Sobel filters.
- Parameters
Example
>>> reg = ForegroundContourConsistency() >>> fg_logits = torch.randn(1, 1, 64, 64, 64) >>> contour_logits = torch.randn(1, 1, 64, 64, 64) >>> loss = reg(fg_logits, contour_logits)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class connectomics.models.losses.ForegroundDistanceConsistency(*args, **kwargs)[source]¶
Consistency regularization between binary foreground mask and signed distance transform.
Encourages foreground predictions to be consistent with distance transform predictions.
Example
>>> reg = ForegroundDistanceConsistency() >>> fg_logits = torch.randn(1, 1, 64, 64, 64) >>> dt_pred = torch.randn(1, 1, 64, 64, 64) >>> loss = reg(fg_logits, dt_pred)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class connectomics.models.losses.GANLoss(gan_mode='lsgan', target_real_label=1.0, target_fake_label=0.0)[source]¶
GAN loss for adversarial training.
Supports vanilla, LSGAN, and WGAN-GP objectives. Based on https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
- Parameters
Note
Do not use sigmoid as the last layer of Discriminator. LSGAN needs no sigmoid. Vanilla GANs will handle it with BCEWithLogitsLoss.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class connectomics.models.losses.LossMetadata(name, call_kind='pred_target', target_kind='dense', spatial_weight_arg=None)[source]¶
Static metadata describing how to invoke a loss module.
- class connectomics.models.losses.NonOverlapRegularization(cleft_masked=True)[source]¶
Regularization preventing overlapping predictions.
Specifically designed for synaptic polarity prediction where pre- and post-synaptic masks should not overlap. Optionally masks the regularization by the cleft prediction.
- Parameters
cleft_masked (bool) – Whether to mask regularization by cleft prediction (default: True)
Example
>>> reg = NonOverlapRegularization() >>> # pred has shape (B, 3, Z, Y, X) with channels: [pre, post, cleft] >>> pred = torch.randn(2, 3, 32, 64, 64) >>> loss = reg(pred)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class connectomics.models.losses.WeightedMAELoss(reduction='mean', tanh=False)[source]¶
Weighted mean absolute error loss.
Useful for regression tasks with spatial importance weighting. Supports optional tanh activation for distance transform predictions.
- Parameters
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class connectomics.models.losses.WeightedMSELoss(reduction='mean', tanh=False)[source]¶
Weighted mean-squared error loss.
Useful for regression tasks with spatial importance weighting. Supports optional tanh activation for distance transform predictions.
- Parameters
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- connectomics.models.losses.create_loss(loss_name, **kwargs)[source]¶
Create a single loss function by name.
- Parameters
loss_name (str) – Name of the loss function
**kwargs – Loss-specific parameters
- Returns
Initialized loss function
- Return type
Examples
>>> loss = create_loss('DiceLoss', include_background=False) >>> loss = create_loss('DiceCELoss', to_onehot_y=True, softmax=True) >>> loss = create_loss('FocalLoss', gamma=2.0)
- connectomics.models.losses.get_loss_metadata(loss_name)[source]¶
Return registered metadata for a known loss name.
- Parameters
loss_name (str) –
- Return type