Shortcuts

Source code for connectomics.data.io.transforms

"""
MONAI transforms for connectomics I/O operations.

This module provides MONAI-compatible transforms for:
- Volume loading (HDF5, TIFF, PNG)
- Volume saving
- Tile-based loading for large datasets
"""

from __future__ import annotations

import os
from typing import Any, Dict, Sequence, Tuple

import numpy as np
from monai.config import KeysCollection
from monai.transforms import MapTransform

from .io import read_volume, save_volume
from .tiles import reconstruct_volume_from_tiles


[docs]class LoadVolumed(MapTransform): """MONAI loader for connectomics volume data. Loads HDF5, TIFF, PNG, NIfTI files and ensures channel-first format with a channel dimension. Args: keys: Keys to load from the data dictionary. transpose_axes: Axis permutation for spatial dims (e.g., [2,1,0] for xyz->zyx). Applied BEFORE adding channel dimension. allow_missing_keys: Allow missing keys. """ def __init__( self, keys: KeysCollection, transpose_axes: Sequence[int] | None = None, allow_missing_keys: bool = False, ): super().__init__(keys, allow_missing_keys) self.transpose_axes = list(transpose_axes) if transpose_axes else [] def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: d = dict(data) for key in self.key_iterator(d): if key not in d or not isinstance(d[key], str): continue source_path = d[key] volume = read_volume(source_path) if self.transpose_axes: if volume.ndim == 3: volume = np.transpose(volume, self.transpose_axes) elif volume.ndim == 4: spatial_t = [i + 1 for i in self.transpose_axes] volume = np.transpose(volume, [0] + spatial_t) # Ensure channel dim: (H,W)->(1,H,W), # (D,H,W)->(1,D,H,W) if volume.ndim in (2, 3): volume = np.expand_dims(volume, axis=0) d[key] = volume meta_key = f"{key}_meta_dict" meta_dict = dict(d.get(meta_key, {})) meta_dict.update( { "filename_or_obj": source_path, "original_shape": tuple(volume.shape), "spatial_shape": tuple(volume.shape[1:]), "channels_first": True, "transpose_axes": (self.transpose_axes if self.transpose_axes else None), } ) d[meta_key] = meta_dict return d
class SaveVolumed(MapTransform): """MONAI transform for saving volume data. Args: keys: Keys to save from the data dictionary. output_dir: Output directory. output_format: File format ('h5' or 'png'). allow_missing_keys: Allow missing keys. """ def __init__( self, keys: KeysCollection, output_dir: str, output_format: str = "h5", allow_missing_keys: bool = False, ): super().__init__(keys, allow_missing_keys) self.output_dir = output_dir self.output_format = output_format def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: os.makedirs(self.output_dir, exist_ok=True) d = dict(data) for key in self.key_iterator(d): if key in d and isinstance(d[key], np.ndarray): fn = os.path.join( self.output_dir, f"{key}.{self.output_format}", ) save_volume( fn, d[key], file_format=self.output_format, ) return d class TileLoaderd(MapTransform): """MONAI transform for loading tile-based data. Reconstructs volumes from tiles based on chunk coordinates and metadata. Args: keys: Keys to process. allow_missing_keys: Allow missing keys. """ def __init__( self, keys: Sequence[str], allow_missing_keys: bool = False, ): super().__init__(keys, allow_missing_keys) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: d = dict(data) for key in self.key_iterator(d): if key in d and isinstance(d[key], dict): ti = d[key] if "metadata" in ti and "chunk_coords" in ti: d[key] = self._load_tiles_for_chunk(ti["metadata"], ti["chunk_coords"]) return d def _load_tiles_for_chunk( self, metadata: Dict[str, Any], coords: Tuple[int, int, int, int, int, int], ) -> np.ndarray: z0, z1, y0, y1, x0, x1 = coords tile_paths = metadata["image"][z0:z1] volume_coords = [z0, z1, y0, y1, x0, x1] tile_coords = [ 0, metadata["depth"], 0, metadata["height"], 0, metadata["width"], ] return reconstruct_volume_from_tiles( tile_paths=tile_paths, volume_coords=volume_coords, tile_coords=tile_coords, tile_size=metadata["tile_size"], data_type=np.dtype(metadata["dtype"]), tile_start=metadata.get("tile_st"), tile_ratio=metadata.get("tile_ratio", 1.0), )