Shortcuts

Source code for connectomics.models.losses.build

"""
MONAI-native loss functions for PyTorch Connectomics.

This module provides loss function composition using MONAI's native losses,
with additional connectomics-specific loss functions as needed.

Design pattern aligned with the rest of the normalized package layout.
"""

from __future__ import annotations

from typing import Dict, List

import torch.nn as nn

# Import MONAI losses
from monai.losses import (
    DiceCELoss,
    DiceFocalLoss,
    DiceLoss,
    FocalLoss,
    GeneralizedDiceLoss,
    TverskyLoss,
)

# Import custom connectomics losses
from .losses import (
    CrossEntropyLossWrapper,
    GANLoss,
    PerChannelBCEWithLogitsLoss,
    SmoothL1Loss,
    WeightedBCEWithLogitsLoss,
    WeightedMAELoss,
    WeightedMSELoss,
)
from .metadata import (
    LossMetadata,
    attach_loss_metadata,
    get_loss_metadata,
    get_loss_metadata_for_module,
)

# Import regularization losses
from .regularization import (
    BinaryRegularization,
    ContourDistanceConsistency,
    ForegroundContourConsistency,
    ForegroundDistanceConsistency,
    NonOverlapRegularization,
)


def _get_loss_registry() -> Dict[str, type[nn.Module]]:
    """Return the canonical mapping of loss names to constructors."""
    return {
        # MONAI Dice variants
        "DiceLoss": DiceLoss,
        "DiceCELoss": DiceCELoss,
        "DiceFocalLoss": DiceFocalLoss,
        "GeneralizedDiceLoss": GeneralizedDiceLoss,
        # MONAI other losses
        "FocalLoss": FocalLoss,
        "TverskyLoss": TverskyLoss,
        # PyTorch standard losses (for convenience)
        "BCEWithLogitsLoss": nn.BCEWithLogitsLoss,
        "CrossEntropyLoss": CrossEntropyLossWrapper,  # Use wrapper for shape handling
        "MSELoss": nn.MSELoss,
        "L1Loss": nn.L1Loss,
        "SmoothL1Loss": SmoothL1Loss,
        # Custom connectomics losses
        "WeightedBCEWithLogitsLoss": WeightedBCEWithLogitsLoss,
        "PerChannelBCEWithLogitsLoss": PerChannelBCEWithLogitsLoss,
        "WeightedMSELoss": WeightedMSELoss,
        "WeightedMAELoss": WeightedMAELoss,
        "GANLoss": GANLoss,
        # Regularization losses
        "BinaryRegularization": BinaryRegularization,
        "ForegroundDistanceConsistency": ForegroundDistanceConsistency,
        "ContourDistanceConsistency": ContourDistanceConsistency,
        "ForegroundContourConsistency": ForegroundContourConsistency,
        "NonOverlapRegularization": NonOverlapRegularization,
    }


[docs]def create_loss(loss_name: str, **kwargs) -> nn.Module: """ Create a single loss function by name. Args: loss_name: Name of the loss function **kwargs: Loss-specific parameters Returns: Initialized loss function Examples: >>> loss = create_loss('DiceLoss', include_background=False) >>> loss = create_loss('DiceCELoss', to_onehot_y=True, softmax=True) >>> loss = create_loss('FocalLoss', gamma=2.0) """ loss_registry = _get_loss_registry() if loss_name not in loss_registry: available = list(loss_registry.keys()) raise ValueError(f"Unknown loss: {loss_name}. Available losses: {available}") loss_fn = loss_registry[loss_name](**kwargs) return attach_loss_metadata(loss_fn, loss_name)
[docs]def list_available_losses() -> List[str]: """List all available loss functions.""" return list(_get_loss_registry().keys())
__all__ = [ "create_loss", "list_available_losses", "LossMetadata", "get_loss_metadata", "get_loss_metadata_for_module", ]