"""
Regularization losses for connectomics.
These losses encourage specific properties in the predictions, such as:
- Binary outputs
- Consistency between related prediction tasks
- Non-overlapping predictions
All losses are implemented as nn.Module for consistency with MONAI.
"""
from __future__ import annotations
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]class BinaryRegularization(nn.Module):
"""
Regularization encouraging outputs to be binary (close to 0 or 1).
Penalizes predictions that are close to 0.5 (uncertain).
Args:
min_threshold: Minimum threshold for clamping (default: 1e-2)
Example:
>>> reg = BinaryRegularization()
>>> pred = torch.sigmoid(torch.randn(1, 1, 64, 64, 64))
>>> loss = reg(pred)
"""
def __init__(self, min_threshold: float = 1e-2, apply_sigmoid: bool = True):
super().__init__()
self.min_threshold = min_threshold
self.apply_sigmoid = apply_sigmoid
[docs] def forward(
self,
pred: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute binary regularization loss.
Args:
pred: Predictions (logits or probabilities)
mask: Optional spatial weight mask
Returns:
Regularization loss
"""
if self.apply_sigmoid:
pred = torch.sigmoid(pred)
# Distance from 0.5 (most uncertain)
diff = torch.abs(pred - 0.5)
diff = torch.clamp(diff, min=self.min_threshold)
# Penalize being close to 0.5
loss = 1.0 / diff
if mask is not None:
loss = loss * mask
return loss.mean()
[docs]class ForegroundDistanceConsistency(nn.Module):
"""
Consistency regularization between binary foreground mask and signed distance transform.
Encourages foreground predictions to be consistent with distance transform predictions.
Example:
>>> reg = ForegroundDistanceConsistency()
>>> fg_logits = torch.randn(1, 1, 64, 64, 64)
>>> dt_pred = torch.randn(1, 1, 64, 64, 64)
>>> loss = reg(fg_logits, dt_pred)
"""
[docs] def forward(
self,
foreground_logits: torch.Tensor,
distance_transform: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute consistency loss between foreground and distance transform.
Args:
foreground_logits: Binary foreground logits
distance_transform: Signed distance transform predictions
mask: Optional spatial weight mask
Returns:
Consistency loss
"""
# Log probabilities for numerical stability
log_prob_pos = F.logsigmoid(foreground_logits)
log_prob_neg = F.logsigmoid(-foreground_logits)
# Distance transform (normalized with tanh)
distance = torch.tanh(distance_transform)
dist_pos = torch.clamp(distance, min=0.0) # Positive distances (inside)
dist_neg = -torch.clamp(distance, max=0.0) # Negative distances (outside)
# Consistency: high positive prob should match positive distances
loss_pos = -log_prob_pos * dist_pos
loss_neg = -log_prob_neg * dist_neg
loss = loss_pos + loss_neg
if mask is not None:
loss = loss * mask
return loss.mean()
[docs]class ContourDistanceConsistency(nn.Module):
"""
Consistency regularization between instance contour map and signed distance transform.
Encourages contour predictions (high at boundaries) to be consistent with
distance transform predictions (low magnitude at boundaries).
Example:
>>> reg = ContourDistanceConsistency()
>>> contour_logits = torch.randn(1, 1, 64, 64, 64)
>>> dt_pred = torch.randn(1, 1, 64, 64, 64)
>>> loss = reg(contour_logits, dt_pred)
"""
[docs] def forward(
self,
contour_logits: torch.Tensor,
distance_transform: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute consistency loss between contour and distance transform.
Args:
contour_logits: Instance contour logits
distance_transform: Signed distance transform predictions
mask: Optional spatial weight mask
Returns:
Consistency loss
"""
contour_prob = torch.sigmoid(contour_logits)
distance_abs = torch.abs(torch.tanh(distance_transform))
if contour_prob.shape != distance_abs.shape:
raise ValueError(
f"Shape mismatch: contour_prob={contour_prob.shape} "
f"vs distance_abs={distance_abs.shape}"
)
# Penalize: high contour prob should match low distance
loss = contour_prob * distance_abs
loss = loss**2
if mask is not None:
loss = loss * mask
return loss.mean()
[docs]class ForegroundContourConsistency(nn.Module):
"""
Consistency regularization between binary foreground and instance contour maps.
Encourages contour predictions to align with foreground edges detected via Sobel filters.
Args:
kernel_half_size: Half-size of edge detection kernel (default: 1)
eps: Small epsilon for numerical stability (default: 1e-7)
Example:
>>> reg = ForegroundContourConsistency()
>>> fg_logits = torch.randn(1, 1, 64, 64, 64)
>>> contour_logits = torch.randn(1, 1, 64, 64, 64)
>>> loss = reg(fg_logits, contour_logits)
"""
def __init__(self, kernel_half_size: int = 1, eps: float = 1e-7):
super().__init__()
self.kernel_size = 2 * kernel_half_size + 1
self.eps = eps
# Sobel filters for edge detection
sobel = torch.tensor([1.0, 0.0, -1.0])
self.register_buffer("sobel_x", sobel.view(1, 1, 1, 1, 3))
self.register_buffer("sobel_y", sobel.view(1, 1, 1, 3, 1))
[docs] def forward(
self,
foreground_logits: torch.Tensor,
contour_logits: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute consistency loss between foreground edges and contours.
Args:
foreground_logits: Binary foreground logits
contour_logits: Instance contour logits
mask: Optional spatial weight mask
Returns:
Consistency loss
"""
fg_prob = torch.sigmoid(foreground_logits)
contour_prob = torch.sigmoid(contour_logits)
# Detect edges in foreground using Sobel filters
edge_x = F.conv3d(fg_prob, self.sobel_x, padding=(0, 0, 1))
edge_y = F.conv3d(fg_prob, self.sobel_y, padding=(0, 1, 0))
# Compute edge magnitude
edge = torch.sqrt(edge_x**2 + edge_y**2 + self.eps)
edge = torch.clamp(edge, min=self.eps, max=1.0 - self.eps)
# Max pooling to expand edge regions
edge = F.pad(edge, (1, 1, 1, 1, 0, 0))
edge = F.max_pool3d(edge, kernel_size=(1, self.kernel_size, self.kernel_size), stride=1)
if edge.shape != contour_prob.shape:
raise ValueError(
f"Shape mismatch: edge={edge.shape} vs contour_prob={contour_prob.shape}"
)
# MSE between detected edges and predicted contours
loss = F.mse_loss(edge, contour_prob, reduction="none")
if mask is not None:
loss = loss * mask
return loss.mean()
[docs]class NonOverlapRegularization(nn.Module):
"""
Regularization preventing overlapping predictions.
Specifically designed for synaptic polarity prediction where pre- and post-synaptic
masks should not overlap. Optionally masks the regularization by the cleft prediction.
Args:
cleft_masked: Whether to mask regularization by cleft prediction (default: True)
Example:
>>> reg = NonOverlapRegularization()
>>> # pred has shape (B, 3, Z, Y, X) with channels: [pre, post, cleft]
>>> pred = torch.randn(2, 3, 32, 64, 64)
>>> loss = reg(pred)
"""
def __init__(self, cleft_masked: bool = True):
super().__init__()
self.cleft_masked = cleft_masked
[docs] def forward(self, pred: torch.Tensor) -> torch.Tensor:
"""
Compute non-overlap regularization loss.
Args:
pred: Predictions with shape (B, C, Z, Y, X) where:
- Channel 0: Pre-synaptic logits
- Channel 1: Post-synaptic logits
- Channel 2: Cleft logits (optional, used if cleft_masked=True)
Returns:
Non-overlap regularization loss
"""
if pred.shape[1] < 2:
raise ValueError(
f"Expected at least 2 channels for pre/post predictions, got {pred.shape[1]}"
)
# Pre- and post-synaptic probabilities
pre_prob = torch.sigmoid(pred[:, 0])
post_prob = torch.sigmoid(pred[:, 1])
# Penalize overlap
loss = pre_prob * post_prob
if self.cleft_masked and pred.shape[1] >= 3:
# Mask by cleft prediction (detached to avoid decreasing cleft prob)
cleft_prob = torch.sigmoid(pred[:, 2].detach())
loss = loss * cleft_prob
return loss.mean()
__all__ = [
"BinaryRegularization",
"ForegroundDistanceConsistency",
"ContourDistanceConsistency",
"ForegroundContourConsistency",
"NonOverlapRegularization",
]