"""
PyTorch Lightning DataModules for connectomics datasets.
Provides ConnectomicsDataModule (MONAI transform-based) and
SimpleDataModule (wraps pre-built dataloaders).
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
import numpy as np
import pytorch_lightning as pl
import torch
from monai.data import CacheDataset, Dataset
from monai.transforms import Compose
from torch.utils.data import DataLoader, Sampler
[docs]class ConnectomicsDataModule(pl.LightningDataModule):
"""
Lightning DataModule using MONAI Dataset/CacheDataset.
Used as a fallback when pre-loaded cache is not enabled.
Transforms (including loading and cropping) are applied on-the-fly.
Args:
train_data_dicts: Training data dictionaries.
val_data_dicts: Validation data dictionaries.
test_data_dicts: Test data dictionaries.
transforms: Dict of Compose for 'train'/'val'/'test'.
dataset_type: 'standard' or 'cached'.
batch_size: Batch size for dataloaders.
num_workers: Number of dataloader workers.
pin_memory: Pin memory for GPU transfer.
persistent_workers: Keep workers alive between epochs.
cache_rate: Cache rate for CacheDataset.
val_steps_per_epoch: Override validation dataset length.
seed: Random seed for validation reseeding.
distributed_tta_sharding: Keep test samples replicated on all ranks so
TTA passes can be partitioned inside inference rather than by sampler.
distributed_window_sharding: Keep the single test sample replicated on all
ranks so lazy sliding-window patches can be partitioned inside inference.
distributed_chunked_raw_sharding: Keep the single test sample replicated on all
ranks so raw prediction chunks can be partitioned inside inference.
**dataset_kwargs: Extra args (iter_num, sample_size, etc.).
"""
def __init__(
self,
train_data_dicts: List[Dict[str, Any]],
val_data_dicts: Optional[List[Dict[str, Any]]] = None,
test_data_dicts: Optional[List[Dict[str, Any]]] = None,
transforms: Optional[Dict[str, Compose]] = None,
dataset_type: str = "standard",
batch_size: int = 1,
num_workers: int = 0,
pin_memory: bool = True,
persistent_workers: bool = False,
cache_rate: float = 1.0,
val_steps_per_epoch: Optional[int] = None,
seed: int = 0,
distributed_tta_sharding: bool = False,
distributed_window_sharding: bool = False,
distributed_chunked_raw_sharding: bool = False,
**dataset_kwargs,
):
super().__init__()
self.train_data_dicts = train_data_dicts
self.val_data_dicts = val_data_dicts
self.test_data_dicts = test_data_dicts
self.skip_validation = not val_data_dicts or len(val_data_dicts) == 0
self.transforms = transforms or {}
self.dataset_type = dataset_type
self.batch_size = batch_size
self.num_workers = num_workers
self.pin_memory = pin_memory
self.persistent_workers = persistent_workers
self.cache_rate = cache_rate
self.val_steps_per_epoch = val_steps_per_epoch
self.seed = seed
self.distributed_tta_sharding = distributed_tta_sharding
self.distributed_window_sharding = distributed_window_sharding
self.distributed_chunked_raw_sharding = distributed_chunked_raw_sharding
self.dataset_kwargs = dataset_kwargs
self.train_dataset = None
self.val_dataset = None
self.test_dataset = None
[docs] def setup(self, stage: Optional[str] = None):
if stage in ("fit", None):
if self.train_data_dicts:
self.train_dataset = self._create_dataset(
self.train_data_dicts,
self.transforms.get("train"),
"train",
)
if self.val_data_dicts:
self.val_dataset = self._create_dataset(
self.val_data_dicts,
self.transforms.get("val"),
"val",
)
if stage in ("test", None):
if self.test_data_dicts:
self.test_dataset = self._create_dataset(
self.test_data_dicts,
self.transforms.get("test"),
"test",
)
def _create_dataset(self, data_dicts, transforms, mode):
iter_num = self.dataset_kwargs.get("iter_num", -1)
if mode == "val" and self.val_steps_per_epoch is not None:
iter_num = self.val_steps_per_epoch
if self.dataset_type == "cached":
ds = CacheDataset(
data=data_dicts,
transform=transforms,
cache_rate=self.cache_rate,
)
else:
ds = Dataset(data=data_dicts, transform=transforms)
if iter_num and iter_num > 0:
ds = _IterNumDataset(ds, iter_num)
return ds
[docs] def train_dataloader(self):
return self._create_dataloader(self.train_dataset, shuffle=True)
[docs] def val_dataloader(self):
if self.skip_validation:
return []
return self._create_dataloader(self.val_dataset, shuffle=False)
[docs] def test_dataloader(self):
sampler = None
if self.test_dataset is not None and _distributed_world_size() > 1:
replicated_single_volume = (
self.distributed_tta_sharding
or self.distributed_window_sharding
or self.distributed_chunked_raw_sharding
)
if replicated_single_volume:
if len(self.test_dataset) != 1:
raise RuntimeError(
"Distributed single-volume inference sharding requires exactly one "
"test sample replicated on every rank. Disable "
"inference.test_time_augmentation.distributed_sharding and "
"inference.sliding_window.distributed_sharding, or use single-GPU "
"chunked raw inference, for multi-volume test datasets."
)
else:
sampler = DistributedEvaluationSampler(self.test_dataset)
return self._create_dataloader(
self.test_dataset,
shuffle=False,
collate_fn=collate_dict_list,
sampler=sampler,
)
def _create_dataloader(self, dataset, shuffle, collate_fn=None, sampler=None):
if dataset is None:
return None
if collate_fn is None:
collate_fn = collate_dict
return DataLoader(
dataset=dataset,
batch_size=self.batch_size,
shuffle=shuffle if sampler is None else False,
sampler=sampler,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=(self.persistent_workers and self.num_workers > 0),
collate_fn=collate_fn,
)
class SimpleDataModule(pl.LightningDataModule):
"""Wraps pre-built train/val/test dataloaders."""
def __init__(
self,
train_loader=None,
val_loader=None,
test_loader=None,
):
super().__init__()
self._train = train_loader
self._val = val_loader
self._test = test_loader
def train_dataloader(self):
return self._train
def val_dataloader(self):
return self._val if self._val is not None else []
def test_dataloader(self):
return self._test if self._test is not None else []
class _IterNumDataset(torch.utils.data.Dataset):
"""Wraps a dataset to override __len__ with iter_num.
Uses modulo indexing: when iter_num > len(dataset), indices wrap around
to the beginning of the underlying dataset.
"""
def __init__(self, dataset, iter_num: int):
self.dataset = dataset
self._len = iter_num
def __len__(self):
return self._len
def __getitem__(self, index):
return self.dataset[index % len(self.dataset)]
def _is_distributed_evaluation_active() -> bool:
return torch.distributed.is_available() and torch.distributed.is_initialized()
def _distributed_world_size() -> int:
if not _is_distributed_evaluation_active():
return 1
return int(torch.distributed.get_world_size())
class DistributedEvaluationSampler(Sampler[int]):
"""Shard evaluation samples across DDP ranks without padding or duplication."""
def __init__(
self,
dataset,
*,
rank: Optional[int] = None,
world_size: Optional[int] = None,
):
if rank is None or world_size is None:
if not _is_distributed_evaluation_active():
raise RuntimeError(
"DistributedEvaluationSampler requires an initialized distributed process "
"group or explicit rank/world_size."
)
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
if world_size <= 0:
raise ValueError(f"world_size must be positive, got {world_size}.")
if rank < 0 or rank >= world_size:
raise ValueError(f"rank must satisfy 0 <= rank < world_size, got {rank}/{world_size}.")
self.rank = int(rank)
self.world_size = int(world_size)
self.indices = list(range(len(dataset)))[self.rank :: self.world_size]
def __iter__(self):
return iter(self.indices)
def __len__(self):
return len(self.indices)
def collate_dict(
batch: List[Dict[str, Any]],
) -> Dict[str, Any]:
"""Stack dict-of-arrays batch into dict-of-tensors."""
if not batch:
return {}
result = {}
for key in batch[0]:
values = [sample[key] for sample in batch]
if isinstance(values[0], torch.Tensor):
result[key] = torch.stack(values)
elif isinstance(values[0], np.ndarray):
result[key] = torch.stack([torch.from_numpy(v) for v in values])
else:
result[key] = values
return result
def collate_dict_list(
batch: List[Dict[str, Any]],
) -> Dict[str, Any]:
"""Preserve per-sample values as lists for variable-shape test volumes."""
if not batch:
return {}
result = {}
for key in batch[0]:
values = [sample[key] for sample in batch]
if isinstance(values[0], np.ndarray):
result[key] = [torch.from_numpy(v) for v in values]
else:
result[key] = values
return result
__all__ = [
"ConnectomicsDataModule",
"DistributedEvaluationSampler",
"SimpleDataModule",
"collate_dict",
"collate_dict_list",
]