Shortcuts

Source code for connectomics.metrics.metrics_skel

"""
Skeleton-based evaluation metrics for curvilinear structures.

This module provides metrics for evaluating segmentation quality of curvilinear
structures (e.g., neurons, blood vessels) based on skeleton correctness,
completeness, and quality.

Based on:
    Mosinska et al., "Beyond the Pixel-Wise Loss for Topology-Aware Delineation"
    https://arxiv.org/abs/1712.02190

Metrics:
    - Correctness: TP / (TP + FP) on skeleton
    - Completeness: TP / (TP + FN) on skeleton
    - Quality: (Completeness * Correctness) /
      (Completeness + Correctness - Completeness * Correctness)
    - Foreground IoU: Intersection over Union of foreground regions

File-backed batch evaluation lives in ``connectomics.evaluation``.
"""

from __future__ import annotations

from typing import List, Tuple

import numpy as np

from skimage.morphology import dilation, skeletonize, square


def compute_skeleton_metrics(
    skeleton_output: List[np.ndarray],
    skeleton_gt: List[np.ndarray],
    skeleton_output_dil: List[np.ndarray],
    skeleton_gt_dil: List[np.ndarray],
) -> Tuple[float, float, float]:
    """Compute skeleton-based metrics for curvilinear structure evaluation.

    Args:
        skeleton_output: List of skeletonized predictions (binarized)
        skeleton_gt: List of skeletonized ground truth
        skeleton_output_dil: List of dilated skeletonized predictions
        skeleton_gt_dil: List of dilated skeletonized ground truth

    Returns:
        Tuple of (correctness, completeness, quality) metrics
            - Correctness: True positives / (True positives + False positives)
            - Completeness: True positives / (True positives + False negatives)
            - Quality: F-measure combining correctness and completeness

    Notes:
        - All inputs should be binary arrays (0 or 1)
        - Dilation factor should be consistent (typically 5x5 square)
        - Quality is 0 if denominator is 0 (no predictions or ground truth)
    """
    tpcor = 0  # True positives for correctness
    tpcom = 0  # True positives for completeness
    fn = 0  # False negatives
    fp = 0  # False positives

    for i in range(len(skeleton_output)):
        # Correctness: predicted skeleton matches dilated GT
        tpcor += ((skeleton_output[i] == skeleton_gt_dil[i]) & (skeleton_output[i] == 1)).sum()

        # Completeness: GT skeleton matches dilated prediction
        tpcom += ((skeleton_gt[i] == skeleton_output_dil[i]) & (skeleton_gt[i] == 1)).sum()

        # False negatives: GT not covered by dilated prediction
        fn += (skeleton_gt[i] == 1).sum() - (
            (skeleton_gt[i] == skeleton_output_dil[i]) & (skeleton_gt[i] == 1)
        ).sum()

        # False positives: prediction not matching dilated GT
        fp += (skeleton_output[i] == 1).sum() - (
            (skeleton_output[i] == skeleton_gt_dil[i]) & (skeleton_output[i] == 1)
        ).sum()

    # Calculate metrics with safety checks
    correctness = tpcor / (tpcor + fp) if (tpcor + fp) > 0 else 0.0
    completeness = tpcom / (tpcom + fn) if (tpcom + fn) > 0 else 0.0

    # Quality (F-measure)
    denominator = completeness + correctness - completeness * correctness
    quality = (completeness * correctness / denominator) if denominator > 0 else 0.0

    return correctness, completeness, quality


def compute_precision_recall(
    pred: np.ndarray,
    gt: np.ndarray,
    dilation_size: int = 5,
) -> Tuple[float, float, float]:
    """Compute precision and recall metrics for single prediction-GT pair.

    Args:
        pred: Predicted binary segmentation mask
        gt: Ground truth binary segmentation mask
        dilation_size: Size of square structuring element for dilation. Default: 5

    Returns:
        Tuple of (correctness, completeness, quality) metrics

    Notes:
        - Automatically skeletonizes input masks
        - Applies dilation with specified size (default 5x5 square)
        - Returns 0.0 for all metrics if inputs are empty
    """
    # Skeletonize both prediction and ground truth
    pred_skel = skeletonize(pred)
    gt_skel = skeletonize(gt)

    # Dilate skeletons for tolerance in matching
    pred_dil = dilation(pred_skel, square(dilation_size))
    gt_dil = dilation(gt_skel, square(dilation_size))

    return compute_skeleton_metrics([pred_skel], [gt_skel], [pred_dil], [gt_dil])


def compute_iou(pred: np.ndarray, gt: np.ndarray) -> float:
    """Calculate foreground Intersection over Union (IoU).

    Args:
        pred: Predicted binary mask
        gt: Ground truth binary mask

    Returns:
        Foreground IoU score in range [0.0, 1.0]
            - 1.0: Perfect overlap
            - 0.0: No overlap or empty union

    Notes:
        - Both inputs should be binary (0 or 1)
        - Returns 0.0 if union is empty (both masks are empty)
    """
    inter = np.logical_and(pred, gt).astype(np.float32)
    union = np.logical_or(pred, gt).astype(np.float32)

    if union.sum() == 0:
        return 0.0

    return inter.sum() / union.sum()


def binarize_masks(
    pred: np.ndarray,
    gt: np.ndarray,
    threshold: int = 128,
) -> Tuple[np.ndarray, np.ndarray]:
    """Binarize prediction and ground truth masks.

    Args:
        pred: Prediction mask (0-255 range)
        gt: Ground truth mask (0-255 range)
        threshold: Threshold for binarizing prediction. Default: 128

    Returns:
        Tuple of (binarized_pred, binarized_gt) as uint8 arrays

    Notes:
        - Prediction: threshold at specified value (default 128)
        - Ground truth: exclude 0 (background) and 255 (ignore) pixels
    """
    pred_bin = (pred > threshold).astype(np.uint8)
    gt_bin = ((gt != 0) & (gt != 255)).astype(np.uint8)
    return pred_bin, gt_bin


[docs]def evaluate_image_pair( pred: np.ndarray, gt: np.ndarray, threshold: int = 128, dilation_size: int = 5, ) -> Tuple[float, float, float, float]: """Evaluate single prediction-ground truth pair. Args: pred: Prediction mask (0-255 range) gt: Ground truth mask (0-255 range) threshold: Threshold for binarizing prediction. Default: 128 dilation_size: Dilation size for skeleton matching. Default: 5 Returns: Tuple of (iou, correctness, completeness, quality) metrics - Returns (1.0, 1.0, 1.0, 1.0) if GT is empty - All values in range [0.0, 1.0] """ # Binarize masks pred_bin, gt_bin = binarize_masks(pred, gt, threshold) # Handle empty ground truth num_gt = gt_bin.sum() if num_gt == 0: return 1.0, 1.0, 1.0, 1.0 # Compute metrics iou = compute_iou(pred_bin, gt_bin) correctness, completeness, quality = compute_precision_recall(pred_bin, gt_bin, dilation_size) return iou, correctness, completeness, quality
__all__ = [ "compute_skeleton_metrics", "compute_precision_recall", "compute_iou", "binarize_masks", "evaluate_image_pair", ]