"""
Connectomics-specific loss functions.
Custom losses that provide functionality not available in MONAI.
Only includes losses that are truly unique to connectomics use cases.
"""
from __future__ import annotations
from typing import Union
import torch
import torch.nn as nn
import torch.nn.functional as F
def _reduce_weighted_tensor(
loss_tensor: torch.Tensor,
weight: torch.Tensor | None,
reduction: str,
) -> torch.Tensor:
"""Reduce weighted loss values while excluding invalid (weight<=0) voxels for mean."""
if reduction == "none":
return loss_tensor
if reduction == "sum":
return loss_tensor.sum()
# reduction == "mean"
if weight is None:
return loss_tensor.mean()
valid = weight > 0
if valid.shape != loss_tensor.shape:
try:
valid = torch.broadcast_to(valid, loss_tensor.shape)
except RuntimeError as e:
raise ValueError(
"Weight mask shape is not broadcastable to loss tensor shape: "
f"weight={tuple(weight.shape)}, loss={tuple(loss_tensor.shape)}"
) from e
if not torch.any(valid):
return loss_tensor.new_tensor(0.0)
return loss_tensor[valid].mean()
class CrossEntropyLossWrapper(nn.Module):
"""
Wrapper for CrossEntropyLoss that handles shape conversion.
Expects labels in format [B, 1, D, H, W] and converts to [B, D, H, W]
for compatibility with PyTorch's CrossEntropyLoss.
Args:
weight: Class weights
ignore_index: Index to ignore
reduction: Reduction method
label_smoothing: Label smoothing factor (0.0 = no smoothing, 0.1 = 10% smoothing)
"""
def __init__(self, weight=None, ignore_index=-100, reduction="mean", label_smoothing=0.0):
super().__init__()
self.ce_loss = nn.CrossEntropyLoss(
weight=weight,
ignore_index=ignore_index,
reduction=reduction,
label_smoothing=label_smoothing,
)
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute cross-entropy loss with automatic shape handling.
Args:
input: Model output
2D: [B, C, H, W]
3D: [B, C, D, H, W]
target: Ground truth
2D: [B, 1, H, W] or [B, H, W]
3D: [B, 1, D, H, W] or [B, D, H, W]
Returns:
Loss value
"""
# Squeeze channel dimension if present
# 2D: [B, 1, H, W] -> [B, H, W]
# 3D: [B, 1, D, H, W] -> [B, D, H, W]
if target.dim() == 4 and target.shape[1] == 1:
target = target.squeeze(1) # 2D case
elif target.dim() == 5 and target.shape[1] == 1:
target = target.squeeze(1) # 3D case
# Convert to long type for cross-entropy
target = target.long()
return self.ce_loss(input, target)
[docs]class WeightedMSELoss(nn.Module):
"""
Weighted mean-squared error loss.
Useful for regression tasks with spatial importance weighting.
Supports optional tanh activation for distance transform predictions.
Args:
reduction: Reduction method ('mean', 'sum', 'none')
tanh: If True, apply tanh activation to predictions before computing loss.
Useful for distance transform targets in range [-1, 1].
With both pred and target in [-1, 1], MSE should be < 4.
"""
def __init__(self, reduction: str = "mean", tanh: bool = False):
super().__init__()
self.reduction = reduction
self.tanh = tanh
[docs] def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
weight: torch.Tensor = None,
) -> torch.Tensor:
"""
Compute weighted MSE loss.
Args:
pred: Predictions (logits if tanh=True, otherwise predictions)
target: Ground truth (range [-1, 1] for SDT)
weight: Optional spatial weights
Returns:
Loss value (should be < 4 for range [-1, 1])
"""
# Apply tanh activation if enabled (constrains pred to [-1, 1])
if self.tanh:
pred = torch.tanh(pred)
# Compute MSE (for range [-1,1], max error is (1-(-1))^2 = 4)
mse = (pred - target) ** 2
if weight is not None:
mse = mse * weight
loss_value = _reduce_weighted_tensor(mse, weight, self.reduction)
return loss_value
class WeightedBCEWithLogitsLoss(nn.Module):
"""
Wrapper for BCEWithLogitsLoss with optional class and spatial weighting.
Supports static ``pos_weight`` (configured once) and optional per-call
``pos_weight`` override (used by auto class-ratio mode in orchestration).
Args:
pos_weight: Optional positive-class weight (scalar or tensor)
reduction: Reduction method ('mean', 'sum', 'none')
"""
def __init__(
self,
pos_weight: Union[float, torch.Tensor, None] = None,
reduction: str = "mean",
**kwargs,
):
super().__init__()
if kwargs:
unexpected = ", ".join(sorted(kwargs.keys()))
raise TypeError(f"Unexpected argument(s) for WeightedBCEWithLogitsLoss: {unexpected}")
self.reduction = reduction
if pos_weight is not None and isinstance(pos_weight, (int, float)):
if float(pos_weight) <= 0:
raise ValueError(f"pos_weight must be > 0, got {float(pos_weight)}")
self.register_buffer("pos_weight", torch.tensor([float(pos_weight)]))
elif pos_weight is not None:
self.register_buffer("pos_weight", pos_weight)
else:
self.pos_weight = None
def forward(
self,
input: torch.Tensor,
target: torch.Tensor,
weight: torch.Tensor = None,
pos_weight: Union[float, torch.Tensor, None] = None,
) -> torch.Tensor:
"""
Compute weighted BCE with logits loss.
Args:
input: Model output (logits) [B, C, ...]
target: Ground truth [B, C, ...]
weight: Optional spatial weights/mask.
pos_weight: Optional per-call positive-class weight override.
Returns:
Loss value
"""
if pos_weight is not None and isinstance(pos_weight, (int, float)):
if float(pos_weight) <= 0:
raise ValueError(f"pos_weight must be > 0, got {float(pos_weight)}")
effective_pos_weight = torch.tensor(
[float(pos_weight)],
device=input.device,
dtype=input.dtype,
)
elif pos_weight is not None:
effective_pos_weight = pos_weight.to(device=input.device, dtype=input.dtype)
elif self.pos_weight is not None:
effective_pos_weight = self.pos_weight.to(device=input.device, dtype=input.dtype)
else:
effective_pos_weight = None
bce = F.binary_cross_entropy_with_logits(
input,
target,
pos_weight=effective_pos_weight,
reduction="none",
)
if weight is not None:
bce = bce * weight
return _reduce_weighted_tensor(bce, weight, self.reduction)
class PerChannelBCEWithLogitsLoss(nn.Module):
"""BCE loss computed independently per channel with per-channel class balancing.
Equivalent to summing C separate WeightedBCEWithLogitsLoss instances (one
per output channel) but expressed as a single loss entry for config brevity.
Matches the DeepEM per-edge loss structure where each affinity channel has
its own class-balanced BCE.
Args:
auto_pos_weight: If True, compute per-channel pos_weight as
min(n_neg / n_pos, max_pos_weight) from the current batch.
max_pos_weight: Cap for auto pos_weight to avoid extreme values.
reduction: Per-channel reduction before summing ('mean' or 'sum').
"""
def __init__(
self,
auto_pos_weight: bool = True,
max_pos_weight: float = 10.0,
reduction: str = "mean",
):
super().__init__()
self.auto_pos_weight = auto_pos_weight
self.max_pos_weight = max_pos_weight
self.reduction = reduction
def forward(
self,
input: torch.Tensor,
target: torch.Tensor,
weight: torch.Tensor = None,
) -> torch.Tensor:
"""Compute per-channel BCE loss.
Args:
input: Logits [B, C, ...].
target: Ground truth [B, C, ...].
weight: Optional spatial mask [B, C, ...] (e.g. affinity valid mask).
Returns:
Sum of per-channel reduced BCE losses.
"""
C = input.shape[1]
# Vectorised per-channel pos_weight computation.
pos_weight_tensor = None
if self.auto_pos_weight:
valid = weight > 0 if weight is not None else torch.ones_like(target, dtype=torch.bool)
pos = (target > 0) & valid
neg = (target <= 0) & valid
# Reduce over batch + spatial dims → (C,)
reduce_dims = (0,) + tuple(range(2, target.ndim))
pos_counts = pos.sum(dim=reduce_dims).float()
neg_counts = neg.sum(dim=reduce_dims).float()
pw = torch.ones(C, device=input.device, dtype=torch.float32)
nonzero = pos_counts > 0
pw[nonzero] = torch.clamp(
neg_counts[nonzero] / pos_counts[nonzero],
max=self.max_pos_weight,
)
# (1, C, 1, ..., 1) for broadcasting; cast to input dtype for AMP
pos_weight_tensor = pw.reshape(1, C, *([1] * (target.ndim - 2))).to(input.dtype)
bce = F.binary_cross_entropy_with_logits(
input,
target,
pos_weight=pos_weight_tensor,
reduction="none",
)
if weight is not None:
bce = bce * weight
# Per-channel reduction, then sum across channels.
total = input.new_tensor(0.0)
for c in range(C):
ch_bce = bce[:, c : c + 1, ...]
ch_w = weight[:, c : c + 1, ...] if weight is not None else None
total = total + _reduce_weighted_tensor(ch_bce, ch_w, self.reduction)
return total
[docs]class WeightedMAELoss(nn.Module):
"""
Weighted mean absolute error loss.
Useful for regression tasks with spatial importance weighting.
Supports optional tanh activation for distance transform predictions.
Args:
reduction: Reduction method ('mean', 'sum', 'none')
tanh: If True, apply tanh activation to predictions before computing loss.
Useful for distance transform targets in range [-1, 1].
"""
def __init__(self, reduction: str = "mean", tanh: bool = False):
super().__init__()
self.reduction = reduction
self.tanh = tanh
[docs] def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
weight: torch.Tensor = None,
) -> torch.Tensor:
"""
Compute weighted MAE loss.
Args:
pred: Predictions (logits if tanh=True, otherwise predictions)
target: Ground truth
weight: Optional spatial weights
Returns:
Loss value
"""
# Apply tanh activation if enabled
if self.tanh:
pred = torch.tanh(pred)
mae = torch.abs(pred - target)
if weight is not None:
mae = mae * weight
return _reduce_weighted_tensor(mae, weight, self.reduction)
class SmoothL1Loss(nn.Module):
"""
Smooth L1 (Huber) loss with optional tanh activation and spatial weighting.
Useful for distance transform regression where large outliers should be
down-weighted relative to MSE.
"""
def __init__(self, beta: float = 1.0, reduction: str = "mean", tanh: bool = False):
super().__init__()
self.beta = beta
self.reduction = reduction
self.tanh = tanh
def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
weight: torch.Tensor = None,
) -> torch.Tensor:
if self.tanh:
pred = torch.tanh(pred)
loss = F.smooth_l1_loss(pred, target, beta=self.beta, reduction="none")
if weight is not None:
loss = loss * weight
return _reduce_weighted_tensor(loss, weight, self.reduction)
[docs]class GANLoss(nn.Module):
"""
GAN loss for adversarial training.
Supports vanilla, LSGAN, and WGAN-GP objectives.
Based on https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
Args:
gan_mode: GAN objective type ('vanilla', 'lsgan', 'wgangp')
target_real_label: Label for real images
target_fake_label: Label for fake images
Note:
Do not use sigmoid as the last layer of Discriminator.
LSGAN needs no sigmoid. Vanilla GANs will handle it with BCEWithLogitsLoss.
"""
def __init__(
self,
gan_mode: str = "lsgan",
target_real_label: float = 1.0,
target_fake_label: float = 0.0,
):
super().__init__()
self.register_buffer("real_label", torch.tensor(target_real_label))
self.register_buffer("fake_label", torch.tensor(target_fake_label))
self.gan_mode = gan_mode
if gan_mode == "lsgan":
self.loss = nn.MSELoss()
elif gan_mode == "vanilla":
self.loss = nn.BCEWithLogitsLoss()
elif gan_mode in ["wgangp"]:
self.loss = None
else:
raise NotImplementedError(f"GAN mode {gan_mode} not implemented")
[docs] def get_target_tensor(
self,
prediction: torch.Tensor,
target_is_real: bool,
) -> torch.Tensor:
"""
Create label tensors with the same size as the input.
Args:
prediction: Discriminator prediction
target_is_real: Whether the ground truth is real or fake
Returns:
Label tensor filled with ground truth labels
"""
if target_is_real:
target_tensor = self.real_label
else:
target_tensor = self.fake_label
return target_tensor.expand_as(prediction)
[docs] def forward(
self,
prediction: torch.Tensor,
target_is_real: bool,
) -> torch.Tensor:
"""
Calculate GAN loss.
Args:
prediction: Discriminator output
target_is_real: Whether ground truth labels are for real or fake images
Returns:
Calculated loss
"""
if self.gan_mode in ["lsgan", "vanilla"]:
target_tensor = self.get_target_tensor(prediction, target_is_real)
loss = self.loss(prediction, target_tensor)
elif self.gan_mode == "wgangp":
if target_is_real:
loss = -prediction.mean()
else:
loss = prediction.mean()
return loss
__all__ = [
"CrossEntropyLossWrapper",
"WeightedBCEWithLogitsLoss",
"PerChannelBCEWithLogitsLoss",
"WeightedMSELoss",
"WeightedMAELoss",
"SmoothL1Loss",
"GANLoss",
]