"""
Connectomics-specific loss functions.
Custom losses that provide functionality not available in MONAI.
Only includes losses that are truly unique to connectomics use cases.
"""
from __future__ import annotations
from typing import List, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
def _reduce_weighted_tensor(
loss_tensor: torch.Tensor,
weight: torch.Tensor | None,
reduction: str,
) -> torch.Tensor:
"""Reduce weighted loss values while excluding invalid (weight<=0) voxels for mean."""
if reduction == "none":
return loss_tensor
if reduction == "sum":
return loss_tensor.sum()
# reduction == "mean"
if weight is None:
return loss_tensor.mean()
valid = weight > 0
if valid.shape != loss_tensor.shape:
try:
valid = torch.broadcast_to(valid, loss_tensor.shape)
except RuntimeError as e:
raise ValueError(
"Weight mask shape is not broadcastable to loss tensor shape: "
f"weight={tuple(weight.shape)}, loss={tuple(loss_tensor.shape)}"
) from e
if not torch.any(valid):
return loss_tensor.new_tensor(0.0)
return loss_tensor[valid].mean()
def _soft_erode_pool(prob: torch.Tensor) -> torch.Tensor:
"""Differentiable morphological erosion using min-pool via max-pool."""
if prob.ndim == 5:
# Use an axis-aligned cross (3x1x1, 1x3x1, 1x1x3), matching clDice-style soft morphology.
p1 = -F.max_pool3d(-prob, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0))
p2 = -F.max_pool3d(-prob, kernel_size=(1, 3, 1), stride=1, padding=(0, 1, 0))
p3 = -F.max_pool3d(-prob, kernel_size=(1, 1, 3), stride=1, padding=(0, 0, 1))
return torch.minimum(torch.minimum(p1, p2), p3)
if prob.ndim == 4:
p1 = -F.max_pool2d(-prob, kernel_size=(3, 1), stride=1, padding=(1, 0))
p2 = -F.max_pool2d(-prob, kernel_size=(1, 3), stride=1, padding=(0, 1))
return torch.minimum(p1, p2)
raise ValueError(f"Expected 4D/5D tensor for soft erosion, got shape {tuple(prob.shape)}")
def _soft_dilate_pool(prob: torch.Tensor) -> torch.Tensor:
"""Differentiable morphological dilation."""
if prob.ndim == 5:
return F.max_pool3d(prob, kernel_size=3, stride=1, padding=1)
if prob.ndim == 4:
return F.max_pool2d(prob, kernel_size=3, stride=1, padding=1)
raise ValueError(f"Expected 4D/5D tensor for soft dilation, got shape {tuple(prob.shape)}")
def _soft_open_pool(prob: torch.Tensor) -> torch.Tensor:
"""Differentiable opening (erode followed by dilate)."""
return _soft_dilate_pool(_soft_erode_pool(prob))
def _soft_skeletonize_pool(prob: torch.Tensor, num_iters: int) -> torch.Tensor:
"""Iterative soft skeletonization from clDice-style morphology."""
opened = _soft_open_pool(prob)
skeleton = F.relu(prob - opened)
for _ in range(num_iters):
prob = _soft_erode_pool(prob)
opened = _soft_open_pool(prob)
delta = F.relu(prob - opened)
skeleton = skeleton + F.relu(delta - skeleton * delta)
return skeleton
class CrossEntropyLossWrapper(nn.Module):
"""
Wrapper for CrossEntropyLoss that handles shape conversion.
Expects labels in format [B, 1, D, H, W] and converts to [B, D, H, W]
for compatibility with PyTorch's CrossEntropyLoss.
Args:
weight: Class weights
ignore_index: Index to ignore
reduction: Reduction method
label_smoothing: Label smoothing factor (0.0 = no smoothing, 0.1 = 10% smoothing)
"""
def __init__(self, weight=None, ignore_index=-100, reduction="mean", label_smoothing=0.0):
super().__init__()
self.ce_loss = nn.CrossEntropyLoss(
weight=weight,
ignore_index=ignore_index,
reduction=reduction,
label_smoothing=label_smoothing,
)
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute cross-entropy loss with automatic shape handling.
Args:
input: Model output
2D: [B, C, H, W]
3D: [B, C, D, H, W]
target: Ground truth
2D: [B, 1, H, W] or [B, H, W]
3D: [B, 1, D, H, W] or [B, D, H, W]
Returns:
Loss value
"""
# Squeeze channel dimension if present
# 2D: [B, 1, H, W] -> [B, H, W]
# 3D: [B, 1, D, H, W] -> [B, D, H, W]
if target.dim() == 4 and target.shape[1] == 1:
target = target.squeeze(1) # 2D case
elif target.dim() == 5 and target.shape[1] == 1:
target = target.squeeze(1) # 3D case
# Convert to long type for cross-entropy
target = target.long()
return self.ce_loss(input, target)
[docs]class WeightedMSELoss(nn.Module):
"""
Weighted mean-squared error loss.
Useful for regression tasks with spatial importance weighting.
Supports optional tanh activation for distance transform predictions.
Args:
reduction: Reduction method ('mean', 'sum', 'none')
tanh: If True, apply tanh activation to predictions before computing loss.
Useful for distance transform targets in range [-1, 1].
With both pred and target in [-1, 1], MSE should be < 4.
"""
def __init__(self, reduction: str = "mean", tanh: bool = False):
super().__init__()
self.reduction = reduction
self.tanh = tanh
[docs] def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
weight: torch.Tensor = None,
) -> torch.Tensor:
"""
Compute weighted MSE loss.
Args:
pred: Predictions (logits if tanh=True, otherwise predictions)
target: Ground truth (range [-1, 1] for SDT)
weight: Optional spatial weights
Returns:
Loss value (should be < 4 for range [-1, 1])
"""
# Apply tanh activation if enabled (constrains pred to [-1, 1])
if self.tanh:
pred = torch.tanh(pred)
# Compute MSE (for range [-1,1], max error is (1-(-1))^2 = 4)
mse = (pred - target) ** 2
if weight is not None:
mse = mse * weight
loss_value = _reduce_weighted_tensor(mse, weight, self.reduction)
return loss_value
class WeightedBCEWithLogitsLoss(nn.Module):
"""
Wrapper for BCEWithLogitsLoss with optional class and spatial weighting.
Supports static ``pos_weight`` (configured once) and optional per-call
``pos_weight`` override (used by auto class-ratio mode in orchestration).
Args:
pos_weight: Optional positive-class weight (scalar or tensor)
reduction: Reduction method ('mean', 'sum', 'none')
"""
def __init__(
self,
pos_weight: Union[float, torch.Tensor, None] = None,
reduction: str = "mean",
**kwargs,
):
super().__init__()
if kwargs:
unexpected = ", ".join(sorted(kwargs.keys()))
raise TypeError(f"Unexpected argument(s) for WeightedBCEWithLogitsLoss: {unexpected}")
self.reduction = reduction
if pos_weight is not None and isinstance(pos_weight, (int, float)):
if float(pos_weight) <= 0:
raise ValueError(f"pos_weight must be > 0, got {float(pos_weight)}")
self.register_buffer("pos_weight", torch.tensor([float(pos_weight)]))
elif pos_weight is not None:
self.register_buffer("pos_weight", pos_weight)
else:
self.pos_weight = None
def forward(
self,
input: torch.Tensor,
target: torch.Tensor,
weight: torch.Tensor = None,
pos_weight: Union[float, torch.Tensor, None] = None,
) -> torch.Tensor:
"""
Compute weighted BCE with logits loss.
Args:
input: Model output (logits) [B, C, ...]
target: Ground truth [B, C, ...]
weight: Optional spatial weights/mask.
pos_weight: Optional per-call positive-class weight override.
Returns:
Loss value
"""
if pos_weight is not None and isinstance(pos_weight, (int, float)):
if float(pos_weight) <= 0:
raise ValueError(f"pos_weight must be > 0, got {float(pos_weight)}")
effective_pos_weight = torch.tensor(
[float(pos_weight)],
device=input.device,
dtype=input.dtype,
)
elif pos_weight is not None:
effective_pos_weight = pos_weight.to(device=input.device, dtype=input.dtype)
elif self.pos_weight is not None:
effective_pos_weight = self.pos_weight.to(device=input.device, dtype=input.dtype)
else:
effective_pos_weight = None
bce = F.binary_cross_entropy_with_logits(
input,
target,
pos_weight=effective_pos_weight,
reduction="none",
)
if weight is not None:
bce = bce * weight
return _reduce_weighted_tensor(bce, weight, self.reduction)
class PerChannelBCEWithLogitsLoss(nn.Module):
"""BCE loss computed independently per channel with per-channel class balancing.
Equivalent to summing C separate WeightedBCEWithLogitsLoss instances (one
per output channel) but expressed as a single loss entry for config brevity.
Matches the DeepEM per-edge loss structure where each affinity channel has
its own class-balanced BCE.
Args:
auto_pos_weight: If True, compute per-channel pos_weight as
min(n_neg / n_pos, max_pos_weight) from the current batch.
max_pos_weight: Cap for auto pos_weight to avoid extreme values.
reduction: Per-channel reduction before summing ('mean' or 'sum').
"""
def __init__(
self,
auto_pos_weight: bool = True,
max_pos_weight: float = 10.0,
reduction: str = "mean",
):
super().__init__()
self.auto_pos_weight = auto_pos_weight
self.max_pos_weight = max_pos_weight
self.reduction = reduction
def forward(
self,
input: torch.Tensor,
target: torch.Tensor,
weight: torch.Tensor = None,
) -> torch.Tensor:
"""Compute per-channel BCE loss.
Args:
input: Logits [B, C, ...].
target: Ground truth [B, C, ...].
weight: Optional spatial mask [B, C, ...] (e.g. affinity valid mask).
Returns:
Sum of per-channel reduced BCE losses.
"""
C = input.shape[1]
# Vectorised per-channel pos_weight computation.
pos_weight_tensor = None
if self.auto_pos_weight:
valid = weight > 0 if weight is not None else torch.ones_like(target, dtype=torch.bool)
pos = (target > 0) & valid
neg = (target <= 0) & valid
# Reduce over batch + spatial dims → (C,)
reduce_dims = (0,) + tuple(range(2, target.ndim))
pos_counts = pos.sum(dim=reduce_dims).float()
neg_counts = neg.sum(dim=reduce_dims).float()
pw = torch.ones(C, device=input.device, dtype=torch.float32)
nonzero = pos_counts > 0
pw[nonzero] = torch.clamp(
neg_counts[nonzero] / pos_counts[nonzero],
max=self.max_pos_weight,
)
# (1, C, 1, ..., 1) for broadcasting; cast to input dtype for AMP
pos_weight_tensor = pw.reshape(1, C, *([1] * (target.ndim - 2))).to(input.dtype)
bce = F.binary_cross_entropy_with_logits(
input,
target,
pos_weight=pos_weight_tensor,
reduction="none",
)
if weight is not None:
bce = bce * weight
# Per-channel reduction, then sum across channels.
total = input.new_tensor(0.0)
for c in range(C):
ch_bce = bce[:, c : c + 1, ...]
ch_w = weight[:, c : c + 1, ...] if weight is not None else None
total = total + _reduce_weighted_tensor(ch_bce, ch_w, self.reduction)
return total
[docs]class SoftClDiceLoss(nn.Module):
"""
Soft clDice loss using differentiable skeletonization.
Supports optional activation on logits (`sigmoid=True` or `softmax=True`),
MONAI-style, before topology computation.
Targets must be dense maps (one-hot or soft labels); single-channel class-index
targets are accepted and converted to one-hot for multi-class predictions.
Args:
num_iters: Number of soft-skeleton erosion iterations.
mode: ``"binary"`` (single foreground channel) or ``"multi"`` (all classes except
``background_index``).
reduction: ``"none"``, ``"mean"``, or ``"sum"``.
smooth: Smoothing constant in topology precision/sensitivity fractions.
foreground_channel: Foreground channel index used in ``mode="binary"`` when C>1.
background_index: Background channel index excluded in ``mode="multi"``.
sigmoid: Apply sigmoid activation to predictions inside ``forward``.
softmax: Apply softmax activation to predictions inside ``forward``.
clamp_probabilities: Clamp predictions/targets to ``[0, 1]`` before skeletonization.
validate_inputs: When True, enforce probability-range and minimal spatial-size checks.
validation_tolerance: Tolerance used by probability-range checks.
"""
def __init__(
self,
num_iters: int = 5,
mode: str = "binary",
reduction: str = "mean",
smooth: float = 1.0,
foreground_channel: int = 1,
background_index: int = 0,
sigmoid: bool = False,
softmax: bool = False,
clamp_probabilities: bool = False,
validate_inputs: bool = True,
validation_tolerance: float = 1e-5,
):
super().__init__()
if num_iters < 0:
raise ValueError(f"num_iters must be >= 0, got {num_iters}")
if mode not in {"binary", "multi"}:
raise ValueError(f"mode must be 'binary' or 'multi', got {mode!r}")
if reduction not in {"none", "mean", "sum"}:
raise ValueError(f"reduction must be 'none', 'mean', or 'sum', got {reduction!r}")
if smooth <= 0:
raise ValueError(f"smooth must be > 0, got {smooth}")
if sigmoid and softmax:
raise ValueError("sigmoid and softmax are mutually exclusive")
if validation_tolerance < 0:
raise ValueError(f"validation_tolerance must be >= 0, got {validation_tolerance}")
self.num_iters = int(num_iters)
self.mode = mode
self.reduction = reduction
self.smooth = float(smooth)
self.foreground_channel = int(foreground_channel)
self.background_index = int(background_index)
self.sigmoid = bool(sigmoid)
self.softmax = bool(softmax)
self.clamp_probabilities = bool(clamp_probabilities)
self.validate_inputs = bool(validate_inputs)
self.validation_tolerance = float(validation_tolerance)
def _prepare_target(self, target: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
if target.ndim == pred.ndim - 1:
target = target.unsqueeze(1)
if target.ndim != pred.ndim:
raise ValueError(
f"Target ndim ({target.ndim}) does not match prediction ndim ({pred.ndim})"
)
if target.shape[0] != pred.shape[0] or target.shape[2:] != pred.shape[2:]:
raise ValueError(
"Target shape must match prediction shape except for channel dimension: "
f"target={tuple(target.shape)}, pred={tuple(pred.shape)}"
)
if target.shape[1] == pred.shape[1]:
return target.to(device=pred.device, dtype=pred.dtype)
if target.shape[1] == 1 and pred.shape[1] > 1:
class_index = target.squeeze(1).long()
min_label = int(class_index.min().item())
max_label = int(class_index.max().item())
if min_label < 0 or max_label >= pred.shape[1]:
raise ValueError(
f"Class-index targets must be in [0, {pred.shape[1] - 1}], "
f"got min={min_label}, max={max_label}"
)
one_hot = F.one_hot(class_index, num_classes=pred.shape[1]).movedim(-1, 1)
return one_hot.to(device=pred.device, dtype=pred.dtype)
raise ValueError(
"Target channel count is incompatible with prediction: "
f"target_channels={target.shape[1]}, pred_channels={pred.shape[1]}"
)
def _apply_activation(self, pred: torch.Tensor) -> torch.Tensor:
if self.sigmoid:
return torch.sigmoid(pred)
if self.softmax:
if pred.shape[1] < 2:
raise ValueError("softmax=True requires prediction with at least 2 channels")
return F.softmax(pred, dim=1)
return pred
def _validate_probability_range(self, tensor: torch.Tensor, name: str) -> None:
if not self.validate_inputs:
return
tol = self.validation_tolerance
min_val = float(tensor.min().item())
max_val = float(tensor.max().item())
if min_val < -tol or max_val > (1.0 + tol):
raise ValueError(
f"{name} must be probabilities in [0, 1] (tolerance={tol}), "
f"got min={min_val:.6f}, max={max_val:.6f}. "
"Pass sigmoid=True/softmax=True for logits."
)
def _validate_spatial_shape(self, pred: torch.Tensor) -> None:
if not self.validate_inputs:
return
spatial_shape = tuple(pred.shape[2:])
if any(dim < 3 for dim in spatial_shape):
raise ValueError(
"SoftClDiceLoss expects each spatial dimension >= 3 for stable morphology, "
f"got spatial shape {spatial_shape}"
)
def _select_foreground_channels(
self, pred: torch.Tensor, target: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, List[int]]:
channels = pred.shape[1]
if self.mode == "binary":
fg_idx = 0 if channels == 1 else self.foreground_channel
if fg_idx < 0 or fg_idx >= channels:
raise ValueError(
f"foreground_channel={self.foreground_channel} is invalid "
f"for {channels} channels"
)
return pred[:, fg_idx : fg_idx + 1], target[:, fg_idx : fg_idx + 1], [fg_idx]
if channels == 1:
return pred, target, [0]
background_index = self.background_index
if background_index < 0:
background_index += channels
if background_index < 0 or background_index >= channels:
raise ValueError(
f"background_index={self.background_index} is invalid for {channels} channels"
)
foreground_indices = [idx for idx in range(channels) if idx != background_index]
if not foreground_indices:
raise ValueError(
f"No foreground classes available: channels={channels}, "
f"background_index={self.background_index}"
)
index_tensor = torch.tensor(foreground_indices, device=pred.device, dtype=torch.long)
return (
torch.index_select(pred, dim=1, index=index_tensor),
torch.index_select(target, dim=1, index=index_tensor),
foreground_indices,
)
def _prepare_weight(
self,
weight: torch.Tensor,
pred: torch.Tensor,
foreground_indices: List[int],
num_fg_channels: int,
) -> torch.Tensor:
if weight.ndim == pred.ndim - 1:
weight = weight.unsqueeze(1)
if weight.ndim != pred.ndim:
raise ValueError(f"Weight ndim ({weight.ndim}) must match pred ndim ({pred.ndim})")
if weight.shape[0] != pred.shape[0] or weight.shape[2:] != pred.shape[2:]:
raise ValueError(
"Weight shape must match prediction shape except for channel dimension: "
f"weight={tuple(weight.shape)}, pred={tuple(pred.shape)}"
)
weight = weight.to(device=pred.device, dtype=pred.dtype)
if weight.shape[1] == num_fg_channels:
return weight
if weight.shape[1] == 1:
return weight.expand(weight.shape[0], num_fg_channels, *weight.shape[2:])
if weight.shape[1] == pred.shape[1]:
if num_fg_channels == pred.shape[1]:
return weight
index_tensor = torch.tensor(foreground_indices, device=pred.device, dtype=torch.long)
return torch.index_select(weight, dim=1, index=index_tensor)
raise ValueError(
"Weight channel count must be 1, foreground-channel count, "
"or prediction-channel count; "
f"got {weight.shape[1]}"
)
[docs] def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
weight: torch.Tensor | None = None,
) -> torch.Tensor:
if pred.ndim not in {4, 5}:
raise ValueError(f"SoftClDiceLoss expects 4D/5D tensors, got {tuple(pred.shape)}")
pred = self._apply_activation(pred)
target = self._prepare_target(target, pred)
if self.clamp_probabilities:
pred = pred.clamp(0.0, 1.0)
target = target.clamp(0.0, 1.0)
self._validate_spatial_shape(pred)
self._validate_probability_range(pred, "pred")
self._validate_probability_range(target, "target")
pred_fg, target_fg, foreground_indices = self._select_foreground_channels(pred, target)
if self.clamp_probabilities:
pred_fg = pred_fg.clamp(0.0, 1.0)
target_fg = target_fg.clamp(0.0, 1.0)
fg_weight = None
if weight is not None:
fg_weight = self._prepare_weight(weight, pred, foreground_indices, pred_fg.shape[1])
if self.clamp_probabilities:
fg_weight = fg_weight.clamp_min(0.0)
pred_skeleton = _soft_skeletonize_pool(pred_fg, self.num_iters)
target_skeleton = _soft_skeletonize_pool(target_fg, self.num_iters)
if fg_weight is not None:
pred_eval = pred_fg * fg_weight
target_eval = target_fg * fg_weight
pred_skeleton_eval = pred_skeleton * fg_weight
target_skeleton_eval = target_skeleton * fg_weight
else:
pred_eval = pred_fg
target_eval = target_fg
pred_skeleton_eval = pred_skeleton
target_skeleton_eval = target_skeleton
spatial_dims = tuple(range(2, pred_fg.ndim))
topology_precision = (
(pred_skeleton_eval * target_eval).sum(dim=spatial_dims) + self.smooth
) / (pred_skeleton_eval.sum(dim=spatial_dims) + self.smooth)
topology_sensitivity = (
(target_skeleton_eval * pred_eval).sum(dim=spatial_dims) + self.smooth
) / (target_skeleton_eval.sum(dim=spatial_dims) + self.smooth)
cl_dice = (
2.0
* topology_precision
* topology_sensitivity
/ (topology_precision + topology_sensitivity + self.smooth)
)
loss = 1.0 - cl_dice
if self.reduction == "none":
return loss
if self.reduction == "sum":
return loss.sum()
return loss.mean()
[docs]class WeightedMAELoss(nn.Module):
"""
Weighted mean absolute error loss.
Useful for regression tasks with spatial importance weighting.
Supports optional tanh activation for distance transform predictions.
Args:
reduction: Reduction method ('mean', 'sum', 'none')
tanh: If True, apply tanh activation to predictions before computing loss.
Useful for distance transform targets in range [-1, 1].
"""
def __init__(self, reduction: str = "mean", tanh: bool = False):
super().__init__()
self.reduction = reduction
self.tanh = tanh
[docs] def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
weight: torch.Tensor = None,
) -> torch.Tensor:
"""
Compute weighted MAE loss.
Args:
pred: Predictions (logits if tanh=True, otherwise predictions)
target: Ground truth
weight: Optional spatial weights
Returns:
Loss value
"""
# Apply tanh activation if enabled
if self.tanh:
pred = torch.tanh(pred)
mae = torch.abs(pred - target)
if weight is not None:
mae = mae * weight
return _reduce_weighted_tensor(mae, weight, self.reduction)
class SmoothL1Loss(nn.Module):
"""
Smooth L1 (Huber) loss with optional tanh activation and spatial weighting.
Useful for distance transform regression where large outliers should be
down-weighted relative to MSE.
"""
def __init__(self, beta: float = 1.0, reduction: str = "mean", tanh: bool = False):
super().__init__()
self.beta = beta
self.reduction = reduction
self.tanh = tanh
def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
weight: torch.Tensor = None,
) -> torch.Tensor:
if self.tanh:
pred = torch.tanh(pred)
loss = F.smooth_l1_loss(pred, target, beta=self.beta, reduction="none")
if weight is not None:
loss = loss * weight
return _reduce_weighted_tensor(loss, weight, self.reduction)
[docs]class GANLoss(nn.Module):
"""
GAN loss for adversarial training.
Supports vanilla, LSGAN, and WGAN-GP objectives.
Based on https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
Args:
gan_mode: GAN objective type ('vanilla', 'lsgan', 'wgangp')
target_real_label: Label for real images
target_fake_label: Label for fake images
Note:
Do not use sigmoid as the last layer of Discriminator.
LSGAN needs no sigmoid. Vanilla GANs will handle it with BCEWithLogitsLoss.
"""
def __init__(
self,
gan_mode: str = "lsgan",
target_real_label: float = 1.0,
target_fake_label: float = 0.0,
):
super().__init__()
self.register_buffer("real_label", torch.tensor(target_real_label))
self.register_buffer("fake_label", torch.tensor(target_fake_label))
self.gan_mode = gan_mode
if gan_mode == "lsgan":
self.loss = nn.MSELoss()
elif gan_mode == "vanilla":
self.loss = nn.BCEWithLogitsLoss()
elif gan_mode in ["wgangp"]:
self.loss = None
else:
raise NotImplementedError(f"GAN mode {gan_mode} not implemented")
[docs] def get_target_tensor(
self,
prediction: torch.Tensor,
target_is_real: bool,
) -> torch.Tensor:
"""
Create label tensors with the same size as the input.
Args:
prediction: Discriminator prediction
target_is_real: Whether the ground truth is real or fake
Returns:
Label tensor filled with ground truth labels
"""
if target_is_real:
target_tensor = self.real_label
else:
target_tensor = self.fake_label
return target_tensor.expand_as(prediction)
[docs] def forward(
self,
prediction: torch.Tensor,
target_is_real: bool,
) -> torch.Tensor:
"""
Calculate GAN loss.
Args:
prediction: Discriminator output
target_is_real: Whether ground truth labels are for real or fake images
Returns:
Calculated loss
"""
if self.gan_mode in ["lsgan", "vanilla"]:
target_tensor = self.get_target_tensor(prediction, target_is_real)
loss = self.loss(prediction, target_tensor)
elif self.gan_mode == "wgangp":
if target_is_real:
loss = -prediction.mean()
else:
loss = prediction.mean()
return loss
__all__ = [
"CrossEntropyLossWrapper",
"GANLoss",
"PerChannelBCEWithLogitsLoss",
"SmoothL1Loss",
"SoftClDiceLoss",
"WeightedBCEWithLogitsLoss",
"WeightedMAELoss",
"WeightedMSELoss",
]