Shortcuts

Source code for connectomics.data.datasets.split

"""
Volume splitting utilities for train/val splits.

Inspired by DeepEM's approach of using spatial splits along the Z-axis
for training and validation, with automatic padding to model input size.
"""

from __future__ import annotations

import logging
from pathlib import Path
from typing import Optional, Tuple, Union

import numpy as np
import torch

logger = logging.getLogger(__name__)


[docs]def split_volume_train_val( volume_shape: Tuple[int, int, int], train_ratio: float = 0.8, axis: int = 0, min_val_size: Optional[int] = None, ) -> Tuple[Tuple[slice, ...], Tuple[slice, ...]]: """ Split a volume into training and validation regions along a specified axis. This follows DeepEM's approach of spatial splitting where: - First 80% (or specified ratio) of volume is used for training - Last 20% is used for validation - Split is along Z-axis by default (axis=0 for [D,H,W] volumes) Args: volume_shape: Shape of the volume (D, H, W) train_ratio: Ratio of volume to use for training (default: 0.8) axis: Axis along which to split (0=D, 1=H, 2=W). Default: 0 (Z-axis) min_val_size: Minimum size for validation split (default: None) Returns: train_slices: Tuple of slices for training region val_slices: Tuple of slices for validation region Example: >>> volume_shape = (100, 256, 256) # [D, H, W] >>> train_slices, val_slices = split_volume_train_val(volume_shape, train_ratio=0.8) >>> # train_slices = (slice(0, 80), slice(None), slice(None)) >>> # val_slices = (slice(80, 100), slice(None), slice(None)) """ # Validate inputs if not 0 < train_ratio < 1: raise ValueError(f"train_ratio must be between 0 and 1, got {train_ratio}") if axis < 0 or axis >= len(volume_shape): raise ValueError(f"axis must be between 0 and {len(volume_shape) - 1}, got {axis}") # Calculate split point split_dim = volume_shape[axis] train_size = int(split_dim * train_ratio) val_size = split_dim - train_size # Ensure minimum validation size if min_val_size is not None and val_size < min_val_size: train_size = split_dim - min_val_size val_size = min_val_size actual_ratio = train_size / split_dim logger.warning( "Adjusted train_ratio from %.2f to %.2f to satisfy min_val_size=%d", train_ratio, actual_ratio, min_val_size, ) # Create slices train_slices = [slice(None)] * len(volume_shape) val_slices = [slice(None)] * len(volume_shape) train_slices[axis] = slice(0, train_size) val_slices[axis] = slice(train_size, split_dim) return tuple(train_slices), tuple(val_slices)
def create_split_masks( volume_shape: Tuple[int, int, int], train_ratio: float = 0.8, axis: int = 0, min_val_size: Optional[int] = None, dtype: np.dtype = np.uint8, ) -> Tuple[np.ndarray, np.ndarray]: """ Create binary masks for train/val splits. Similar to DeepEM's msk_train.h5 and msk_val.h5 approach, but computed automatically based on train_ratio. Args: volume_shape: Shape of the volume (D, H, W) train_ratio: Ratio of volume to use for training (default: 0.8) axis: Axis along which to split (0=D, 1=H, 2=W). Default: 0 min_val_size: Minimum size for validation split (default: None) dtype: Data type for masks (default: np.uint8) Returns: train_mask: Binary mask for training region (1=valid, 0=invalid) val_mask: Binary mask for validation region (1=valid, 0=invalid) Example: >>> volume_shape = (100, 256, 256) >>> train_mask, val_mask = create_split_masks(volume_shape, train_ratio=0.8) >>> print(train_mask.shape, train_mask.sum()) # (100, 256, 256), first 80 slices = 1 >>> print(val_mask.shape, val_mask.sum()) # (100, 256, 256), last 20 slices = 1 """ train_slices, val_slices = split_volume_train_val(volume_shape, train_ratio, axis, min_val_size) # Create masks train_mask = np.zeros(volume_shape, dtype=dtype) val_mask = np.zeros(volume_shape, dtype=dtype) train_mask[train_slices] = 1 val_mask[val_slices] = 1 return train_mask, val_mask def pad_volume_to_size( volume: Union[np.ndarray, torch.Tensor], target_size: Tuple[int, int, int], mode: str = "reflect", constant_value: float = 0.0, ) -> Union[np.ndarray, torch.Tensor]: """ Pad a volume to match a target size. This is useful when validation data is smaller than the model's input size and needs to be padded for inference. Args: volume: Input volume of shape (D, H, W) or (C, D, H, W) target_size: Target size (D, H, W) mode: Padding mode. Options: - 'constant': Pad with constant value - 'reflect': Reflect values at boundaries - 'replicate': Replicate edge values - 'circular': Circular padding constant_value: Value to use for constant padding (default: 0.0) Returns: Padded volume of shape matching target_size (+ channel dim if input had it) Example: >>> volume = np.random.randn(18, 160, 160) # Small validation volume >>> target_size = (32, 256, 256) # Model input size >>> padded = pad_volume_to_size(volume, target_size, mode='reflect') >>> print(padded.shape) # (32, 256, 256) """ is_tensor = isinstance(volume, torch.Tensor) # Handle channel dimension has_channel = volume.ndim == 4 if has_channel: c, d, h, w = volume.shape spatial_shape = (d, h, w) else: d, h, w = volume.shape spatial_shape = (d, h, w) # Calculate padding needed pad_d = max(0, target_size[0] - spatial_shape[0]) pad_h = max(0, target_size[1] - spatial_shape[1]) pad_w = max(0, target_size[2] - spatial_shape[2]) # No padding needed if pad_d == 0 and pad_h == 0 and pad_w == 0: return volume # Calculate symmetric padding (before, after) for each dimension pad_d_before = pad_d // 2 pad_d_after = pad_d - pad_d_before pad_h_before = pad_h // 2 pad_h_after = pad_h - pad_h_before pad_w_before = pad_w // 2 pad_w_after = pad_w - pad_w_before if is_tensor: # PyTorch padding: (left, right, top, bottom, front, back) # Order is reversed: last dim first padding = ( pad_w_before, pad_w_after, # W pad_h_before, pad_h_after, # H pad_d_before, pad_d_after, # D ) # Map mode names if mode == "constant": padded = torch.nn.functional.pad(volume, padding, mode="constant", value=constant_value) elif mode == "reflect": padded = torch.nn.functional.pad(volume, padding, mode="reflect") elif mode == "replicate": padded = torch.nn.functional.pad(volume, padding, mode="replicate") elif mode == "circular": padded = torch.nn.functional.pad(volume, padding, mode="circular") else: raise ValueError(f"Unknown padding mode: {mode}") else: # NumPy padding if has_channel: pad_width = ( (0, 0), # C - no padding (pad_d_before, pad_d_after), # D (pad_h_before, pad_h_after), # H (pad_w_before, pad_w_after), # W ) else: pad_width = ( (pad_d_before, pad_d_after), # D (pad_h_before, pad_h_after), # H (pad_w_before, pad_w_after), # W ) # Map mode names np_mode = mode if mode == "reflect": np_mode = "reflect" elif mode == "replicate": np_mode = "edge" elif mode == "circular": np_mode = "wrap" elif mode == "constant": padded = np.pad(volume, pad_width, mode="constant", constant_values=constant_value) return padded else: raise ValueError(f"Unknown padding mode: {mode}") padded = np.pad(volume, pad_width, mode=np_mode) return padded def split_and_pad_volume( volume: Union[np.ndarray, torch.Tensor], train_ratio: float = 0.8, target_size: Optional[Tuple[int, int, int]] = None, axis: int = 0, pad_mode: str = "reflect", min_val_size: Optional[int] = None, ) -> Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]: """ Split volume into train/val and pad validation if needed. This combines splitting and padding in one operation, following DeepEM's approach of using 80% for training and padding the smaller validation volume to match model input requirements. Args: volume: Input volume (D, H, W) or (C, D, H, W) train_ratio: Ratio for training split (default: 0.8) target_size: Target size for padding validation (default: None, no padding) axis: Split axis (default: 0 for Z-axis) pad_mode: Padding mode for validation volume (default: 'reflect') min_val_size: Minimum validation size (default: None) Returns: train_volume: Training portion (unchanged) val_volume: Validation portion (padded to target_size if specified) Example: >>> # Volume with 100 slices >>> volume = np.random.randn(100, 256, 256) >>> >>> # Split 80/20 and pad validation to model size >>> train_vol, val_vol = split_and_pad_volume( ... volume, ... train_ratio=0.8, ... target_size=(32, 256, 256), # Model input size ... pad_mode='reflect' ... ) >>> >>> print(train_vol.shape) # (80, 256, 256) - first 80 slices >>> print(val_vol.shape) # (32, 256, 256) - last 20 slices, padded to 32 """ # Determine spatial shape has_channel = volume.ndim == 4 if has_channel: spatial_shape = volume.shape[1:] else: spatial_shape = volume.shape # Get split slices train_slices, val_slices = split_volume_train_val( spatial_shape, train_ratio, axis, min_val_size ) # Split volume if has_channel: train_volume = volume[(slice(None),) + train_slices] val_volume = volume[(slice(None),) + val_slices] else: train_volume = volume[train_slices] val_volume = volume[val_slices] # Pad validation if target size specified if target_size is not None: val_volume = pad_volume_to_size(val_volume, target_size, mode=pad_mode) return train_volume, val_volume # Convenience function for saving masks to HDF5 (DeepEM style) def save_split_masks_h5( output_dir: Union[str, Path], volume_shape: Tuple[int, int, int], train_ratio: float = 0.8, axis: int = 0, train_filename: str = "msk_train.h5", val_filename: str = "msk_val.h5", ): """ Save train/val split masks to HDF5 files (DeepEM compatible). Args: output_dir: Directory to save mask files volume_shape: Shape of the volume (D, H, W) train_ratio: Training split ratio (default: 0.8) axis: Split axis (default: 0) train_filename: Filename for training mask (default: 'msk_train.h5') val_filename: Filename for validation mask (default: 'msk_val.h5') """ import h5py output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Create masks train_mask, val_mask = create_split_masks(volume_shape, train_ratio, axis) # Save to HDF5 train_path = output_dir / train_filename val_path = output_dir / val_filename with h5py.File(train_path, "w") as f: f.create_dataset("main", data=train_mask, compression="gzip") with h5py.File(val_path, "w") as f: f.create_dataset("main", data=val_mask, compression="gzip") logger.info("Saved training mask to: %s", train_path) logger.info("Saved validation mask to: %s", val_path) logger.info("Training region: %d voxels", train_mask.sum()) logger.info("Validation region: %d voxels", val_mask.sum()) # ============================================================================ # MONAI Transform for applying splits during data loading # ============================================================================ try: from monai.config import KeysCollection from monai.transforms import MapTransform HAS_MONAI = True except ImportError: HAS_MONAI = False if HAS_MONAI: class ApplyVolumetricSplitd(MapTransform): """ MONAI transform to apply volumetric split during data loading. This transform extracts a spatial region from loaded volumes based on split metadata stored in the data dictionary. It supports automatic padding for validation regions. Expected metadata in data dict: - 'split_slices': Tuple of slice objects defining the region - 'split_mode': 'train' or 'val' - 'split_pad' (optional): Whether to pad validation - 'split_pad_size' (optional): Target size for padding - 'split_pad_mode' (optional): Padding mode ('reflect', etc.) Args: keys: Keys of data dict to apply split to (e.g., ['image', 'label']) allow_missing_keys: Don't raise error if key is missing Example: >>> from monai.transforms import Compose >>> transforms = Compose([ ... LoadImaged(keys=['image', 'label']), ... ApplyVolumetricSplitd(keys=['image', 'label']), ... # ... other transforms ... ]) """ def __init__( self, keys: KeysCollection, allow_missing_keys: bool = False, ): super().__init__(keys, allow_missing_keys) def __call__(self, data: dict) -> dict: """Apply volumetric split to data dictionary.""" d = dict(data) # Check if split metadata exists if "split_slices" not in d: return d # No split to apply split_slices = d["split_slices"] split_mode = d.get("split_mode", "train") # Apply split to each key for key in self.key_iterator(d): if key in d: volume = d[key] # Apply slicing # Handle both (C, D, H, W) and (D, H, W) shapes if volume.ndim == 4: # Has channel dimension - apply split to spatial dims d[key] = volume[:, split_slices[0], split_slices[1], split_slices[2]] elif volume.ndim == 3: # No channel dimension - apply directly d[key] = volume[split_slices] else: raise ValueError(f"Expected 3D or 4D volume, got shape {volume.shape}") # Apply padding if validation mode and padding requested if split_mode == "val" and d.get("split_pad", False): target_size = d.get("split_pad_size") pad_mode = d.get("split_pad_mode", "reflect") if target_size is not None: for key in self.key_iterator(d): if key in d: d[key] = pad_volume_to_size( d[key], target_size=target_size, mode=pad_mode, ) return d def apply_volumetric_split( volume: Union[np.ndarray, torch.Tensor], split_slices: Tuple[slice, ...], pad: bool = False, pad_size: Optional[Tuple[int, ...]] = None, pad_mode: str = "reflect", ) -> Union[np.ndarray, torch.Tensor]: """ Apply volumetric split to a single volume (convenience function). Args: volume: Input volume (D, H, W) or (C, D, H, W) split_slices: Tuple of slice objects pad: Whether to pad after splitting pad_size: Target size for padding (only if pad=True) pad_mode: Padding mode ('reflect', 'replicate', 'constant') Returns: Split (and optionally padded) volume """ # Apply split if volume.ndim == 4: result = volume[:, split_slices[0], split_slices[1], split_slices[2]] elif volume.ndim == 3: result = volume[split_slices] else: raise ValueError(f"Expected 3D or 4D volume, got shape {volume.shape}") # Apply padding if requested if pad and pad_size is not None: result = pad_volume_to_size(result, target_size=pad_size, mode=pad_mode) return result __all__ = [ "split_volume_train_val", "create_split_masks", "pad_volume_to_size", "split_and_pad_volume", "save_split_masks_h5", "apply_volumetric_split", ] if HAS_MONAI: __all__.append("ApplyVolumetricSplitd")