Shortcuts

Source code for connectomics.metrics.nerl

"""Reusable NERL scoring helpers.

This module owns PyTC's low-level adapter around ``em_erl``: graph loading,
NetworkX skeleton conversion, segmentation normalization, and score extraction.
Evaluation-stage logging and report wiring lives in ``connectomics.evaluation``.
"""

from __future__ import annotations

import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Any

import numpy as np

logger = logging.getLogger(__name__)


[docs]@dataclass(frozen=True) class NerlGraphOptions: """Options for converting NetworkX skeleton pickles into ERL graphs.""" skeleton_id_attribute: str = "id" skeleton_position_attribute: str = "index_position" skeleton_edge_length_attribute: str = "edge_length" skeleton_position_order: str = "xyz" prediction_position_order: str | None = None
[docs]@dataclass(frozen=True) class NerlScoreResult: """Detailed NERL score output for stage-specific adapters.""" nerl: float pred_erl: float gt_erl: float num_skeletons: int per_gt_erl: np.ndarray graph: Any
def _materialize_for_parallel(segment, mask, num_workers): """Write ndarray inputs to temp HDF5 so em_erl's multi-process path can use them. em_erl's parallel `compute_segment_lut` requires path-based inputs because workers each open their own VolumeSource. When the caller hands us an in-memory ndarray (as the test pipeline does) we materialize it to a NamedTemporaryFile here so the parallel path is actually exercised. Returns `(seg_arg, mask_arg, tempfiles_to_cleanup)`. """ if num_workers <= 1: return segment, mask, [] import os import tempfile import h5py # Cluster $TMPDIR/$HOME often lives on a network FS where HDF5 POSIX # file locking returns ENOSPC even when there's plenty of space. Prefer # /dev/shm (tmpfs in RAM) when available; otherwise fall back to the # default temp dir. temp_dir = "/dev/shm" if os.path.isdir("/dev/shm") and os.access("/dev/shm", os.W_OK) else None tempfiles = [] def _maybe_write(arr, prefix): if not isinstance(arr, np.ndarray): return arr fh = tempfile.NamedTemporaryFile(suffix=".h5", prefix=prefix, dir=temp_dir, delete=False) fh.close() with h5py.File(fh.name, "w") as fid: fid.create_dataset("main", data=np.ascontiguousarray(arr)) tempfiles.append(fh.name) return fh.name seg_arg = _maybe_write(segment, "nerl_seg_") mask_arg = _maybe_write(mask, "nerl_mask_") if mask is not None else None return seg_arg, mask_arg, tempfiles def import_em_erl(): try: from em_erl import ERLGraph, compute_erl_score, compute_segment_lut return ERLGraph, compute_erl_score, compute_segment_lut except ModuleNotFoundError: import sys repo_root = Path(__file__).resolve().parents[2] em_erl_root = repo_root / "lib" / "em_erl" if em_erl_root.exists(): sys.path.insert(0, str(em_erl_root)) from em_erl import ERLGraph, compute_erl_score, compute_segment_lut return ERLGraph, compute_erl_score, compute_segment_lut def reorder_coordinate_axes( coords: np.ndarray, *, source_order: str, target_order: str | None, ) -> np.ndarray: source_order = str(source_order).lower() target_order = source_order if target_order is None else str(target_order).lower() valid_axes = {"x", "y", "z"} if len(source_order) != 3 or set(source_order) != valid_axes: raise ValueError(f"Invalid skeleton coordinate order: {source_order!r}") if len(target_order) != 3 or set(target_order) != valid_axes: raise ValueError(f"Invalid prediction coordinate order: {target_order!r}") axis_indices = [source_order.index(axis) for axis in target_order] return np.asarray(coords)[:, axis_indices] def networkx_skeleton_to_erl_graph( skeleton: Any, options: NerlGraphOptions | None = None, resolution: Any = None, ): ERLGraph, _, _ = import_em_erl() options = options or NerlGraphOptions() node_ids = list(skeleton.nodes) if not node_ids: raise ValueError("NERL skeleton has no nodes") raw_skeleton_ids = [] node_coords = [] for node_id in node_ids: node_data = skeleton.nodes[node_id] raw_skeleton_ids.append(node_data[options.skeleton_id_attribute]) node_coords.append(node_data[options.skeleton_position_attribute]) skeleton_ids = list(dict.fromkeys(raw_skeleton_ids)) skeleton_index_by_id = {skeleton_id: i for i, skeleton_id in enumerate(skeleton_ids)} node_index_by_id = {node_id: i for i, node_id in enumerate(node_ids)} node_skeleton_index = np.asarray( [skeleton_index_by_id[skeleton_id] for skeleton_id in raw_skeleton_ids], dtype=np.uint32, ) node_coords_arr = reorder_coordinate_axes( np.asarray(node_coords, dtype=np.float32), source_order=options.skeleton_position_order, target_order=options.prediction_position_order, ) edge_coords = node_coords_arr if resolution is not None: res = np.asarray(resolution, dtype=np.float64).reshape(-1) if res.size != 3: raise ValueError(f"NERL resolution must have 3 elements, got {res.size}") edge_coords = node_coords_arr.astype(np.float64) * res edge_buckets: list[list[tuple[int, int, float]]] = [[] for _ in skeleton_ids] skeleton_len: np.ndarray = np.zeros(len(skeleton_ids), dtype=np.float64) for u, v, edge_data in skeleton.edges(data=True): if u not in node_index_by_id or v not in node_index_by_id: continue u_idx = node_index_by_id[u] v_idx = node_index_by_id[v] skel_idx = int(node_skeleton_index[u_idx]) if skel_idx != int(node_skeleton_index[v_idx]): continue if options.skeleton_edge_length_attribute in edge_data: edge_len = float(edge_data[options.skeleton_edge_length_attribute]) else: edge_len = float(np.linalg.norm(edge_coords[u_idx] - edge_coords[v_idx])) edge_buckets[skel_idx].append((u_idx, v_idx, edge_len)) skeleton_len[skel_idx] += edge_len edge_ptr = [0] edge_u = [] edge_v = [] edge_lens: list[float] = [] for bucket in edge_buckets: for u_idx, v_idx, length in bucket: edge_u.append(u_idx) edge_v.append(v_idx) edge_lens.append(length) edge_ptr.append(len(edge_u)) return ERLGraph( skeleton_id=np.asarray(skeleton_ids), skeleton_len=skeleton_len, node_skeleton_index=node_skeleton_index, node_coords_zyx=node_coords_arr, edge_u=np.asarray(edge_u, dtype=np.uint32), edge_v=np.asarray(edge_v, dtype=np.uint32), edge_len=np.asarray(edge_lens, dtype=np.float32), edge_ptr=np.asarray(edge_ptr, dtype=np.uint64), ) _ERL_CACHE_FIELDS = ( "skeleton_id", "skeleton_len", "node_skeleton_index", "node_coords_zyx", "edge_u", "edge_v", "edge_len", "edge_ptr", ) def _resolution_cache_tag(resolution: Any) -> str: if resolution is None: return "vox" res = np.asarray(resolution, dtype=np.float64).reshape(-1) return "res" + "_".join(f"{r:g}" for r in res) def _erl_cache_path(source: Path, resolution: Any = None) -> Path: return source.with_suffix(source.suffix + f".erl_cache.{_resolution_cache_tag(resolution)}.npz") def _save_erl_cache(graph: Any, voxel_coords: bool, cache_path: Path) -> None: try: np.savez_compressed( cache_path, voxel_coords=np.asarray(int(voxel_coords)), **{name: getattr(graph, name) for name in _ERL_CACHE_FIELDS}, ) except OSError as exc: logger.warning("Failed to write ERLGraph cache %s: %s", cache_path, exc) def _load_erl_cache(cache_path: Path) -> tuple[Any, bool]: ERLGraph, _, _ = import_em_erl() data = np.load(cache_path, allow_pickle=False) voxel_coords = bool(int(np.asarray(data["voxel_coords"]).item())) graph = ERLGraph(**{name: data[name] for name in _ERL_CACHE_FIELDS}) return graph, voxel_coords def load_nerl_graph( graph_source: Any, graph_options: NerlGraphOptions | None = None, resolution: Any = None, ): ERLGraph, _, _ = import_em_erl() if isinstance(graph_source, ERLGraph): return graph_source, False if hasattr(graph_source, "node_coords_zyx") and hasattr(graph_source, "edge_ptr"): return graph_source, False graph_path = Path(graph_source) suffix = graph_path.suffix.lower() if suffix == ".npz": return ERLGraph.from_npz(graph_path), False if suffix in {".pkl", ".pickle"}: cache_path = _erl_cache_path(graph_path, resolution) if cache_path.exists() and cache_path.stat().st_mtime >= graph_path.stat().st_mtime: try: return _load_erl_cache(cache_path) except (KeyError, OSError, ValueError) as exc: logger.warning("Ignoring corrupt ERLGraph cache %s: %s", cache_path, exc) import pickle with open(graph_path, "rb") as f: skeleton = pickle.load(f) graph = networkx_skeleton_to_erl_graph(skeleton, graph_options, resolution=resolution) _save_erl_cache(graph, True, cache_path) return graph, True raise ValueError( "NERL skeleton must be an ERLGraph .npz or " f"NetworkX skeleton pickle, got {graph_path}" ) def prepare_nerl_segmentation(decoded_predictions: np.ndarray) -> np.ndarray: seg = np.asarray(decoded_predictions) while seg.ndim > 3 and seg.shape[0] == 1: seg = seg[0] if seg.ndim > 3: singleton_axes = tuple(i for i, size in enumerate(seg.shape) if size == 1) if singleton_axes: seg = np.squeeze(seg, axis=singleton_axes) if seg.ndim != 3: raise ValueError(f"NERL expects a 3D decoded instance volume, got shape {seg.shape}") if not np.issubdtype(seg.dtype, np.integer): seg = seg.astype(np.uint32, copy=False) return seg def extract_nerl_score_outputs(score: Any) -> tuple[float, float, int, np.ndarray]: """Return aggregate and per-GT ERL values from an em_erl score object.""" score_erl = np.asarray(score.erl) if score_erl.ndim > 1: score_erl = score_erl[0] pred_erl = getattr(score, "pred_erl", None) gt_erl = getattr(score, "gt_erl", None) if pred_erl is None: pred_erl = score_erl[0] if gt_erl is None: gt_erl = score_erl[1] num_skeletons = int(score_erl[2]) if score_erl.size > 2 else int(len(score.skeleton_len)) per_gt_erl = None for attr_name in ( "per_gt_erl", "gt_segment_erl", "skeleton_erl_pair", "skeleton_erl_pairs", ): attr_value = getattr(score, attr_name, None) if attr_value is not None: per_gt_erl = np.asarray(attr_value, dtype=np.float64) break if per_gt_erl is None: skeleton_pred_erl = getattr(score, "skeleton_pred_erl", None) if skeleton_pred_erl is None: skeleton_pred_erl = score.skeleton_erl skeleton_gt_erl = getattr(score, "skeleton_gt_erl", None) if skeleton_gt_erl is None: skeleton_gt_erl = score.skeleton_len skeleton_pred_erl = np.asarray(skeleton_pred_erl, dtype=np.float64) skeleton_gt_erl = np.asarray(skeleton_gt_erl, dtype=np.float64) if skeleton_pred_erl.ndim == 2 and skeleton_pred_erl.shape[1] >= 2: per_gt_erl = skeleton_pred_erl[:, :2] else: per_gt_erl = np.column_stack([skeleton_pred_erl, skeleton_gt_erl]) if per_gt_erl.ndim == 1: per_gt_erl = per_gt_erl.reshape(0, 2) if per_gt_erl.size == 0 else per_gt_erl.reshape(1, -1) if per_gt_erl.ndim != 2 or per_gt_erl.shape[1] != 2: raise ValueError(f"NERL per-GT ERL array must have shape [N, 2], got {per_gt_erl.shape}") return float(pred_erl), float(gt_erl), num_skeletons, per_gt_erl
[docs]def compute_nerl_score_details( segmentation: np.ndarray, skeleton_value: Any, *, skeleton_mask_value: Any = None, resolution: Any = None, merge_threshold: int = 1, chunk_num: int = 1, num_workers: int = 1, graph_options: NerlGraphOptions | None = None, ) -> NerlScoreResult: """Compute detailed NERL output for one segmentation/skeleton pair.""" _, compute_erl_score, compute_segment_lut = import_em_erl() erl_graph, voxel_coords = load_nerl_graph(skeleton_value, graph_options, resolution=resolution) if voxel_coords: node_positions = np.asarray(erl_graph.node_coords_zyx, dtype=np.int64) else: node_positions = erl_graph.get_nodes_position(resolution) segment = prepare_nerl_segmentation(segmentation) seg_arg, mask_arg, _tempfiles = _materialize_for_parallel( segment, skeleton_mask_value, int(num_workers) ) try: node_segment_lut, mask_segment_id = compute_segment_lut( seg_arg, node_positions, mask=mask_arg, chunk_num=int(chunk_num), data_type=segment.dtype, num_workers=int(num_workers), ) finally: for path in _tempfiles: try: Path(path).unlink() except OSError: pass score = compute_erl_score( erl_graph, node_segment_lut, mask_segment_id, merge_threshold=int(merge_threshold), ) score.compute_erl() pred_erl, gt_erl, num_skeletons, per_gt_erl = extract_nerl_score_outputs(score) nerl = pred_erl / gt_erl if gt_erl > 0 else float("nan") return NerlScoreResult( nerl=float(nerl), pred_erl=float(pred_erl), gt_erl=float(gt_erl), num_skeletons=num_skeletons, per_gt_erl=per_gt_erl, graph=erl_graph, )
[docs]def compute_nerl_score( segmentation: np.ndarray, skeleton_value: Any, *, skeleton_mask_value: Any = None, resolution: Any = None, merge_threshold: int = 1, chunk_num: int = 1, num_workers: int = 1, graph_options: NerlGraphOptions | None = None, ) -> tuple[float, float, float]: """Return ``(nerl, pred_erl, gt_erl)`` for one segmentation/skeleton pair.""" result = compute_nerl_score_details( segmentation, skeleton_value, skeleton_mask_value=skeleton_mask_value, resolution=resolution, merge_threshold=merge_threshold, chunk_num=chunk_num, num_workers=num_workers, graph_options=graph_options, ) return result.nerl, result.pred_erl, result.gt_erl
__all__ = [ "NerlGraphOptions", "NerlScoreResult", "compute_nerl_score", "compute_nerl_score_details", "extract_nerl_score_outputs", "import_em_erl", "load_nerl_graph", "networkx_skeleton_to_erl_graph", "prepare_nerl_segmentation", "reorder_coordinate_axes", ]