Shortcuts

Source code for connectomics.data.datasets.base

"""Shared base class for patch-sampling datasets."""

from __future__ import annotations

import logging
import random
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Tuple

import torch.utils.data
from monai.transforms import Compose
from monai.utils import ensure_tuple_rep

from .crop_sampling import center_crop_position, random_crop_position

logger = logging.getLogger(__name__)


[docs]class PatchDataset(torch.utils.data.Dataset): """ Abstract base for datasets that sample random patches from volumes. Subclasses must implement: _crop_volumes(vol_idx, pos) -> dict with "image" and optional "label"/"mask" _has_labels(vol_idx) -> bool Subclasses must populate ``self.volume_sizes`` during __init__. Provides: - __getitem__ with foreground-aware retry loop - set_epoch / get_sampling_fingerprint for validation reseeding - Shared crop position sampling via crop_sampling.py """ def __init__( self, patch_size: Tuple[int, ...], iter_num: int = 500, transforms: Optional[Compose] = None, mode: str = "train", max_attempts: int = 10, foreground_threshold: float = 0.0, ): super().__init__() ndim = len(patch_size) if ndim not in (2, 3): raise ValueError(f"patch_size must be 2D or 3D, got {ndim}D") self.patch_size = ensure_tuple_rep(patch_size, ndim) self.iter_num = iter_num self.transforms = transforms self.mode = mode self.max_attempts = max_attempts self.foreground_threshold = foreground_threshold # Validation reseeding support self.base_seed = 0 self.current_epoch = 0 # Subclass must populate this during __init__ self.volume_sizes: List[Tuple[int, ...]] = [] @property def num_volumes(self) -> int: return len(self.volume_sizes) @abstractmethod def _crop_volumes(self, vol_idx: int, pos: Tuple[int, ...]) -> Dict[str, Any]: """ Crop image/label/mask from volume at given position. Returns dict with "image" (required), "label" and "mask" (optional, None if absent). Values are numpy arrays with channel dim: (C, ...). """ ... @abstractmethod def _has_labels(self, vol_idx: int) -> bool: """Whether the volume at vol_idx has associated labels.""" ... def __len__(self) -> int: return self.iter_num def __getitem__(self, index: int) -> Dict[str, Any]: vol_idx = random.randint(0, self.num_volumes - 1) use_fg = ( self.mode == "train" and self._has_labels(vol_idx) and self.foreground_threshold > 0 ) data = None if use_fg: for _ in range(self.max_attempts): pos = self._get_random_crop_position(vol_idx) data = self._crop_volumes(vol_idx, pos) label = data.get("label") if label is None: break # no label => no foreground filtering needed fg_frac = float((label > 0).sum()) / float(label.size) # Reject if mask is present but entirely zero in this crop mask = data.get("mask") if mask is not None and not (mask > 0).any(): continue if fg_frac >= self.foreground_threshold: break # If loop exhausted without break, data holds the last attempt if data is None: # Either use_fg was False, or max_attempts==0 pos = ( self._get_random_crop_position(vol_idx) if self.mode == "train" else self._get_center_crop_position(vol_idx) ) data = self._crop_volumes(vol_idx, pos) # Remove None values so downstream code doesn't see phantom entries data = {k: v for k, v in data.items() if v is not None} if self.transforms: data = self.transforms(data) return data # -- Crop position helpers (overridable by subclasses) -- def _get_random_crop_position(self, vol_idx: int) -> Tuple[int, ...]: return random_crop_position(self.volume_sizes[vol_idx], self.patch_size, rng=random) def _get_center_crop_position(self, vol_idx: int) -> Tuple[int, ...]: return center_crop_position(self.volume_sizes[vol_idx], self.patch_size) # -- Validation reseeding --
[docs] def set_epoch(self, epoch: int, base_seed: int = 0): """Set epoch for deterministic validation reseeding.""" if self.mode == "val": self.base_seed = base_seed self.current_epoch = epoch effective_seed = self.base_seed + epoch random.seed(effective_seed) logger.debug( "[Validation] epoch=%s, effective_seed=%s, dataset=%s@%s", epoch, effective_seed, type(self).__name__, id(self), )
[docs] def get_sampling_fingerprint(self, num_samples: int = 5) -> str: """Generate fingerprint of validation sampling for verification.""" if self.mode != "val": return "N/A (training mode)" state = random.getstate() try: samples = [] for _ in range(num_samples): vol_idx = random.randint(0, self.num_volumes - 1) pos = self._get_random_crop_position(vol_idx) samples.append((vol_idx, pos)) return ", ".join(f"v{v}@{p}" for v, p in samples) finally: random.setstate(state)
__all__ = ["PatchDataset"]