Source code for connectomics.models.losses.metadata
"""Loss metadata describing how PyTorch loss modules are invoked."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal, Optional
import torch.nn as nn
LossCallKind = Literal["pred_target", "pred_only", "pred_pred", "unsupported"]
TargetKind = Literal["dense", "class_index", "none"]
_LOSS_METADATA_BY_NAME = {
# Standard supervised segmentation losses (dense targets unless noted)
"DiceLoss": LossMetadata("DiceLoss"),
"DiceCELoss": LossMetadata("DiceCELoss"),
"DiceFocalLoss": LossMetadata("DiceFocalLoss"),
"GeneralizedDiceLoss": LossMetadata("GeneralizedDiceLoss"),
"FocalLoss": LossMetadata("FocalLoss"),
"TverskyLoss": LossMetadata("TverskyLoss"),
"BCEWithLogitsLoss": LossMetadata("BCEWithLogitsLoss"),
"CrossEntropyLoss": LossMetadata("CrossEntropyLoss", target_kind="class_index"),
"MSELoss": LossMetadata("MSELoss"),
"L1Loss": LossMetadata("L1Loss"),
# Custom supervised losses
"SmoothL1Loss": LossMetadata("SmoothL1Loss", spatial_weight_arg="weight"),
"WeightedBCEWithLogitsLoss": LossMetadata(
"WeightedBCEWithLogitsLoss", spatial_weight_arg="weight"
),
"PerChannelBCEWithLogitsLoss": LossMetadata(
"PerChannelBCEWithLogitsLoss", spatial_weight_arg="weight"
),
"WeightedMSELoss": LossMetadata("WeightedMSELoss", spatial_weight_arg="weight"),
"WeightedMAELoss": LossMetadata("WeightedMAELoss", spatial_weight_arg="weight"),
# GAN is not compatible with the generic supervised orchestrator path
"GANLoss": LossMetadata("GANLoss", call_kind="unsupported", target_kind="none"),
# Regularization losses
"BinaryRegularization": LossMetadata(
"BinaryRegularization", call_kind="pred_only", target_kind="none", spatial_weight_arg="mask"
),
"ForegroundDistanceConsistency": LossMetadata(
"ForegroundDistanceConsistency",
call_kind="pred_pred",
target_kind="none",
spatial_weight_arg="mask",
),
"ContourDistanceConsistency": LossMetadata(
"ContourDistanceConsistency",
call_kind="pred_pred",
target_kind="none",
spatial_weight_arg="mask",
),
"ForegroundContourConsistency": LossMetadata(
"ForegroundContourConsistency",
call_kind="pred_pred",
target_kind="none",
spatial_weight_arg="mask",
),
"NonOverlapRegularization": LossMetadata(
"NonOverlapRegularization", call_kind="pred_only", target_kind="none"
),
# Class name alias (CrossEntropyLossWrapper -> same metadata as CrossEntropyLoss)
"CrossEntropyLossWrapper": LossMetadata("CrossEntropyLoss", target_kind="class_index"),
}
def attach_loss_metadata(loss_fn: nn.Module, loss_name: str) -> nn.Module:
"""Attach registered loss metadata to a module instance for downstream dispatch."""
setattr(loss_fn, "_connectomics_loss_metadata", get_loss_metadata(loss_name))
return loss_fn
__all__ = [
"LossCallKind",
"TargetKind",
"LossMetadata",
"attach_loss_metadata",
"get_loss_metadata",
"get_loss_metadata_for_module",
]