"""
Build MONAI transform pipelines from Hydra configuration.
Modern replacement for monai_compose.py that works with the new Hydra config system.
"""
from __future__ import annotations
from functools import partial
import torch
from monai.transforms import EnsureChannelFirstd # Ensure channel-first format for 2D/3D images
from monai.transforms import LoadImaged # For filename-based datasets (PNG, JPG, etc.)
from monai.transforms import (
BorderPadd,
CenterSpatialCropd,
Compose,
Lambdad,
OneOf,
RandAdjustContrastd,
RandAffined,
RandFlipd,
RandGaussianNoised,
RandRotate90d,
RandShiftIntensityd,
RandSpatialCropd,
Resized,
SpatialPadd,
ToTensord,
)
# Import custom loader for HDF5/TIFF volumes
from connectomics.data.io.transforms import LoadVolumed
from connectomics.data.processing.nnunet_preprocess import NNUNetPreprocessd
from ...config.schema import AugmentationConfig, Config
from .transforms import (
RandAxisPermuted,
RandCopyPasted,
RandCutBlurd,
RandCutNoised,
RandElasticd,
RandMisAlignmentd,
RandMissingPartsd,
RandMissingSectiond,
RandMixupd,
RandMotionBlurd,
RandMulAddIntensityd,
RandRotate90Alld,
RandSliceDropd,
RandSliceDropZd,
RandSliceShiftd,
RandSliceShiftZd,
RandStriped,
ResizeByFactord,
SmartNormalizeIntensityd,
)
def _strict_binarize_mask(mask, threshold: float = 0.0):
"""Binarize mask with strict greater-than semantics (mask > threshold)."""
if torch.is_tensor(mask):
return (mask > threshold).to(dtype=mask.dtype)
return (mask > threshold).astype(mask.dtype, copy=False)
def _target_context(cfg: Config) -> tuple[tuple[int, ...], tuple[int, ...]]:
"""Return ``(pre, post)`` per-axis voxel counts for target_context extension.
Length-3 configs ``[a, b, c]`` are interpreted as trailing-only extension
(``pre = (0,...,0)``, ``post = (a, b, c)``) for legacy compatibility.
Length-6 configs ``[pre0, pre1, pre2, post0, post1, post2]`` are taken
literally.
"""
spatial_ndim = len(cfg.data.dataloader.patch_size) if cfg.data.dataloader.patch_size else 0
raw = getattr(cfg.data.dataloader, "target_context", None) or []
values = [int(v) for v in raw]
if not values:
zero = tuple(0 for _ in range(spatial_ndim))
return zero, zero
if len(values) == spatial_ndim:
return tuple(0 for _ in range(spatial_ndim)), tuple(values)
if len(values) == 2 * spatial_ndim:
return tuple(values[:spatial_ndim]), tuple(values[spatial_ndim:])
raise ValueError(
"data.dataloader.target_context length must be 0, "
f"{spatial_ndim} (trailing-only), or {2 * spatial_ndim} (pre+post). "
f"Got {raw}."
)
def _effective_patch_size(cfg: Config) -> tuple[int, ...] | None:
patch_size = tuple(cfg.data.dataloader.patch_size) if cfg.data.dataloader.patch_size else None
if patch_size is None:
return None
pre, post = _target_context(cfg)
return tuple(int(patch_size[i]) + pre[i] + post[i] for i in range(len(patch_size)))
def _append_banis_pre_target_transforms(transforms: list, label_cfg) -> None:
if bool(getattr(label_cfg, "relabel_connected_components", False)):
from ..processing.transforms import RelabelConnectedComponentsd
transforms.append(
RelabelConnectedComponentsd(
keys=["label"],
connectivity=int(getattr(label_cfg, "relabel_connectivity", 6)),
)
)
def _label_transform_emits_mask(label_cfg) -> bool:
"""Return True when label_transform.targets includes an affinity task.
Affinity target generation emits a paired ``f"{key}_mask"`` key carrying
the per-channel loss-mask; downstream pipeline steps (target_context crop,
ToTensor, collate) need to know to forward that key.
"""
if label_cfg is None:
return False
targets = getattr(label_cfg, "targets", None) or []
for task in targets:
if isinstance(task, dict):
name = task.get("name") or task.get("task") or task.get("type")
else:
name = (
getattr(task, "name", None)
or getattr(task, "task", None)
or getattr(task, "type", None)
)
if name == "affinity":
return True
return False
def _label_mask_keys(cfg: Config, keys: list[str]) -> list[str]:
"""Return the mask keys that ``MultiTaskLabelTransformd`` will emit."""
label_cfg = getattr(cfg.data, "label_transform", None)
if not _label_transform_emits_mask(label_cfg):
return []
return [f"{k}_mask" for k in keys if k == "label"]
def _append_target_context_crop(transforms: list, cfg: Config) -> None:
pre, post = _target_context(cfg)
if not any(v > 0 for v in pre) and not any(v > 0 for v in post):
return
from ..processing.transforms import LeadingSpatialCropd
transforms.append(
LeadingSpatialCropd(
roi_size=tuple(cfg.data.dataloader.patch_size),
roi_start=pre,
)
)
def _build_nnunet_preprocess_transform(keys, nnunet_pre_cfg, source_spacing):
"""Build NNUNetPreprocessd transform from config."""
source_spacing = getattr(nnunet_pre_cfg, "source_spacing", None) or source_spacing
return NNUNetPreprocessd(
keys=keys,
image_key="image",
enabled=True,
crop_to_nonzero=getattr(nnunet_pre_cfg, "crop_to_nonzero", True),
source_spacing=source_spacing,
target_spacing=getattr(nnunet_pre_cfg, "target_spacing", None),
normalization=getattr(nnunet_pre_cfg, "normalization", "zscore"),
normalization_use_nonzero_mask=getattr(
nnunet_pre_cfg, "normalization_use_nonzero_mask", True
),
clip_percentile_low=getattr(nnunet_pre_cfg, "clip_percentile_low", 0.0),
clip_percentile_high=getattr(nnunet_pre_cfg, "clip_percentile_high", 1.0),
force_separate_z=getattr(nnunet_pre_cfg, "force_separate_z", None),
anisotropy_threshold=getattr(nnunet_pre_cfg, "anisotropy_threshold", 3.0),
image_order=getattr(nnunet_pre_cfg, "image_order", 3),
label_order=getattr(nnunet_pre_cfg, "label_order", 0),
order_z=getattr(nnunet_pre_cfg, "order_z", 0),
)
def _build_eval_transforms_impl(
cfg: Config, mode: str = "val", keys: list[str] = None, skip_loading: bool = False
) -> Compose:
"""
Internal implementation for building evaluation transforms (validation or test).
This function contains the shared logic between validation and test transforms,
with mode-specific branching for key differences.
Args:
cfg: Hydra Config object
mode: 'val' or 'test' mode
keys: Keys to transform (default: auto-detected based on mode)
skip_loading: Skip LoadVolumed (for pre-cached datasets)
Returns:
Composed MONAI transforms (no augmentation)
"""
data_cfg = cfg.data
def _resolve_eval_split():
if mode == "val":
return data_cfg.val
if mode == "tune":
return data_cfg.val
return data_cfg.test
if keys is None:
# Auto-detect keys based on mode
if mode == "val":
# Validation: default to image+label
keys = ["image", "label"]
if (getattr(getattr(data_cfg, "val", None), "label_aux", None) is not None) or (
getattr(getattr(data_cfg, "train", None), "label_aux", None) is not None
):
keys.append("label_aux")
# Add mask if val_mask or train_mask exists
if (
hasattr(data_cfg, "val")
and hasattr(data_cfg.val, "mask")
and data_cfg.val.mask is not None
) or (
hasattr(data_cfg, "train")
and hasattr(data_cfg.train, "mask")
and data_cfg.train.mask is not None
):
keys.append("mask")
else: # mode == "test" or "tune"
# Test/inference: default to image only
eval_split = _resolve_eval_split()
keys = ["image"]
if eval_split.label is not None:
keys.append("label")
if eval_split.label_aux is not None:
keys.append("label_aux")
if eval_split.mask is not None:
keys.append("mask")
transforms = []
# Load images first - use appropriate loader based on dataset type
# Skip loading if using pre-cached datasets
if not skip_loading:
dataset_type = (
getattr(data_cfg.train, "dataset_type", None)
or getattr(data_cfg.val, "dataset_type", None)
or "volume"
)
if dataset_type == "filename":
# For filename-based datasets (PNG, JPG, etc.), use MONAI's LoadImaged
transforms.append(LoadImaged(keys=keys, image_only=False))
# Ensure channel-first format [C, H, W] or [C, D, H, W]
transforms.append(EnsureChannelFirstd(keys=keys))
else:
# For volume-based datasets (HDF5, TIFF volumes), use custom LoadVolumed
transpose_axes = (
data_cfg.data_transform.val_transpose
if data_cfg.data_transform.val_transpose
else []
)
transforms.append(
LoadVolumed(keys=keys, transpose_axes=transpose_axes if transpose_axes else None)
)
nnunet_pre_cfg = getattr(data_cfg, "nnunet_preprocessing", None)
if mode in {"test", "tune"}:
source_spacing = _resolve_eval_split().resolution
elif mode == "val":
source_spacing = getattr(data_cfg.val, "resolution", None) or getattr(
data_cfg.train, "resolution", None
)
else:
source_spacing = getattr(data_cfg.train, "resolution", None)
nnunet_pre_enabled = bool(getattr(nnunet_pre_cfg, "enabled", False))
if not skip_loading and nnunet_pre_enabled:
transforms.append(_build_nnunet_preprocess_transform(keys, nnunet_pre_cfg, source_spacing))
# Apply volumetric split if enabled
if data_cfg.split_enabled:
from connectomics.data.datasets.split import ApplyVolumetricSplitd
transforms.append(ApplyVolumetricSplitd(keys=keys))
# Apply resize if configured (before cropping).
# Validation uses target spatial size so patch-based val can mirror train-time
# crop-then-upsample workflows. Test/tune interpret data_transform.resize as
# the model-space patch size and derive full-volume scale factors from
# resize / dataloader.patch_size.
paired_resize_size = data_cfg.data_transform.resize if mode == "val" else None
paired_resize_factors = None
mask_cfg = getattr(data_cfg, "mask_transform", None) or data_cfg.data_transform
if mode in {"test", "tune"} and data_cfg.data_transform.resize:
patch_size_cfg = data_cfg.dataloader.patch_size
if (
patch_size_cfg
and len(patch_size_cfg) == len(data_cfg.data_transform.resize)
and all(float(v) > 0 for v in patch_size_cfg)
):
paired_resize_factors = [
float(out_size) / float(in_size)
for out_size, in_size in zip(data_cfg.data_transform.resize, patch_size_cfg)
]
if paired_resize_size:
transforms.append(
Resized(
keys=["image"],
spatial_size=paired_resize_size,
mode="bilinear",
align_corners=True,
)
)
label_mask_keys = [k for k in keys if k in ["label", "mask"]]
if label_mask_keys:
transforms.append(
Resized(
keys=label_mask_keys,
spatial_size=paired_resize_size,
mode="nearest",
align_corners=None,
)
)
elif paired_resize_factors:
transforms.append(
ResizeByFactord(
keys=["image"],
scale_factors=paired_resize_factors,
mode="bilinear",
align_corners=True,
)
)
label_mask_keys = [k for k in keys if k in ["label", "mask"]]
if label_mask_keys:
transforms.append(
ResizeByFactord(
keys=label_mask_keys,
scale_factors=paired_resize_factors,
mode="nearest",
align_corners=None,
)
)
else:
image_resize_factors = getattr(data_cfg.image_transform, "resize", None)
mask_resize_factors = None
if mode in {"test", "tune"} and mask_cfg.resize is not None:
mask_resize_factors = mask_cfg.resize
if image_resize_factors is not None and image_resize_factors:
transforms.append(
ResizeByFactord(
keys=["image"],
scale_factors=image_resize_factors,
mode="bilinear",
align_corners=True,
)
)
if "label" in keys:
transforms.append(
ResizeByFactord(
keys=["label"],
scale_factors=image_resize_factors,
mode="nearest",
align_corners=None,
)
)
# By default, mask follows image resize unless mask_transform explicitly overrides it.
if "mask" in keys and mask_resize_factors is None:
transforms.append(
ResizeByFactord(
keys=["mask"],
scale_factors=image_resize_factors,
mode="nearest",
align_corners=None,
)
)
if mask_resize_factors is not None and mask_resize_factors and "mask" in keys:
transforms.append(
ResizeByFactord(
keys=["mask"],
scale_factors=mask_resize_factors,
mode="nearest",
align_corners=None,
)
)
# Optional mask binarization for inference masks (e.g., enforce mask > 0).
mask_binarize = False
mask_threshold = 0.0
if mode in {"test", "tune"}:
mask_binarize = bool(getattr(mask_cfg, "binarize", False))
mask_threshold = float(getattr(mask_cfg, "threshold", 0.0))
if "mask" in keys and mask_binarize:
transforms.append(
Lambdad(
keys=["mask"],
func=partial(_strict_binarize_mask, threshold=mask_threshold),
)
)
patch_size = (
_effective_patch_size(cfg)
if mode == "val"
else tuple(data_cfg.dataloader.patch_size) if data_cfg.dataloader.patch_size else None
)
if patch_size and all(size > 0 for size in patch_size):
transforms.append(
SpatialPadd(
keys=keys,
spatial_size=patch_size,
constant_values=0.0,
)
)
if mode in {"test", "tune"}:
context_pad = getattr(data_cfg.data_transform, "pad_size", None)
if context_pad and any(int(v) > 0 for v in context_pad):
# Explicit test-time context padding is only for inference inputs.
# Labels stay in the original FOV; masks get zero-padded to match the image.
transforms.append(
BorderPadd(
keys=["image"],
spatial_border=tuple(int(v) for v in context_pad),
mode=getattr(data_cfg.data_transform, "pad_mode", "reflect"),
)
)
if "mask" in keys:
# Keep mask context empty outside the source FOV.
transforms.append(
BorderPadd(
keys=["mask"],
spatial_border=tuple(int(v) for v in context_pad),
mode="constant",
constant_values=0.0,
)
)
# Add spatial cropping - MODE-SPECIFIC
# Validation: Apply center crop for patch-based validation
# Test: Skip cropping to enable sliding window inference on full volumes
if mode == "val":
if patch_size and all(size > 0 for size in patch_size):
if bool(getattr(data_cfg.dataloader, "val_random_sampling", False)):
transforms.append(
RandSpatialCropd(
keys=keys,
roi_size=patch_size,
random_center=True,
random_size=False,
)
)
else:
transforms.append(
CenterSpatialCropd(
keys=keys,
roi_size=patch_size,
)
)
# else: mode == "test" -> no cropping for sliding window inference
# Normalization - use smart normalization
image_transform = data_cfg.image_transform
if (not nnunet_pre_enabled) and image_transform.normalize != "none":
transforms.append(
SmartNormalizeIntensityd(
keys=["image"],
mode=image_transform.normalize,
clip_percentile_low=getattr(image_transform, "clip_percentile_low", 0.0),
clip_percentile_high=getattr(image_transform, "clip_percentile_high", 1.0),
)
)
label_cfg = getattr(data_cfg, "label_transform", None)
if mode == "val" and "label" in keys and label_cfg is not None:
_append_banis_pre_target_transforms(transforms, label_cfg)
# Only process labels if 'label' is in keys
if "label" in keys:
# Label transformations (affinity, distance transform, etc.)
# For test/tune modes: NEVER apply label transforms
# (keep raw instance labels for evaluation)
# For val mode: use training label_transform config
label_cfg = None
if mode == "val":
# Validation always uses training label_transform
if hasattr(data_cfg, "label_transform"):
label_cfg = data_cfg.label_transform
# Apply label transforms if configured
if label_cfg is not None:
from ..processing.build import create_label_transform_pipeline
from ..processing.transforms import SegErosionInstanced
# Apply instance erosion first if specified
if hasattr(label_cfg, "erosion") and label_cfg.erosion > 0:
transforms.append(SegErosionInstanced(keys=["label"], tsz_h=label_cfg.erosion))
# Build label transform pipeline directly from label_transform config
label_transform = create_label_transform_pipeline(label_cfg)
if isinstance(label_transform, Compose):
transforms.extend(label_transform.transforms)
else:
transforms.append(label_transform)
if mode == "val":
_append_target_context_crop(transforms, cfg)
# NOTE: Do NOT squeeze labels here!
# - DiceLoss needs (B, 1, H, W) with to_onehot_y=True
# - CrossEntropyLoss needs (B, H, W)
# Squeezing is handled in the loss wrapper instead
# Final conversion to tensor with float32 dtype
transforms.append(ToTensord(keys=keys, dtype=torch.float32))
mask_keys = _label_mask_keys(cfg, keys)
if mask_keys:
transforms.append(ToTensord(keys=mask_keys, dtype=torch.bool, allow_missing_keys=True))
return Compose(transforms)
def _build_augmentations(aug_cfg: AugmentationConfig, keys: list[str], do_2d: bool = False) -> list:
"""
Build augmentation transforms from config.
Args:
aug_cfg: AugmentationConfig object
keys: Keys to augment
do_2d: Whether data is 2D (True) or 3D (False)
Returns:
List of MONAI transforms
"""
transforms = []
# Standard geometric augmentations
if aug_cfg.axis_permute.enabled and not do_2d:
transforms.append(
RandAxisPermuted(
keys=keys,
prob=aug_cfg.axis_permute.prob,
include_identity=aug_cfg.axis_permute.include_identity,
)
)
if aug_cfg.rotate90_all.enabled and not do_2d:
transforms.append(
RandRotate90Alld(
keys=keys,
prob=aug_cfg.rotate90_all.prob,
include_identity=aug_cfg.rotate90_all.include_identity,
)
)
if aug_cfg.flip.enabled:
spatial_axis = aug_cfg.flip.spatial_axis
if isinstance(spatial_axis, (list, tuple)):
# One RandFlipd per axis so each is flipped independently
for ax in spatial_axis:
transforms.append(RandFlipd(keys=keys, prob=aug_cfg.flip.prob, spatial_axis=ax))
else:
transforms.append(
RandFlipd(keys=keys, prob=aug_cfg.flip.prob, spatial_axis=spatial_axis)
)
if aug_cfg.rotate.enabled:
spatial_axes_cfg = tuple(getattr(aug_cfg.rotate, "spatial_axes", ()) or ())
if spatial_axes_cfg:
spatial_axes = spatial_axes_cfg
else:
spatial_axes = (0, 1) if do_2d else (1, 2)
transforms.append(
RandRotate90d(
keys=keys,
prob=aug_cfg.rotate.prob,
spatial_axes=spatial_axes,
)
)
# BANIS-style sparse slice defects are independent augmentations applied to
# images only. They are separate from the legacy z-only defect family below.
if aug_cfg.slice_drop.enabled:
transforms.append(
RandSliceDropd(
keys=["image"],
prob=aug_cfg.slice_drop.prob,
slice_prob=aug_cfg.slice_drop.slice_prob,
spatial_axis=aug_cfg.slice_drop.spatial_axis,
fill_value=aug_cfg.slice_drop.fill_value,
preserve_boundaries=aug_cfg.slice_drop.preserve_boundaries,
)
)
if aug_cfg.slice_shift.enabled:
transforms.append(
RandSliceShiftd(
keys=["image"],
prob=aug_cfg.slice_shift.prob,
slice_prob=aug_cfg.slice_shift.slice_prob,
shift_magnitude=aug_cfg.slice_shift.shift_magnitude,
spatial_axis=aug_cfg.slice_shift.spatial_axis,
wrap=aug_cfg.slice_shift.wrap,
)
)
if aug_cfg.affine.enabled:
# Adjust affine parameters for 2D vs 3D data
# For 2D: use only the first element of each range
# For 3D: use all three elements
if do_2d:
rotate_range = (aug_cfg.affine.rotate_range[0],)
scale_range = (aug_cfg.affine.scale_range[0],)
shear_range = (aug_cfg.affine.shear_range[0],)
else:
rotate_range = aug_cfg.affine.rotate_range
scale_range = aug_cfg.affine.scale_range
shear_range = aug_cfg.affine.shear_range
# Interpolation per key: bilinear for images, nearest for labels/masks
affine_modes = ["bilinear" if k == "image" else "nearest" for k in keys]
transforms.append(
RandAffined(
keys=keys,
prob=aug_cfg.affine.prob,
rotate_range=rotate_range,
scale_range=scale_range,
shear_range=shear_range,
mode=affine_modes,
padding_mode="reflection",
)
)
if aug_cfg.elastic.enabled:
# Unified elastic deformation that supports both 2D and 3D
elastic_modes = ["bilinear" if k == "image" else "nearest" for k in keys]
transforms.append(
RandElasticd(
keys=keys,
do_2d=do_2d,
prob=aug_cfg.elastic.prob,
sigma_range=aug_cfg.elastic.sigma_range,
magnitude_range=aug_cfg.elastic.magnitude_range,
mode=elastic_modes,
)
)
# Intensity augmentations (only for images)
if aug_cfg.intensity.enabled:
if getattr(aug_cfg.intensity, "banis_style", False):
transforms.append(
RandMulAddIntensityd(
keys=["image"],
prob=aug_cfg.intensity.mul_add_prob,
mul_range=aug_cfg.intensity.mul_range,
add_range=aug_cfg.intensity.add_range,
)
)
if aug_cfg.intensity.gaussian_noise_prob > 0:
transforms.append(
RandGaussianNoised(
keys=["image"],
prob=aug_cfg.intensity.gaussian_noise_prob,
std=aug_cfg.intensity.gaussian_noise_std,
sample_std=True,
)
)
else:
if aug_cfg.intensity.gaussian_noise_prob > 0:
transforms.append(
RandGaussianNoised(
keys=["image"],
prob=aug_cfg.intensity.gaussian_noise_prob,
std=aug_cfg.intensity.gaussian_noise_std,
sample_std=True,
)
)
if aug_cfg.intensity.shift_intensity_prob > 0:
transforms.append(
RandShiftIntensityd(
keys=["image"],
prob=aug_cfg.intensity.shift_intensity_prob,
offsets=aug_cfg.intensity.shift_intensity_offset,
)
)
if aug_cfg.intensity.contrast_prob > 0:
transforms.append(
RandAdjustContrastd(
keys=["image"],
prob=aug_cfg.intensity.contrast_prob,
gamma=aug_cfg.intensity.contrast_range,
)
)
# EM-specific defect augmentations.
# When defect_mutex is True, at most one defect fires per sample.
defect_transforms = []
if aug_cfg.slice_shift_z.enabled:
defect_transforms.append(
RandSliceShiftZd(
keys=["image"],
prob=aug_cfg.slice_shift_z.prob,
displacement=aug_cfg.slice_shift_z.displacement,
rotate_ratio=aug_cfg.slice_shift_z.rotate_ratio,
)
)
if aug_cfg.slice_drop_z.enabled:
defect_transforms.append(
RandSliceDropZd(
keys=["image"],
prob=aug_cfg.slice_drop_z.prob,
num_sections=aug_cfg.slice_drop_z.num_sections,
full_section_prob=aug_cfg.slice_drop_z.full_section_prob,
partial_ratio_range=aug_cfg.slice_drop_z.partial_ratio_range,
fill_value_range=aug_cfg.slice_drop_z.fill_value_range,
)
)
if aug_cfg.misalignment.enabled:
defect_transforms.append(
RandMisAlignmentd(
keys=["image"],
prob=aug_cfg.misalignment.prob,
displacement=aug_cfg.misalignment.displacement,
rotate_ratio=aug_cfg.misalignment.rotate_ratio,
)
)
if aug_cfg.missing_section.enabled:
defect_transforms.append(
RandMissingSectiond(
keys=["image"],
prob=aug_cfg.missing_section.prob,
num_sections=aug_cfg.missing_section.num_sections,
full_section_prob=aug_cfg.missing_section.full_section_prob,
partial_ratio_range=aug_cfg.missing_section.partial_ratio_range,
fill_value_range=aug_cfg.missing_section.fill_value_range,
)
)
if aug_cfg.motion_blur.enabled:
defect_transforms.append(
RandMotionBlurd(
keys=["image"],
prob=aug_cfg.motion_blur.prob,
sections=aug_cfg.motion_blur.sections,
kernel_size=aug_cfg.motion_blur.kernel_size,
sigma_range=aug_cfg.motion_blur.sigma_range,
full_section_prob=aug_cfg.motion_blur.full_section_prob,
partial_ratio_range=aug_cfg.motion_blur.partial_ratio_range,
)
)
if aug_cfg.missing_parts.enabled:
defect_transforms.append(
RandMissingPartsd(
keys=["image"],
prob=aug_cfg.missing_parts.prob,
hole_range=aug_cfg.missing_parts.hole_range,
)
)
if defect_transforms:
if getattr(aug_cfg, "defect_mutex", False) and len(defect_transforms) > 1:
# Mutual exclusion: randomly pick one defect per sample.
transforms.append(OneOf(transforms=defect_transforms))
else:
transforms.extend(defect_transforms)
if aug_cfg.cut_noise.enabled:
transforms.append(
RandCutNoised(
keys=["image"],
prob=aug_cfg.cut_noise.prob,
length_ratio=aug_cfg.cut_noise.length_ratio,
noise_scale=aug_cfg.cut_noise.noise_scale,
)
)
if aug_cfg.cut_blur.enabled:
transforms.append(
RandCutBlurd(
keys=["image"],
prob=aug_cfg.cut_blur.prob,
length_ratio=aug_cfg.cut_blur.length_ratio,
down_ratio_range=aug_cfg.cut_blur.down_ratio_range,
downsample_z=aug_cfg.cut_blur.downsample_z,
)
)
if aug_cfg.stripe.enabled:
transforms.append(
RandStriped(
keys=["image"],
prob=aug_cfg.stripe.prob,
num_stripes_range=aug_cfg.stripe.num_stripes_range,
thickness_range=aug_cfg.stripe.thickness_range,
intensity_range=aug_cfg.stripe.intensity_range,
angle_range=aug_cfg.stripe.angle_range,
orientation=aug_cfg.stripe.orientation,
mode=aug_cfg.stripe.mode,
)
)
# Advanced augmentations
if aug_cfg.mixup.enabled:
transforms.append(
RandMixupd(
keys=["image"], prob=aug_cfg.mixup.prob, alpha_range=aug_cfg.mixup.alpha_range
)
)
if aug_cfg.copy_paste.enabled:
transforms.append(
RandCopyPasted(
keys=["image"],
label_key="label",
prob=aug_cfg.copy_paste.prob,
max_obj_ratio=aug_cfg.copy_paste.max_obj_ratio,
rotation_angles=aug_cfg.copy_paste.rotation_angles,
border=aug_cfg.copy_paste.border,
)
)
return transforms
__all__ = [
"build_train_transforms",
"build_val_transforms",
"build_test_transforms",
]