Shortcuts

Source code for connectomics.data.datasets.dataset_filename

"""
Filename-based dataset for PyTorch Connectomics.

Loads individual images from JSON file lists instead of cropping from large
volumes. Ideal for datasets with pre-tiled images like MitoLab, CEM500K, etc.
"""

from __future__ import annotations

import json
import logging
import random
from pathlib import Path
from typing import Any, Dict, Optional, Tuple

from monai.data import Dataset
from monai.transforms import Compose

logger = logging.getLogger(__name__)


[docs]class MonaiFilenameDataset(Dataset): """ MONAI dataset for loading individual images from JSON file lists. JSON format:: { "base_path": "/path/to/data", "images": ["relative/path/to/image1.png", ...], "masks": ["relative/path/to/mask1.png", ...] } Args: json_path: Path to JSON file containing file lists. transforms: MONAI transforms pipeline. mode: 'train', 'val', or 'test'. images_key: Key in JSON for image file list. labels_key: Key in JSON for label file list. base_path_key: Key in JSON for base path. train_val_split: Fraction for train split (0.0-1.0). random_seed: Random seed for train/val split. use_labels: Whether to load labels. """ def __init__( self, json_path: str, transforms: Optional[Compose] = None, mode: str = "train", images_key: str = "images", labels_key: str = "masks", base_path_key: str = "base_path", train_val_split: Optional[float] = None, random_seed: int = 42, use_labels: bool = True, ): self.json_path = Path(json_path) self.mode = mode with open(self.json_path, "r") as f: json_data = json.load(f) base_path = Path(json_data.get(base_path_key, "")) image_files = json_data.get(images_key, []) label_files = json_data.get(labels_key, []) if use_labels else [] if not image_files: raise ValueError(f"No images found in JSON under key '{images_key}'") # Create paired data if use_labels and label_files: if len(image_files) != len(label_files): raise ValueError( f"Image count ({len(image_files)}) != " f"label count ({len(label_files)})" ) pairs = list(zip(image_files, label_files)) else: pairs = [(img, None) for img in image_files] # Apply train/val split if requested if train_val_split is not None: if not 0.0 < train_val_split < 1.0: raise ValueError(f"train_val_split must be in (0, 1), " f"got {train_val_split}") rng = random.Random(random_seed) pairs_shuffled = pairs.copy() rng.shuffle(pairs_shuffled) n_train = int(len(pairs_shuffled) * train_val_split) if mode == "train": pairs = pairs_shuffled[:n_train] elif mode in ("val", "validation"): pairs = pairs_shuffled[n_train:] else: pairs = pairs_shuffled # Create MONAI data dictionaries data_dicts = [] for img_file, label_file in pairs: d: Dict[str, Any] = { "image": str(base_path / img_file), } if label_file is not None: d["label"] = str(base_path / label_file) data_dicts.append(d) super().__init__(data=data_dicts, transform=transforms) logger.info( "MonaiFilenameDataset: mode=%s, samples=%d, base=%s", mode, len(data_dicts), base_path, )
[docs]def create_filename_datasets( json_path: str, train_transforms: Optional[Compose] = None, val_transforms: Optional[Compose] = None, train_val_split: float = 0.9, random_seed: int = 42, images_key: str = "images", labels_key: str = "masks", use_labels: bool = True, ) -> Tuple[MonaiFilenameDataset, MonaiFilenameDataset]: """Create train and val datasets from a single JSON.""" train_ds = MonaiFilenameDataset( json_path=json_path, transforms=train_transforms, mode="train", images_key=images_key, labels_key=labels_key, train_val_split=train_val_split, random_seed=random_seed, use_labels=use_labels, ) val_ds = MonaiFilenameDataset( json_path=json_path, transforms=val_transforms, mode="val", images_key=images_key, labels_key=labels_key, train_val_split=train_val_split, random_seed=random_seed, use_labels=use_labels, ) return train_ds, val_ds
__all__ = [ "MonaiFilenameDataset", "create_filename_datasets", ]