Shortcuts

Source code for connectomics.data.augmentation.transforms

"""
MONAI-native transforms for connectomics-specific augmentations.

Each transform is a thin MONAI wrapper that delegates business logic
to pure functions in augment_ops.py.
"""

from __future__ import annotations

from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import torchvision.transforms.functional as tvf
from monai.config import KeysCollection
from monai.transforms import MapTransform, RandomizableTransform
from scipy.ndimage import binary_dilation, generate_binary_structure

from . import augment_ops
from .transform_utils import from_numpy as _from_numpy
from .transform_utils import has_channel_axis as _has_channel_axis
from .transform_utils import infer_depth_axis as _infer_depth_axis
from .transform_utils import infer_spatial_rank as _infer_spatial_rank
from .transform_utils import sample_count as _sample_count
from .transform_utils import sample_non_identity_permutation as _sample_non_identity_permutation
from .transform_utils import sample_non_identity_rotate_ks as _sample_non_identity_rotate_ks
from .transform_utils import sample_spatial_axis as _sample_spatial_axis
from .transform_utils import spatial_axis_to_array_axis as _spatial_axis_to_array_axis
from .transform_utils import to_numpy as _to_numpy


def _as_range(value: Union[float, Tuple[float, float], List[float]]) -> Tuple[float, float]:
    """Normalize a scalar or 2-tuple/list into a ``(low, high)`` pair for uniform sampling."""
    if isinstance(value, (tuple, list)):
        if len(value) != 2:
            raise ValueError(f"Expected a (low, high) pair, got {value!r}")
        low, high = float(value[0]), float(value[1])
    else:
        low = high = float(value)
    if low > high:
        low, high = high, low
    return low, high


[docs]class RandAxisPermuted(RandomizableTransform, MapTransform): """Randomly permute the three spatial axes of a cubic 3D volume.""" def __init__( self, keys: KeysCollection, prob: float = 1.0, include_identity: bool = True, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.include_identity = include_identity def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: if self.prob <= 0: self._do_transform = False return data d = dict(data) self.randomize(None) if not self._do_transform: return d permutation: Optional[np.ndarray] = None for key in self.key_iterator(d): if key not in d: continue arr, was_tensor, device = _to_numpy(d[key]) if _infer_spatial_rank(arr) != 3 or len(set(arr.shape[-3:])) != 1: continue if permutation is None: if self.include_identity: permutation = self.R.permutation(3).astype(np.int64) else: permutation = _sample_non_identity_permutation(self.R, 3) result = augment_ops.permute_spatial_axes( arr, permutation, has_channel_axis=_has_channel_axis(arr), ) d[key] = _from_numpy(result, was_tensor, device) return d
[docs] def randomize(self, _: Any = None) -> None: self._do_transform = self.R.rand() < self.prob
[docs]class RandRotate90Alld(RandomizableTransform, MapTransform): """Apply random quarter-turn rotations over all three 3D plane pairs.""" def __init__( self, keys: KeysCollection, prob: float = 1.0, include_identity: bool = True, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.include_identity = include_identity def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: if self.prob <= 0: self._do_transform = False return data d = dict(data) self.randomize(None) if not self._do_transform: return d rotate_ks: Optional[Tuple[int, int, int]] = None for key in self.key_iterator(d): if key not in d: continue arr, was_tensor, device = _to_numpy(d[key]) if _infer_spatial_rank(arr) != 3 or len(set(arr.shape[-3:])) != 1: continue if rotate_ks is None: if self.include_identity: rotate_ks = tuple(int(self.R.randint(0, 4)) for _ in range(3)) else: rotate_ks = _sample_non_identity_rotate_ks(self.R) result = augment_ops.apply_rotate90_all(arr, rotate_ks) d[key] = _from_numpy(result, was_tensor, device) return d
[docs] def randomize(self, _: Any = None) -> None: self._do_transform = self.R.rand() < self.prob
[docs]class RandSliceDropd(RandomizableTransform, MapTransform): """BANIS-style per-slice dropping along one sampled spatial axis.""" def __init__( self, keys: KeysCollection, prob: float = 0.5, slice_prob: float = 0.05, spatial_axis: Union[int, str, Tuple[int, ...], List[int]] = (0, 1, 2), fill_value: float = 0.0, preserve_boundaries: bool = False, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.slice_prob = slice_prob self.spatial_axis = spatial_axis self.fill_value = fill_value self.preserve_boundaries = preserve_boundaries def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: if self.prob <= 0 or self.slice_prob <= 0: self._do_transform = False return data d = dict(data) self.randomize(None) if not self._do_transform: return d chosen_spatial_axis: Optional[int] = None selected_indices: Optional[np.ndarray] = None for key in self.key_iterator(d): if key not in d: continue arr, was_tensor, device = _to_numpy(d[key]) spatial_rank = _infer_spatial_rank(arr) if spatial_rank != 3: continue if chosen_spatial_axis is None: chosen_spatial_axis = _sample_spatial_axis(self.R, self.spatial_axis, spatial_rank) array_axis = _spatial_axis_to_array_axis(arr, chosen_spatial_axis) if selected_indices is None: depth = arr.shape[array_axis] candidates = np.arange(depth, dtype=np.int64) if self.preserve_boundaries and depth > 2: candidates = candidates[1:-1] if candidates.size == 0: continue selected = self.R.rand(candidates.size) < self.slice_prob selected_indices = candidates[selected] if selected_indices is None or selected_indices.size == 0: continue result = augment_ops.fill_sections( arr, selected_indices, fill_value=float(self.fill_value), depth_axis=array_axis, ) d[key] = _from_numpy(result, was_tensor, device) return d
[docs] def randomize(self, _: Any = None) -> None: self._do_transform = self.R.rand() < self.prob
[docs]class RandSliceShiftd(RandomizableTransform, MapTransform): """BANIS-style independent per-slice in-plane shifts along one sampled axis.""" def __init__( self, keys: KeysCollection, prob: float = 0.5, slice_prob: float = 0.05, shift_magnitude: int = 10, spatial_axis: Union[int, str, Tuple[int, ...], List[int]] = (0, 1, 2), wrap: bool = True, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.slice_prob = slice_prob self.shift_magnitude = shift_magnitude self.spatial_axis = spatial_axis self.wrap = wrap def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: if self.prob <= 0 or self.slice_prob <= 0 or self.shift_magnitude <= 0: self._do_transform = False return data d = dict(data) self.randomize(None) if not self._do_transform: return d chosen_spatial_axis: Optional[int] = None selected_indices: Optional[np.ndarray] = None selected_shifts: Optional[List[Tuple[int, int]]] = None for key in self.key_iterator(d): if key not in d: continue arr, was_tensor, device = _to_numpy(d[key]) spatial_rank = _infer_spatial_rank(arr) if spatial_rank != 3: continue if chosen_spatial_axis is None: chosen_spatial_axis = _sample_spatial_axis(self.R, self.spatial_axis, spatial_rank) array_axis = _spatial_axis_to_array_axis(arr, chosen_spatial_axis) if selected_indices is None: depth = arr.shape[array_axis] candidates = np.arange(depth, dtype=np.int64) if candidates.size == 0: continue selected = self.R.rand(candidates.size) < self.slice_prob selected_indices = candidates[selected] selected_shifts = [ ( int(self.R.randint(-self.shift_magnitude, self.shift_magnitude + 1)), int(self.R.randint(-self.shift_magnitude, self.shift_magnitude + 1)), ) for _ in range(selected_indices.size) ] if selected_indices is None or selected_indices.size == 0 or selected_shifts is None: continue result = augment_ops.apply_slice_roll_shifts( arr, slice_axis=array_axis, indices=selected_indices, shifts=selected_shifts, wrap=self.wrap, ) d[key] = _from_numpy(result, was_tensor, device) return d
[docs] def randomize(self, _: Any = None) -> None: self._do_transform = self.R.rand() < self.prob
class RandMulAddIntensityd(RandomizableTransform, MapTransform): """BANIS-style joint multiplicative and additive intensity jitter.""" def __init__( self, keys: KeysCollection, prob: float = 0.5, mul_range: Tuple[float, float] = (0.9, 1.1), add_range: Tuple[float, float] = (-0.1, 0.1), allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.mul_range = tuple(float(v) for v in mul_range) self.add_range = tuple(float(v) for v in add_range) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: if self.prob <= 0: self._do_transform = False return data d = dict(data) self.randomize(None) if not self._do_transform: return d mul_low, mul_high = self.mul_range add_low, add_high = self.add_range if mul_high < mul_low: mul_low, mul_high = mul_high, mul_low if add_high < add_low: add_low, add_high = add_high, add_low mul = float(self.R.uniform(mul_low, mul_high)) add = float(self.R.uniform(add_low, add_high)) for key in self.key_iterator(d): if key not in d: continue arr, was_tensor, device = _to_numpy(d[key]) result = arr.astype(np.float32, copy=False) * mul + add d[key] = _from_numpy(result, was_tensor, device) return d def randomize(self, _: Any = None) -> None: self._do_transform = self.R.rand() < self.prob
[docs]class RandMisAlignmentd(RandomizableTransform, MapTransform): """Random misalignment augmentation simulating EM section alignment artifacts.""" def __init__( self, keys: KeysCollection, prob: float = 0.1, displacement: int = 16, rotate_ratio: float = 0.0, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.displacement = displacement self.rotate_ratio = rotate_ratio def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: if self.prob <= 0: self._do_transform = False return data d = dict(data) self.randomize(None) if not self._do_transform: return d use_rotation = self.R.rand() < self.rotate_ratio for key in self.key_iterator(d): if key in d: arr, was_tensor, device = _to_numpy(d[key]) if arr.ndim < 3: continue depth_axis = _infer_depth_axis(arr) depth = arr.shape[depth_axis] if depth <= 2: continue split_idx = int(self.R.choice(np.arange(1, depth - 1))) mode = "slip" if self.R.rand() < 0.5 else "translation" if use_rotation: height = arr.shape[-2] angle_range = augment_ops.compute_misalignment_angle_range( self.displacement, height ) rand_angle = (self.R.rand() - 0.5) * 2.0 * angle_range result = augment_ops.apply_misalignment_rotation( arr, self.displacement, rand_angle, split_idx, mode, depth_axis=depth_axis, ) else: dy0 = int(self.R.randint(-self.displacement, self.displacement + 1)) dx0 = int(self.R.randint(-self.displacement, self.displacement + 1)) dy1 = int(self.R.randint(-self.displacement, self.displacement + 1)) dx1 = int(self.R.randint(-self.displacement, self.displacement + 1)) result = augment_ops.apply_misalignment_translation( arr, self.displacement, dy0, dx0, dy1, dx1, split_idx, mode, depth_axis=depth_axis, ) d[key] = _from_numpy(result, was_tensor, device) return d
[docs] def randomize(self, _: Any = None) -> None: self._do_transform = self.R.rand() < self.prob
[docs]class RandMissingSectiond(RandomizableTransform, MapTransform): """Random missing section augmentation with paper-style fill values.""" def __init__( self, keys: KeysCollection, prob: float = 0.1, num_sections: Union[int, Tuple[int, int]] = 2, full_section_prob: float = 0.5, partial_ratio_range: Tuple[float, float] = (0.25, 0.75), fill_value_range: Tuple[float, float] = (0.0, 1.0), allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.num_sections = num_sections self.full_section_prob = full_section_prob self.partial_ratio_range = partial_ratio_range self.fill_value_range = fill_value_range def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: if self.prob <= 0: self._do_transform = False return data d = dict(data) self.randomize(None) if not self._do_transform: return d for key in self.key_iterator(d): if key in d: arr, was_tensor, device = _to_numpy(d[key]) if arr.ndim < 3: continue depth_axis = _infer_depth_axis(arr) depth = arr.shape[depth_axis] if depth <= 3: continue num_to_fill = _sample_count(self.R, self.num_sections, depth - 2) if num_to_fill == 0: continue indices = self.R.choice(np.arange(1, depth - 1), size=num_to_fill, replace=False) result = arr for idx in indices: fill_value = float(self.R.uniform(*self.fill_value_range)) if self.R.rand() < self.full_section_prob: result = augment_ops.fill_sections( result, np.asarray([idx]), fill_value=fill_value, depth_axis=depth_axis, ) continue hole_h = max( 1, int(arr.shape[-2] * self.R.uniform(*self.partial_ratio_range)), ) hole_w = max( 1, int(arr.shape[-1] * self.R.uniform(*self.partial_ratio_range)), ) y_start = int(self.R.randint(0, max(1, arr.shape[-2] - hole_h + 1))) x_start = int(self.R.randint(0, max(1, arr.shape[-1] - hole_w + 1))) result = augment_ops.fill_region( result, y_start, x_start, hole_h, hole_w, section_axis=depth_axis, section_idx=int(idx), fill_value=fill_value, ) d[key] = _from_numpy(result, was_tensor, device) return d
[docs] def randomize(self, _: Any = None) -> None: self._do_transform = self.R.rand() < self.prob
[docs]class RandSliceShiftZd(RandMisAlignmentd): """Clearer alias for the legacy z-only misalignment augmentation."""
[docs]class RandSliceDropZd(RandMissingSectiond): """Clearer alias for the legacy z-only missing-section augmentation."""
[docs]class RandMissingPartsd(RandomizableTransform, MapTransform): """Random missing parts — creates rectangular holes in sections.""" def __init__( self, keys: KeysCollection, prob: float = 0.1, hole_range: Tuple[float, float] = (0.1, 0.3), allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.hole_range = hole_range def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: if self.prob <= 0: self._do_transform = False return data d = dict(data) self.randomize(None) if not self._do_transform: return d for key in self.key_iterator(d): if key in d: arr, was_tensor, device = _to_numpy(d[key]) if arr.ndim < 2: continue # Determine section axis section_axis: Optional[int] = None section_idx: Optional[int] = None if arr.ndim == 2: pass elif arr.ndim == 3 and arr.shape[0] <= 4: pass # 2D channel-first elif arr.ndim >= 4 and arr.shape[0] <= 4: section_axis = 1 else: section_axis = 0 hole_ratio = self.R.uniform(*self.hole_range) hole_h = max(1, int(arr.shape[-2] * hole_ratio)) hole_w = max(1, int(arr.shape[-1] * hole_ratio)) y_start = int(self.R.randint(0, max(1, arr.shape[-2] - hole_h + 1))) x_start = int(self.R.randint(0, max(1, arr.shape[-1] - hole_w + 1))) if section_axis is not None: section_idx = int(self.R.randint(0, arr.shape[section_axis])) result = augment_ops.create_missing_hole( arr, y_start, x_start, hole_h, hole_w, section_axis, section_idx ) d[key] = _from_numpy(result, was_tensor, device) return d
[docs] def randomize(self, _: Any = None) -> None: self._do_transform = self.R.rand() < self.prob
[docs]class RandMotionBlurd(RandomizableTransform, MapTransform): """Legacy name for paper-style out-of-focus Gaussian blur augmentation.""" def __init__( self, keys: KeysCollection, prob: float = 0.1, sections: Union[int, Tuple[int, int]] = 2, kernel_size: int = 11, sigma_range: Tuple[float, float] = (1.0, 3.0), full_section_prob: float = 0.5, partial_ratio_range: Tuple[float, float] = (0.25, 0.75), allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.sections = sections self.kernel_size = kernel_size self.sigma_range = sigma_range self.full_section_prob = full_section_prob self.partial_ratio_range = partial_ratio_range def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: d = dict(data) self.randomize(None) if not self._do_transform: return d for key in self.key_iterator(d): if key in d: arr, was_tensor, device = _to_numpy(d[key]) if arr.ndim < 3: continue depth_axis = _infer_depth_axis(arr) depth = arr.shape[depth_axis] num_sections = _sample_count(self.R, self.sections, depth) if num_sections == 0: continue section_indices = self.R.choice(depth, size=num_sections, replace=False) result = arr for idx in section_indices: sigma = float(self.R.uniform(*self.sigma_range)) if self.R.rand() < self.full_section_prob: result = augment_ops.blur_sections( result, np.asarray([idx]), kernel_size=self.kernel_size, sigma=sigma, depth_axis=depth_axis, ) continue hole_h = max( 1, int(arr.shape[-2] * self.R.uniform(*self.partial_ratio_range)), ) hole_w = max( 1, int(arr.shape[-1] * self.R.uniform(*self.partial_ratio_range)), ) y_start = int(self.R.randint(0, max(1, arr.shape[-2] - hole_h + 1))) x_start = int(self.R.randint(0, max(1, arr.shape[-1] - hole_w + 1))) result = augment_ops.blur_region( result, section_idx=int(idx), y_start=y_start, x_start=x_start, hole_h=hole_h, hole_w=hole_w, kernel_size=self.kernel_size, sigma=sigma, depth_axis=depth_axis, ) d[key] = _from_numpy(result, was_tensor, device) return d
[docs] def randomize(self, _: Any = None) -> None: self._do_transform = self.R.rand() < self.prob
[docs]class RandCutNoised(RandomizableTransform, MapTransform): """Random cut noise — adds noise to random cuboid regions.""" def __init__( self, keys: KeysCollection, prob: float = 0.1, length_ratio: Union[float, Tuple[float, float]] = (0.1, 0.4), noise_scale: Union[float, Tuple[float, float]] = (0.05, 0.15), allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.length_ratio = _as_range(length_ratio) self.noise_scale = _as_range(noise_scale) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: d = dict(data) self.randomize(None) if not self._do_transform: return d length_ratio = float(self.R.uniform(*self.length_ratio)) noise_scale = float(self.R.uniform(*self.noise_scale)) for key in self.key_iterator(d): if key in d: arr, was_tensor, device = _to_numpy(d[key]) if arr.ndim < 2 or any(s == 0 for s in arr.shape): continue # Build cuboid slices (channel dim + spatial dims) spatial_shape = arr.shape[1:] slices = [slice(None)] noise_shape = [arr.shape[0]] for s in spatial_shape: length = max(1, int(length_ratio * s)) start = int(self.R.randint(0, max(1, s - length + 1))) slices.append(slice(start, start + length)) noise_shape.append(length) noise = self.R.uniform(-noise_scale, noise_scale, noise_shape) result = augment_ops.apply_cut_noise(arr, slices, noise) d[key] = _from_numpy(result, was_tensor, device) return d
[docs] def randomize(self, _: Any = None) -> None: self._do_transform = self.R.rand() < self.prob
[docs]class RandCutBlurd(RandomizableTransform, MapTransform): """Random CutBlur — downsample+upsample cuboid regions for super-resolution learning.""" def __init__( self, keys: KeysCollection, prob: float = 0.5, length_ratio: Union[float, Tuple[float, float]] = (0.1, 0.4), down_ratio_range: Tuple[float, float] = (2.0, 8.0), downsample_z: bool = False, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.length_ratio = _as_range(length_ratio) self.down_ratio_range = down_ratio_range self.downsample_z = downsample_z def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: d = dict(data) self.randomize(None) if not self._do_transform: return d # Get random params once from first available key random_params = None for key in self.key_iterator(d): if key in d: random_params = self._get_random_params(d[key]) break if random_params is None: return d for key in self.key_iterator(d): if key in d: arr, was_tensor, device = _to_numpy(d[key]) zl, zh, yl, yh, xl, xh, down_ratio = random_params result = augment_ops.apply_cutblur( arr, zl, zh, yl, yh, xl, xh, down_ratio, self.downsample_z ) d[key] = _from_numpy(result, was_tensor, device) return d
[docs] def randomize(self, _: Any = None) -> None: self._do_transform = self.R.rand() < self.prob
def _get_random_params(self, img: Union[np.ndarray, torch.Tensor]) -> Tuple: shape = img.shape zdim = shape[0] if len(shape) == 3 else 1 length_ratio = float(self.R.uniform(*self.length_ratio)) if zdim > 1: zl, zh = self._random_region(shape[0], length_ratio) else: zl, zh = None, None yl, yh = self._random_region( shape[1] if len(shape) == 3 else shape[0], length_ratio ) xl, xh = self._random_region( shape[2] if len(shape) == 3 else shape[1], length_ratio ) down_ratio = self.R.uniform(*self.down_ratio_range) return zl, zh, yl, yh, xl, xh, down_ratio def _random_region(self, vol_len: int, length_ratio: float) -> Tuple[int, int]: cuboid_len = max(1, int(length_ratio * vol_len)) low = int(self.R.randint(0, max(1, vol_len - cuboid_len + 1))) return low, low + cuboid_len
[docs]class RandMixupd(RandomizableTransform, MapTransform): """Random Mixup — linear interpolation between batch samples. Warning: This transform requires a batch dimension (ndim >= 4) and at least 2 samples along that dimension. In standard per-sample MONAI pipelines (where each dict is one sample with ndim=3), this is a no-op. For true cross-sample mixup, use a collate-level or batch-level transform instead. """ def __init__( self, keys: KeysCollection, prob: float = 0.5, alpha_range: Tuple[float, float] = (0.7, 0.9), allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.alpha_range = alpha_range def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: d = dict(data) self.randomize(None) if not self._do_transform: return d for key in self.key_iterator(d): if key in d: d[key] = self._apply_mixup(d[key]) return d
[docs] def randomize(self, _: Any = None) -> None: self._do_transform = self.R.rand() < self.prob
def _apply_mixup( self, volume: Union[np.ndarray, torch.Tensor] ) -> Union[np.ndarray, torch.Tensor]: """Apply mixup to a batched volume (requires batch dim).""" if volume.ndim < 4: return volume is_numpy = isinstance(volume, np.ndarray) if is_numpy: volume = torch.from_numpy(volume) batch_size = volume.shape[0] if batch_size < 2: return volume.numpy() if is_numpy else volume alpha = self.R.uniform(*self.alpha_range) indices = torch.randperm(batch_size) mixed = alpha * volume + (1 - alpha) * volume[indices] return mixed.numpy() if is_numpy else mixed
[docs]class RandCopyPasted(RandomizableTransform, MapTransform): """Random Copy-Paste — copies transformed objects to non-overlapping regions.""" def __init__( self, keys: KeysCollection, label_key: str = "label", prob: float = 0.5, max_obj_ratio: float = 0.7, rotation_angles: List[int] = list(range(30, 360, 30)), border: int = 3, allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.label_key = label_key self.max_obj_ratio = max_obj_ratio self.rotation_angles = rotation_angles self.border = border self.dil_struct = self._generate_binary_structure() @staticmethod def _generate_binary_structure(): return generate_binary_structure(3, 3) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: d = dict(data) self.randomize(None) if not self._do_transform: return d if self.label_key not in d: return d label = d[self.label_key] if isinstance(label, torch.Tensor): obj_ratio = label.float().mean().item() else: obj_ratio = float(label.astype(np.float32).mean()) if obj_ratio > self.max_obj_ratio: return d for key in self.key_iterator(d): if key in d and key != self.label_key: d[key], d[self.label_key] = self._apply_copy_paste(d[key], label) return d def _apply_copy_paste( self, volume: Union[np.ndarray, torch.Tensor], label: Union[np.ndarray, torch.Tensor], ) -> Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]: """Apply copy-paste augmentation.""" is_numpy = isinstance(volume, np.ndarray) if is_numpy: volume = torch.from_numpy(volume.copy()) label = torch.from_numpy(label.copy()) label = label.bool() if label.ndim != 3 or volume.ndim not in [3, 4]: return volume.numpy() if is_numpy else volume, (label.numpy() if is_numpy else label) label_flipped = label.flip(0) if volume.ndim == 4: neuron_tensor = volume * label.unsqueeze(0) else: neuron_tensor = volume * label neuron_tensor, label_paste = self._find_best_paste(neuron_tensor, label, label_flipped) if volume.ndim == 4: label_paste = label_paste.unsqueeze(0) volume = volume * (~label_paste) + neuron_tensor * label_paste else: volume = volume * (~label_paste) + neuron_tensor * label_paste if is_numpy: return volume.numpy(), ( label_paste.squeeze().numpy() if label_paste.ndim > 3 else label_paste.numpy() ) return volume, label_paste.squeeze() if label_paste.ndim > 3 else label_paste def _find_best_paste( self, neuron_tensor: torch.Tensor, label_orig: torch.Tensor, label_flipped: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Find best rotation and position with minimal overlap.""" labels = torch.stack([label_orig, label_flipped]) best_overlap = torch.logical_and(label_flipped, label_orig).int().sum() best_angle = 0 best_idx = 1 for angle in self.rotation_angles: rotated = self._rotate_3d(labels, angle) overlap0 = torch.logical_and(rotated[0], label_orig).int().sum() overlap1 = torch.logical_and(rotated[1], label_orig).int().sum() if overlap0 < best_overlap: best_overlap = overlap0 best_angle = angle best_idx = 0 if overlap1 < best_overlap: best_overlap = overlap1 best_angle = angle best_idx = 1 if best_idx == 1: neuron_tensor = ( neuron_tensor.flip(0) if neuron_tensor.ndim == 3 else neuron_tensor.flip(1) ) label_paste = labels[best_idx : best_idx + 1] if best_angle != 0: label_paste = self._rotate_3d(label_paste, best_angle) if neuron_tensor.ndim == 4: neuron_tensor = self._rotate_3d(neuron_tensor.unsqueeze(0), best_angle).squeeze(0) else: neuron_tensor = self._rotate_3d(neuron_tensor.unsqueeze(0), best_angle).squeeze(0) label_paste = label_paste.squeeze(0) gt_dilated = torch.tensor( binary_dilation(label_orig.numpy(), structure=self.dil_struct, iterations=self.border) ) overlap_mask = torch.logical_and(label_paste, gt_dilated) label_paste[overlap_mask] = False if neuron_tensor.ndim == 4: neuron_tensor[:, overlap_mask] = 0 else: neuron_tensor[overlap_mask] = 0 return neuron_tensor, label_paste @staticmethod def _rotate_3d(tensor: torch.Tensor, angle: float) -> torch.Tensor: """Rotate 3D volume around z-axis.""" if tensor.ndim == 4: # (C, Z, Y, X) c, z, y, x = tensor.shape reshaped = tensor.reshape(1, c * z, y, x) rotated = tvf.rotate(reshaped, angle) return rotated.reshape(c, z, y, x) elif tensor.ndim == 5: # (B, C, Z, Y, X) b, c, z, y, x = tensor.shape rotated_list = [] for i in range(b): reshaped = tensor[i].reshape(1, c * z, y, x) rot = tvf.rotate(reshaped, angle) rotated_list.append(rot.reshape(c, z, y, x)) return torch.stack(rotated_list) return tensor
[docs] def randomize(self, _: Any = None) -> None: self._do_transform = self.R.rand() < self.prob
class NormalizeLabelsd(MapTransform): """Convert labels to binary {0, 1} integers.""" def __init__( self, keys: KeysCollection, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: d = dict(data) for key in self.key_iterator(d): if key in d: if isinstance(d[key], np.ndarray): d[key] = (d[key] > 0).astype(np.int32) elif isinstance(d[key], torch.Tensor): d[key] = (d[key] > 0).int() return d class SmartNormalizeIntensityd(MapTransform): """Smart intensity normalization with multiple modes and percentile clipping. Modes: "none", "normal" (z-score), "0-1" (min-max), "divide-K" (divide by K). """ def __init__( self, keys: KeysCollection, mode: str = "0-1", clip_percentile_low: float = 0.0, clip_percentile_high: float = 1.0, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) self.divide_value = None if mode.startswith("divide-"): try: self.divide_value = float(mode.split("-", 1)[1]) self.mode = "divide" except ValueError: raise ValueError( f"Invalid divide mode '{mode}'. Format should be 'divide-K' " f"where K is a number (e.g., 'divide-255')" ) elif mode not in ["none", "normal", "0-1"]: raise ValueError( f"Invalid mode '{mode}'. Must be 'none', 'normal', '0-1', or 'divide-K'" ) else: self.mode = mode self.clip_percentile_low = clip_percentile_low self.clip_percentile_high = clip_percentile_high def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: d = dict(data) for key in self.key_iterator(d): if key in d: arr, was_tensor, device = _to_numpy(d[key]) result = augment_ops.smart_normalize( arr, self.mode, self.divide_value, self.clip_percentile_low, self.clip_percentile_high, ) d[key] = _from_numpy(result, was_tensor, device) return d class RandStriped(RandomizableTransform, MapTransform): """Random stripe augmentation simulating EM curtaining/scan line artifacts.""" def __init__( self, keys: KeysCollection, prob: float = 0.3, num_stripes_range: Tuple[int, int] = (2, 10), thickness_range: Tuple[int, int] = (1, 5), intensity_range: Tuple[float, float] = (-0.2, 0.2), angle_range: Optional[Tuple[float, float]] = None, orientation: str = "random", mode: str = "add", allow_missing_keys: bool = False, ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.num_stripes_range = num_stripes_range self.thickness_range = thickness_range self.intensity_range = intensity_range self.angle_range = angle_range if orientation not in ["horizontal", "vertical", "random"]: raise ValueError( f"Invalid orientation '{orientation}'. Must be 'horizontal', " f"'vertical', or 'random'" ) self.orientation = orientation if mode not in ["add", "replace"]: raise ValueError(f"Invalid mode '{mode}'. Must be 'add' or 'replace'") self.mode = mode def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: d = dict(data) self.randomize(None) if not self._do_transform: return d for key in self.key_iterator(d): if key in d: arr, was_tensor, device = _to_numpy(d[key]) if arr.ndim < 2: continue # Determine angle if self.angle_range is not None: angle = float(self.R.uniform(*self.angle_range)) elif self.orientation == "random": angle = 0.0 if self.R.rand() > 0.5 else 90.0 elif self.orientation == "horizontal": angle = 0.0 else: angle = 90.0 # Number of stripes if self.num_stripes_range[0] == self.num_stripes_range[1]: num_stripes = self.num_stripes_range[0] else: num_stripes = int( self.R.randint(self.num_stripes_range[0], self.num_stripes_range[1] + 1) ) # Generate stripe parameters # We need coord range for positioning — compute from a representative slice h, w = arr.shape[-2:] angle_rad = np.deg2rad(angle) y_coords, x_coords = np.ogrid[:h, :w] rotated_coords = x_coords * np.sin(angle_rad) - y_coords * np.cos(angle_rad) coord_min = float(rotated_coords.min()) coord_max = float(rotated_coords.max()) stripe_params = [] for _ in range(num_stripes): center = float(self.R.uniform(coord_min, coord_max)) if self.thickness_range[0] == self.thickness_range[1]: thickness = self.thickness_range[0] else: thickness = int( self.R.randint(self.thickness_range[0], self.thickness_range[1] + 1) ) intensity = float(self.R.uniform(*self.intensity_range)) stripe_params.append((center, thickness, intensity)) result = augment_ops.apply_stripes(arr, stripe_params, angle, self.mode) d[key] = _from_numpy(result, was_tensor, device) return d def randomize(self, _: Any = None) -> None: self._do_transform = self.R.rand() < self.prob class ResizeByFactord(MapTransform): """Resize images by scale factors using F.interpolate.""" def __init__( self, keys: KeysCollection, scale_factors: List[float], mode: str = "bilinear", align_corners: Optional[bool] = None, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) self.scale_factors = scale_factors self.mode = mode self.align_corners = align_corners def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: d = dict(data) for key in self.key_iterator(d): if key not in d: continue arr = d[key] is_numpy = isinstance(arr, np.ndarray) if is_numpy: arr = torch.from_numpy(arr) inp = arr.unsqueeze(0).float() spatial_dims = inp.ndim - 2 interp_mode = self.mode if interp_mode == "bilinear": if spatial_dims == 3: interp_mode = "trilinear" elif spatial_dims == 1: interp_mode = "linear" out = torch.nn.functional.interpolate( inp, scale_factor=[float(f) for f in self.scale_factors], mode=interp_mode, align_corners=None if interp_mode == "nearest" else self.align_corners, ).squeeze(0) d[key] = out.numpy() if is_numpy else out.to(arr.dtype) return d class RandElasticd(MapTransform, RandomizableTransform): """Unified elastic deformation wrapping MONAI's Rand2DElasticd/Rand3DElasticd.""" def __init__( self, keys, do_2d: bool = False, prob: float = 0.5, sigma_range: tuple = (5.0, 8.0), magnitude_range: tuple = (50.0, 150.0), allow_missing_keys: bool = False, mode: str = "bilinear", padding_mode: str = "reflection", ): MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) self.do_2d = do_2d self.sigma_range = sigma_range self.magnitude_range = magnitude_range self.mode = mode self.padding_mode = padding_mode def _build_inner_transform(self): from monai.transforms import Rand2DElasticd, Rand3DElasticd if self.do_2d: return Rand2DElasticd( keys=self.keys, prob=1.0, spacing=self.sigma_range, magnitude_range=self.magnitude_range, mode=self.mode, padding_mode=self.padding_mode, allow_missing_keys=self.allow_missing_keys, ) return Rand3DElasticd( keys=self.keys, prob=1.0, sigma_range=self.sigma_range, magnitude_range=self.magnitude_range, mode=self.mode, padding_mode=self.padding_mode, allow_missing_keys=self.allow_missing_keys, ) def __call__(self, data): d = dict(data) self.randomize(None) if not self._do_transform: return d if not hasattr(self, "_inner_transform"): self._inner_transform = self._build_inner_transform() return self._inner_transform(d) __all__ = [ "RandMisAlignmentd", "RandMissingSectiond", "RandMissingPartsd", "RandMotionBlurd", "RandCutNoised", "RandCutBlurd", "RandMixupd", "RandCopyPasted", "RandStriped", "NormalizeLabelsd", "SmartNormalizeIntensityd", "ResizeByFactord", "RandElasticd", ]