Shortcuts

connectomics.models

Model Builder

Modern model builder using architecture registry.

Uses MONAI and MedNeXt native models with automatic configuration. All models are registered in the architecture registry.

This module is intentionally model-only: - selects an architecture builder - instantiates the model - logs model structure info

It does not load checkpoints or move the model to a device.

connectomics.models.build.build_model(cfg)[source]

Build model from configuration using architecture registry.

Parameters

cfg – Hydra config object with model configuration

Returns

Instantiated model (left on default device)

Available architectures:
  • MONAI models: monai_basic_unet3d, monai_unet, monai_unetr, monai_swin_unetr

  • MedNeXt models: mednext, mednext_custom

Example

cfg = OmegaConf.create({
‘model’: {

‘arch’: {‘type’: ‘mednext’}, ‘in_channels’: 1, ‘out_channels’: 2, ‘mednext’: {‘size’: ‘S’, ‘kernel_size’: 3}, ‘deep_supervision’: True,

}

}) model = build_model(cfg)

To see all available architectures:

from connectomics.models.architectures import print_available_architectures print_available_architectures()

Architecture Registry

Architecture module for connectomics models.

Provides: - Registry system for architecture management - Base model interface with deep supervision support - MONAI model wrappers (BasicUNet, UNet, UNETR, SwinUNETR) - Future: MedNeXt model wrappers

Usage:
from connectomics.models.architectures import (

register_architecture, list_architectures, ConnectomicsModel,

)

# List available models print(list_architectures())

# Register custom model @register_architecture(‘my_model’) def build_my_model(cfg):

return MyModel(cfg)

class connectomics.models.architectures.ConnectomicsModel[source]

Base class for all connectomics models.

Provides common interface for: - Forward pass (single or multi-scale outputs) - Deep supervision support - Model information

All models in the architectures module should inherit from this class or at least implement the same interface.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

abstract forward(x)[source]

Forward pass through the model.

Parameters

x (Tensor) – Input tensor of shape (B, C, D, H, W) for 3D or (B, C, H, W) for 2D

Returns

torch.Tensor: Output tensor of shape (B, num_classes, D, H, W)

For deep supervision models:
Dict[str, torch.Tensor]: Dictionary with keys:
  • ’output’: Main output (full resolution)

  • ’ds_1’: Deep supervision output at scale 1 (1/2 resolution)

  • ’ds_2’: Deep supervision output at scale 2 (1/4 resolution)

  • ’ds_3’: Deep supervision output at scale 3 (1/8 resolution)

  • ’ds_4’: Deep supervision output at scale 4 (1/16 resolution)

Return type

For single-scale models

get_model_info()[source]

Get model information and statistics.

Returns

  • name: Model class name

  • deep_supervision: Whether model supports deep supervision

  • output_scales: Number of output scales

  • parameters: Total trainable parameters

  • trainable_parameters: Number of trainable parameters

Return type

Dictionary with model metadata

connectomics.models.architectures.get_architecture_builder(name)[source]

Get builder function for architecture.

Parameters

name (str) – Architecture name

Returns

Builder function that takes cfg and returns a model

Raises

ValueError – If architecture not found

Return type

Callable

connectomics.models.architectures.get_architecture_info()[source]

Get information about all registered architectures.

Returns

Dict mapping architecture names to their metadata

Return type

Dict[str, Dict[str, str]]

connectomics.models.architectures.get_available_architectures()[source]

Get information about available architectures and their dependencies.

Returns

  • ‘monai’: List of MONAI architectures (if available)

  • ’mednext’: List of MedNeXt architectures (if available)

  • ’all’: List of all registered architectures

Return type

Dictionary with

connectomics.models.architectures.is_architecture_available(name)[source]

Check if architecture is available.

Parameters

name (str) – Architecture name

Returns

True if architecture is registered

Return type

bool

connectomics.models.architectures.list_architectures()[source]

List all registered architectures.

Returns

Sorted list of architecture names

Return type

List[str]

connectomics.models.architectures.print_available_architectures()[source]

Log a formatted list of available architectures.

connectomics.models.architectures.register_architecture(name)[source]

Decorator to register architecture builders.

Example

@register_architecture(‘my_model’) def build_my_model(cfg):

return MyModel(…)

Parameters

name (str) – Unique name for the architecture

Returns

Decorator function

connectomics.models.architectures.unregister_architecture(name)[source]

Unregister an architecture (useful for testing).

Parameters

name (str) – Architecture name

Raises

ValueError – If architecture not found

Return type

None

Loss Functions

MONAI-native loss functions for PyTorch Connectomics.

This module provides a clean interface for loss function creation using MONAI’s native implementations, with additional connectomics-specific losses.

Design pattern follows the same package structure used across data processing and augmentation.

class connectomics.models.losses.BinaryRegularization(min_threshold=0.01, apply_sigmoid=True)[source]

Regularization encouraging outputs to be binary (close to 0 or 1).

Penalizes predictions that are close to 0.5 (uncertain).

Parameters
  • min_threshold (float) – Minimum threshold for clamping (default: 1e-2)

  • apply_sigmoid (bool) –

Example

>>> reg = BinaryRegularization()
>>> pred = torch.sigmoid(torch.randn(1, 1, 64, 64, 64))
>>> loss = reg(pred)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(pred, mask=None)[source]

Compute binary regularization loss.

Parameters
  • pred (Tensor) – Predictions (logits or probabilities)

  • mask (Optional[Tensor]) – Optional spatial weight mask

Returns

Regularization loss

Return type

Tensor

class connectomics.models.losses.ContourDistanceConsistency(*args, **kwargs)[source]

Consistency regularization between instance contour map and signed distance transform.

Encourages contour predictions (high at boundaries) to be consistent with distance transform predictions (low magnitude at boundaries).

Example

>>> reg = ContourDistanceConsistency()
>>> contour_logits = torch.randn(1, 1, 64, 64, 64)
>>> dt_pred = torch.randn(1, 1, 64, 64, 64)
>>> loss = reg(contour_logits, dt_pred)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters
  • args (Any) –

  • kwargs (Any) –

forward(contour_logits, distance_transform, mask=None)[source]

Compute consistency loss between contour and distance transform.

Parameters
  • contour_logits (Tensor) – Instance contour logits

  • distance_transform (Tensor) – Signed distance transform predictions

  • mask (Optional[Tensor]) – Optional spatial weight mask

Returns

Consistency loss

Return type

Tensor

class connectomics.models.losses.ForegroundContourConsistency(kernel_half_size=1, eps=1e-07)[source]

Consistency regularization between binary foreground and instance contour maps.

Encourages contour predictions to align with foreground edges detected via Sobel filters.

Parameters
  • kernel_half_size (int) – Half-size of edge detection kernel (default: 1)

  • eps (float) – Small epsilon for numerical stability (default: 1e-7)

Example

>>> reg = ForegroundContourConsistency()
>>> fg_logits = torch.randn(1, 1, 64, 64, 64)
>>> contour_logits = torch.randn(1, 1, 64, 64, 64)
>>> loss = reg(fg_logits, contour_logits)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(foreground_logits, contour_logits, mask=None)[source]

Compute consistency loss between foreground edges and contours.

Parameters
  • foreground_logits (Tensor) – Binary foreground logits

  • contour_logits (Tensor) – Instance contour logits

  • mask (Optional[Tensor]) – Optional spatial weight mask

Returns

Consistency loss

Return type

Tensor

class connectomics.models.losses.ForegroundDistanceConsistency(*args, **kwargs)[source]

Consistency regularization between binary foreground mask and signed distance transform.

Encourages foreground predictions to be consistent with distance transform predictions.

Example

>>> reg = ForegroundDistanceConsistency()
>>> fg_logits = torch.randn(1, 1, 64, 64, 64)
>>> dt_pred = torch.randn(1, 1, 64, 64, 64)
>>> loss = reg(fg_logits, dt_pred)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters
  • args (Any) –

  • kwargs (Any) –

forward(foreground_logits, distance_transform, mask=None)[source]

Compute consistency loss between foreground and distance transform.

Parameters
  • foreground_logits (Tensor) – Binary foreground logits

  • distance_transform (Tensor) – Signed distance transform predictions

  • mask (Optional[Tensor]) – Optional spatial weight mask

Returns

Consistency loss

Return type

Tensor

class connectomics.models.losses.GANLoss(gan_mode='lsgan', target_real_label=1.0, target_fake_label=0.0)[source]

GAN loss for adversarial training.

Supports vanilla, LSGAN, and WGAN-GP objectives. Based on https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix

Parameters
  • gan_mode (str) – GAN objective type (‘vanilla’, ‘lsgan’, ‘wgangp’)

  • target_real_label (float) – Label for real images

  • target_fake_label (float) – Label for fake images

Note

Do not use sigmoid as the last layer of Discriminator. LSGAN needs no sigmoid. Vanilla GANs will handle it with BCEWithLogitsLoss.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(prediction, target_is_real)[source]

Calculate GAN loss.

Parameters
  • prediction (Tensor) – Discriminator output

  • target_is_real (bool) – Whether ground truth labels are for real or fake images

Returns

Calculated loss

Return type

Tensor

get_target_tensor(prediction, target_is_real)[source]

Create label tensors with the same size as the input.

Parameters
  • prediction (Tensor) – Discriminator prediction

  • target_is_real (bool) – Whether the ground truth is real or fake

Returns

Label tensor filled with ground truth labels

Return type

Tensor

class connectomics.models.losses.LossMetadata(name, call_kind='pred_target', target_kind='dense', spatial_weight_arg=None)[source]

Static metadata describing how to invoke a loss module.

Parameters
  • name (str) –

  • call_kind (Literal['pred_target', 'pred_only', 'pred_pred', 'unsupported']) –

  • target_kind (Literal['dense', 'class_index', 'none']) –

  • spatial_weight_arg (Optional[str]) –

class connectomics.models.losses.NonOverlapRegularization(cleft_masked=True)[source]

Regularization preventing overlapping predictions.

Specifically designed for synaptic polarity prediction where pre- and post-synaptic masks should not overlap. Optionally masks the regularization by the cleft prediction.

Parameters

cleft_masked (bool) – Whether to mask regularization by cleft prediction (default: True)

Example

>>> reg = NonOverlapRegularization()
>>> # pred has shape (B, 3, Z, Y, X) with channels: [pre, post, cleft]
>>> pred = torch.randn(2, 3, 32, 64, 64)
>>> loss = reg(pred)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(pred)[source]

Compute non-overlap regularization loss.

Parameters

pred (Tensor) – Predictions with shape (B, C, Z, Y, X) where: - Channel 0: Pre-synaptic logits - Channel 1: Post-synaptic logits - Channel 2: Cleft logits (optional, used if cleft_masked=True)

Returns

Non-overlap regularization loss

Return type

Tensor

class connectomics.models.losses.WeightedMAELoss(reduction='mean', tanh=False)[source]

Weighted mean absolute error loss.

Useful for regression tasks with spatial importance weighting. Supports optional tanh activation for distance transform predictions.

Parameters
  • reduction (str) – Reduction method (‘mean’, ‘sum’, ‘none’)

  • tanh (bool) – If True, apply tanh activation to predictions before computing loss. Useful for distance transform targets in range [-1, 1].

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(pred, target, weight=None)[source]

Compute weighted MAE loss.

Parameters
  • pred (Tensor) – Predictions (logits if tanh=True, otherwise predictions)

  • target (Tensor) – Ground truth

  • weight (Tensor) – Optional spatial weights

Returns

Loss value

Return type

Tensor

class connectomics.models.losses.WeightedMSELoss(reduction='mean', tanh=False)[source]

Weighted mean-squared error loss.

Useful for regression tasks with spatial importance weighting. Supports optional tanh activation for distance transform predictions.

Parameters
  • reduction (str) – Reduction method (‘mean’, ‘sum’, ‘none’)

  • tanh (bool) – If True, apply tanh activation to predictions before computing loss. Useful for distance transform targets in range [-1, 1]. With both pred and target in [-1, 1], MSE should be < 4.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(pred, target, weight=None)[source]

Compute weighted MSE loss.

Parameters
  • pred (Tensor) – Predictions (logits if tanh=True, otherwise predictions)

  • target (Tensor) – Ground truth (range [-1, 1] for SDT)

  • weight (Tensor) – Optional spatial weights

Returns

Loss value (should be < 4 for range [-1, 1])

Return type

Tensor

connectomics.models.losses.create_loss(loss_name, **kwargs)[source]

Create a single loss function by name.

Parameters
  • loss_name (str) – Name of the loss function

  • **kwargs – Loss-specific parameters

Returns

Initialized loss function

Return type

Module

Examples

>>> loss = create_loss('DiceLoss', include_background=False)
>>> loss = create_loss('DiceCELoss', to_onehot_y=True, softmax=True)
>>> loss = create_loss('FocalLoss', gamma=2.0)
connectomics.models.losses.get_loss_metadata(loss_name)[source]

Return registered metadata for a known loss name.

Parameters

loss_name (str) –

Return type

LossMetadata

connectomics.models.losses.get_loss_metadata_for_module(loss_fn)[source]

Fetch metadata from an annotated module, or infer a safe fallback.

Parameters

loss_fn (Module) –

Return type

LossMetadata

connectomics.models.losses.list_available_losses()[source]

List all available loss functions.

Return type

List[str]