Shortcuts

Source code for connectomics.metrics.segmentation_numpy

from __future__ import annotations

import numpy as np
import scipy.sparse as sparse
from scipy.optimize import linear_sum_assignment
from skimage.segmentation import relabel_sequential

from connectomics.utils.label_overlap import compute_label_overlap

matching_criteria = dict()

__all__ = [
    "adapted_rand",
    "voi",
    "instance_matching",
    "instance_matching_simple",
    "matching_criteria",
]


def adapted_rand_oracle(seg, gt, gt_ids=None):
    """Efficiently compute per-GT-segment oracle ARE by incremental update.

    For each GT segment, computes the ARE that would result from perfectly
    predicting that segment (``pred[gt == g_id] = unique_id``) while
    keeping everything else unchanged.

    Builds the contingency table once, then for each GT segment updates
    only the affected row — O(nnz_in_row) per GT instead of O(volume).

    Parameters
    ----------
    seg : np.ndarray
        Predicted segmentation.
    gt : np.ndarray, same shape as seg
        Ground-truth segmentation.
    gt_ids : array-like, optional
        GT segment IDs to evaluate.  If None, uses all non-zero GT IDs.

    Returns
    -------
    list of (gt_id, are_oracle, delta_are)
        Sorted by delta_are descending (most impactful first).
        ``delta_are = are_baseline - are_oracle`` (positive = improvement).
    """
    segA = np.ravel(gt)
    segB = np.ravel(seg)
    n = segA.size

    n_labels_A = int(np.amax(segA)) + 1
    n_labels_B = int(np.amax(segB)) + 1

    # Build contingency table: p_ij[gt_label, pred_label] = count
    ones_data = np.ones(n, int)
    p_ij = sparse.csr_matrix(
        (ones_data, (segA, segB)), shape=(n_labels_A, n_labels_B)
    )

    # Baseline quantities (exclude GT background row 0)
    a = p_ij[1:n_labels_A, :]           # all pred cols including bg
    b = p_ij[1:n_labels_A, 1:n_labels_B]  # exclude pred bg col
    c = np.asarray(p_ij[1:n_labels_A, 0].todense()).ravel()  # pred bg col

    a_i = np.asarray(a.sum(1)).ravel()   # GT row sums
    b_i = np.asarray(b.sum(0)).ravel()   # pred col sums (excl bg)

    sum_c = np.sum(c)
    sumA = np.sum(a_i * a_i)
    sumB = np.sum(b_i * b_i) + sum_c / n
    sumAB = np.sum(b.multiply(b)) + sum_c / n

    prec_base = sumAB / sumB if sumB > 0 else 0
    rec_base = sumAB / sumA if sumA > 0 else 0
    f_base = 2.0 * prec_base * rec_base / (prec_base + rec_base) if (prec_base + rec_base) > 0 else 0
    are_base = 1.0 - f_base

    if gt_ids is None:
        gt_ids = np.arange(1, n_labels_A)
    else:
        gt_ids = np.asarray(gt_ids)

    # New pred label ID (beyond any existing)
    new_label = n_labels_B  # 0-indexed in the b matrix → col index = new_label - 1

    results = []
    for g_id in gt_ids:
        g_id = int(g_id)
        if g_id < 1 or g_id >= n_labels_A:
            continue
        row_idx = g_id - 1  # index into a/b/c (which start from GT label 1)

        # Current row entries in b (pred labels excl bg)
        row = b.getrow(row_idx)
        row_data = row.data.copy()  # nonzero values
        row_cols = row.indices.copy()  # column indices

        g_size = int(a_i[row_idx])  # total voxels in this GT segment
        if g_size == 0:
            continue

        # Old contributions from this row to sumAB and sumB
        old_pij_sq = np.sum(row_data * row_data)
        old_c_val = c[row_idx]

        # Old b_i contributions for affected pred columns
        old_bi_affected = b_i[row_cols].copy()

        # --- After oracle fix: row becomes [0, ..., 0, g_size] at new column ---
        # sumA is unchanged (row sum = g_size, same as before)

        # sumAB change:
        #   remove old: sum(p_ij^2) for this row + old_c/n
        #   add new: g_size^2 (single entry) + 0/n (no bg overlap)
        new_sumAB = sumAB - old_pij_sq - old_c_val / n + g_size * g_size

        # sumB change:
        #   affected old columns: b_i[col] decreases by row_data[j]
        #   new column: b_i[new] = g_size
        #   remove old_c contribution, add 0 (oracle has no bg)
        new_sumB = sumB - old_c_val / n
        # Update affected columns
        for j in range(len(row_cols)):
            col = row_cols[j]
            val = row_data[j]
            old_sq = old_bi_affected[j] ** 2
            new_bi = old_bi_affected[j] - val
            new_sumB = new_sumB - old_sq + new_bi * new_bi
        # Add new column
        new_sumB = new_sumB + g_size * g_size

        prec = new_sumAB / new_sumB if new_sumB > 0 else 0
        rec = new_sumAB / sumA if sumA > 0 else 0
        f = 2.0 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0
        are_oracle = 1.0 - f
        delta = are_base - are_oracle

        results.append((g_id, are_oracle, delta))

    results.sort(key=lambda x: -x[2])
    return results, are_base


[docs]def adapted_rand(seg, gt, all_stats=False): """Compute Adapted Rand error as defined by the SNEMI3D contest [1] Formula is given as 1 - the maximal F-score of the Rand index (excluding the zero component of the original labels). Adapted from the SNEMI3D MATLAB script, hence the strange style. Parameters ---------- seg : np.ndarray the segmentation to score, where each value is the label at that point gt : np.ndarray, same shape as seg the groundtruth to score against, where each value is a label all_stats : boolean, optional whether to also return precision and recall as a 3-tuple with rand_error Returns ------- are : float The adapted Rand error; equal to $1 - \frac{2pr}{p + r}$, where $p$ and $r$ are the precision and recall described below. prec : float, optional The adapted Rand precision. (Only returned when `all_stats` is ``True``.) rec : float, optional The adapted Rand recall. (Only returned when `all_stats` is ``True``.) References ---------- [1]: http://brainiac2.mit.edu/SNEMI3D/evaluation """ # Validate shapes match if seg.shape != gt.shape: raise ValueError( f"seg and gt must have the same shape. " f"Got seg.shape={seg.shape}, gt.shape={gt.shape}" ) # segA is truth, segB is query segA = np.ravel(gt) segB = np.ravel(seg) n = segA.size n_labels_A = int(np.amax(segA)) + 1 n_labels_B = int(np.amax(segB)) + 1 ones_data = np.ones(n, int) p_ij = sparse.csr_matrix((ones_data, (segA[:], segB[:])), shape=(n_labels_A, n_labels_B)) a = p_ij[1:n_labels_A, :] b = p_ij[1:n_labels_A, 1:n_labels_B] c = p_ij[1:n_labels_A, 0].todense() d = b.multiply(b) a_i = np.array(a.sum(1)) b_i = np.array(b.sum(0)) sumA = np.sum(a_i * a_i) sumB = np.sum(b_i * b_i) + (np.sum(c) / n) sumAB = np.sum(d) + (np.sum(c) / n) precision = sumAB / sumB recall = sumAB / sumA fScore = 2.0 * precision * recall / (precision + recall) are = 1.0 - fScore if all_stats: return (are, precision, recall) else: return are
# Evaluation code courtesy of Juan Nunez-Iglesias, taken from # https://github.com/janelia-flyem/gala/blob/master/gala/evaluate.py def voi(reconstruction, groundtruth, ignore_reconstruction=None, ignore_groundtruth=None): """Return the conditional entropies of the variation of information metric. [1] Let X be a reconstruction, and Y a ground truth labelling. The variation of information between the two is the sum of two conditional entropies: VI(X, Y) = H(X|Y) + H(Y|X). The first one, H(X|Y), is a measure of oversegmentation, the second one, H(Y|X), a measure of undersegmentation. These measures are referred to as the variation of information split or merge error, respectively. Parameters ---------- seg : np.ndarray, int type, arbitrary shape A candidate segmentation. gt : np.ndarray, int type, same shape as `seg` The ground truth segmentation. ignore_seg, ignore_gt : list of int, optional Any points having a label in this list are ignored in the evaluation. By default, only the label 0 in the ground truth will be ignored. Returns ------- (split, merge) : float The variation of information split and merge error, i.e., H(X|Y) and H(Y|X) References ---------- [1] Meila, M. (2007). Comparing clusterings - an information based distance. Journal of Multivariate Analysis 98, 873-895. """ if ignore_reconstruction is None: ignore_reconstruction = [] if ignore_groundtruth is None: ignore_groundtruth = [0] hyxg, hxgy = split_vi(reconstruction, groundtruth, ignore_reconstruction, ignore_groundtruth) return (hxgy, hyxg) def split_vi(x, y=None, ignore_x=None, ignore_y=None): """Return the symmetric conditional entropies associated with the VI. The variation of information is defined as VI(X,Y) = H(X|Y) + H(Y|X). If Y is the ground-truth segmentation, then H(Y|X) can be interpreted as the amount of under-segmentation of Y and H(X|Y) is then the amount of over-segmentation. In other words, a perfect over-segmentation will have H(Y|X)=0 and a perfect under-segmentation will have H(X|Y)=0. If y is None, x is assumed to be a contingency table. Parameters ---------- x : np.ndarray Label field (int type) or contingency table (float). `x` is interpreted as a contingency table (summing to 1.0) if and only if `y` is not provided. y : np.ndarray of int, same shape as x, optional A label field to compare to `x`. ignore_x, ignore_y : list of int, optional Any points having a label in this list are ignored in the evaluation. Ignore 0-labeled points by default. Returns ------- sv : np.ndarray of float, shape (2,) The conditional entropies of Y|X and X|Y. See Also -------- vi """ if ignore_x is None: ignore_x = [0] if ignore_y is None: ignore_y = [0] _, _, _, hxgy, hygx, _, _ = vi_tables(x, y, ignore_x, ignore_y) # false merges, false splits return np.array([hygx.sum(), hxgy.sum()]) def vi_tables(x, y=None, ignore_x=None, ignore_y=None): """Return probability tables used for calculating VI. If y is None, x is assumed to be a contingency table. Parameters ---------- x, y : np.ndarray Either x and y are provided as equal-shaped np.ndarray label fields (int type), or y is not provided and x is a contingency table (sparse.csc_matrix) that may or may not sum to 1. ignore_x, ignore_y : list of int, optional Rows and columns (respectively) to ignore in the contingency table. These are labels that are not counted when evaluating VI. Returns ------- pxy : sparse.csc_matrix of float The normalized contingency table. px, py, hxgy, hygx, lpygx, lpxgy : np.ndarray of float The proportions of each label in `x` and `y` (`px`, `py`), the per-segment conditional entropies of `x` given `y` and vice-versa, the per-segment conditional probability p log p. """ if ignore_x is None: ignore_x = [0] if ignore_y is None: ignore_y = [0] if y is not None: pxy = contingency_table(x, y, ignore_x, ignore_y) else: cont = x total = float(cont.sum()) # normalize, since it is an identity op if already done pxy = cont / total # Calculate probabilities px = np.array(pxy.sum(axis=1)).ravel() py = np.array(pxy.sum(axis=0)).ravel() # Remove zero rows/cols nzx = px.nonzero()[0] nzy = py.nonzero()[0] nzpx = px[nzx] nzpy = py[nzy] nzpxy = pxy[nzx, :][:, nzy] # Calculate log conditional probabilities and entropies lpygx = np.zeros(np.shape(px)) lpygx[nzx] = xlogx(divide_rows(nzpxy, nzpx)).sum(axis=1).ravel() # \sum_x{p_{y|x} \log{p_{y|x}}} hygx = -(px * lpygx) # \sum_x{p_x H(Y|X=x)} = H(Y|X) lpxgy = np.zeros(np.shape(py)) lpxgy[nzy] = xlogx(divide_columns(nzpxy, nzpy)).sum(axis=0).ravel() hxgy = -(py * lpxgy) return [pxy] + list(map(np.asarray, [px, py, hxgy, hygx, lpygx, lpxgy])) def contingency_table(seg, gt, ignore_seg=None, ignore_gt=None, norm=True): """Return the contingency table for all regions in matched segmentations. Parameters ---------- seg : np.ndarray, int type, arbitrary shape A candidate segmentation. gt : np.ndarray, int type, same shape as `seg` The ground truth segmentation. ignore_seg : list of int, optional Values to ignore in `seg`. Voxels in `seg` having a value in this list will not contribute to the contingency table. (default: [0]) ignore_gt : list of int, optional Values to ignore in `gt`. Voxels in `gt` having a value in this list will not contribute to the contingency table. (default: [0]) norm : bool, optional Whether to normalize the table so that it sums to 1. Returns ------- cont : scipy.sparse.csc_matrix A contingency table. `cont[i, j]` will equal the number of voxels labeled `i` in `seg` and `j` in `gt`. (Or the proportion of such voxels if `norm=True`.) """ if ignore_seg is None: ignore_seg = [0] if ignore_gt is None: ignore_gt = [0] segr = seg.ravel() gtr = gt.ravel() data = np.ones(len(gtr)) ignored = np.isin(segr, ignore_seg) | np.isin(gtr, ignore_gt) data[ignored] = 0 cont = sparse.coo_matrix((data, (segr, gtr))).tocsc() if norm: cont /= float(cont.sum()) return cont def divide_columns(matrix, row, in_place=False): """Divide each column of `matrix` by the corresponding element in `row`. The result is as follows: out[i, j] = matrix[i, j] / row[j] Parameters ---------- matrix : np.ndarray, scipy.sparse.csc_matrix or csr_matrix, shape (M, N) The input matrix. column : a 1D np.ndarray, shape (N,) The row dividing `matrix`. in_place : bool (optional, default False) Do the computation in-place. Returns ------- out : same type as `matrix` The result of the row-wise division. """ if in_place: out = matrix else: out = matrix.copy() if isinstance(out, (sparse.csc_matrix, sparse.csr_matrix)): if isinstance(out, sparse.csc_matrix): convert_to_csc = True out = out.tocsr() else: convert_to_csc = False row_repeated = np.take(row, out.indices) nz = out.data.nonzero() out.data[nz] /= row_repeated[nz] if convert_to_csc: out = out.tocsc() else: out /= row[np.newaxis, :] return out def divide_rows(matrix, column, in_place=False): """Divide each row of `matrix` by the corresponding element in `column`. The result is as follows: out[i, j] = matrix[i, j] / column[i] Parameters ---------- matrix : np.ndarray, scipy.sparse.csc_matrix or csr_matrix, shape (M, N) The input matrix. column : a 1D np.ndarray, shape (M,) The column dividing `matrix`. in_place : bool (optional, default False) Do the computation in-place. Returns ------- out : same type as `matrix` The result of the row-wise division. """ if in_place: out = matrix else: out = matrix.copy() if isinstance(out, (sparse.csc_matrix, sparse.csr_matrix)): if isinstance(out, sparse.csr_matrix): convert_to_csr = True out = out.tocsc() else: convert_to_csr = False column_repeated = np.take(column, out.indices) nz = out.data.nonzero() out.data[nz] /= column_repeated[nz] if convert_to_csr: out = out.tocsr() else: out /= column[:, np.newaxis] return out def xlogx(x, out=None, in_place=False): """Compute x * log_2(x). We define 0 * log_2(0) = 0 Parameters ---------- x : np.ndarray or scipy.sparse.csc_matrix or csr_matrix The input array. out : same type as x (optional) If provided, use this array/matrix for the result. in_place : bool (optional, default False) Operate directly on x. Returns ------- y : same type as x Result of x * log_2(x). """ if in_place: y = x elif out is None: y = x.copy() else: y = out if type(y) in [sparse.csc_matrix, sparse.csr_matrix]: z = y.data else: z = y nz = z.nonzero() z[nz] *= np.log2(z[nz]) return y # Code modified from https://github.com/stardist/stardist # Copied from https://github.com/CSBDeep/CSBDeep/blob/master/csbdeep/utils/utils.py def _raise(e): if isinstance(e, BaseException): raise e else: raise ValueError(e) def label_are_sequential(y): """returns true if y has only sequential labels from 1...""" labels = np.unique(y) return (set(labels) - {0}) == set(range(1, 1 + labels.max())) def is_array_of_integers(y): return isinstance(y, np.ndarray) and np.issubdtype(y.dtype, np.integer) def _check_label_array(y, name=None, check_sequential=False): err = ValueError( "{label} must be an array of {integers}.".format( label="labels" if name is None else name, integers=("sequential " if check_sequential else "") + "non-negative integers", ) ) is_array_of_integers(y) or _raise(err) if len(y) == 0: return True if check_sequential: label_are_sequential(y) or _raise(err) else: y.min() >= 0 or _raise(err) return True def label_overlap(x, y, check=True): if check: _check_label_array(x, "x", True) _check_label_array(y, "y", True) x.shape == y.shape or _raise(ValueError("x and y must have the same shape")) return compute_label_overlap(x, y) def _safe_divide(x, y, eps=1e-10): """computes a safe divide which returns 0 if y is zero""" if np.isscalar(x) and np.isscalar(y): return x / y if np.abs(y) > eps else 0.0 else: out = np.zeros(np.broadcast(x, y).shape, np.float32) np.divide(x, y, out=out, where=np.abs(y) > eps) return out def intersection_over_union(overlap): _check_label_array(overlap, "overlap") if np.sum(overlap) == 0: return overlap n_pixels_pred = np.sum(overlap, axis=0, keepdims=True) n_pixels_true = np.sum(overlap, axis=1, keepdims=True) return _safe_divide(overlap, (n_pixels_pred + n_pixels_true - overlap)) matching_criteria["iou"] = intersection_over_union def intersection_over_true(overlap): _check_label_array(overlap, "overlap") if np.sum(overlap) == 0: return overlap n_pixels_true = np.sum(overlap, axis=1, keepdims=True) return _safe_divide(overlap, n_pixels_true) matching_criteria["iot"] = intersection_over_true def intersection_over_pred(overlap): _check_label_array(overlap, "overlap") if np.sum(overlap) == 0: return overlap n_pixels_pred = np.sum(overlap, axis=0, keepdims=True) return _safe_divide(overlap, n_pixels_pred) matching_criteria["iop"] = intersection_over_pred def precision(tp, fp, fn): return tp / (tp + fp) if tp > 0 else 0 def recall(tp, fp, fn): return tp / (tp + fn) if tp > 0 else 0 def accuracy(tp, fp, fn): # also known as "average precision" (?) # -> https://www.kaggle.com/c/data-science-bowl-2018#evaluation return tp / (tp + fp + fn) if tp > 0 else 0 def f1(tp, fp, fn): # also known as "dice coefficient" return (2 * tp) / (2 * tp + fp + fn) if tp > 0 else 0
[docs]def instance_matching(y_true, y_pred, thresh=0.5, criterion="iou", report_matches=False): """Calculate detection/instance segmentation metrics between ground truth and predictions. Currently, the following metrics are implemented: 'fp', 'tp', 'fn', 'precision', 'recall', 'accuracy', 'f1', 'criterion', 'thresh', 'n_true', 'n_pred', 'mean_true_score', 'mean_matched_score', 'panoptic_quality' Corresponding objects of y_true and y_pred are counted as true positives (tp), false positives (fp), and false negatives (fn) when their intersection over union (IoU) >= thresh (for criterion='iou', which can be changed) * mean_matched_score is the mean IoUs of matched true positives * mean_true_score is the mean IoUs of matched true positives but normalized by the total number of GT objects * panoptic_quality defined as in Eq. 1 of Kirillov et al. "Panoptic Segmentation", CVPR 2019 Parameters ---------- y_true: ndarray ground truth label image (integer valued) y_pred: ndarray predicted label image (integer valued) thresh: float threshold for matching criterion (default 0.5) criterion: string matching criterion (default IoU) report_matches: bool if True, additionally calculate matched_pairs and matched_scores (returns gt-pred pairs even when scores are below 'thresh') Returns ------- Matching object with different metrics as attributes Examples -------- >>> y_true = np.zeros((100,100), np.uint16) >>> y_true[10:20,10:20] = 1 >>> y_pred = np.roll(y_true,5,axis = 0) >>> stats = instance_matching(y_true, y_pred) >>> print(stats) Matching(criterion='iou', thresh=0.5, fp=1, tp=0, fn=1, precision=0, recall=0, accuracy=0, f1=0, n_true=1, n_pred=1, mean_true_score=0.0, mean_matched_score=0.0, panoptic_quality=0.0) """ _check_label_array(y_true, "y_true") _check_label_array(y_pred, "y_pred") y_true.shape == y_pred.shape or _raise( ValueError( "y_true ({y_true.shape}) and y_pred ({y_pred.shape}) have different shapes".format( y_true=y_true, y_pred=y_pred ) ) ) criterion in matching_criteria or _raise( ValueError("Matching criterion '%s' not supported." % criterion) ) if thresh is None: thresh = 0 thresh = float(thresh) if np.isscalar(thresh) else map(float, thresh) y_true, _, map_rev_true = relabel_sequential(y_true) y_pred, _, map_rev_pred = relabel_sequential(y_pred) map_rev_true = np.array(map_rev_true) map_rev_pred = np.array(map_rev_pred) overlap = label_overlap(y_true, y_pred, check=False) scores = matching_criteria[criterion](overlap) if not (0 <= np.min(scores) <= np.max(scores) <= 1): raise ValueError( f"Scores must be in [0, 1], got range [{np.min(scores)}, {np.max(scores)}]" ) # ignoring background scores = scores[1:, 1:] n_true, n_pred = scores.shape n_matched = min(n_true, n_pred) def _single(thr): not_trivial = n_matched > 0 and np.any(scores >= thr) if not_trivial: # compute optimal matching with scores as tie-breaker costs = -(scores >= thr).astype(float) - scores / (2 * n_matched) true_ind, pred_ind = linear_sum_assignment(costs) assert n_matched == len(true_ind) == len(pred_ind) match_ok = scores[true_ind, pred_ind] >= thr tp = np.count_nonzero(match_ok) else: tp = 0 fp = n_pred - tp fn = n_true - tp # the score sum over all matched objects (tp) sum_matched_score = np.sum(scores[true_ind, pred_ind][match_ok]) if not_trivial else 0.0 # the score average over all matched objects (tp) mean_matched_score = _safe_divide(sum_matched_score, tp) # the score average over all gt/true objects mean_true_score = _safe_divide(sum_matched_score, n_true) panoptic_quality = _safe_divide(sum_matched_score, tp + fp / 2 + fn / 2) stats_dict = dict( criterion=criterion, thresh=thr, fp=fp, tp=tp, fn=fn, precision=precision(tp, fp, fn), recall=recall(tp, fp, fn), accuracy=accuracy(tp, fp, fn), f1=f1(tp, fp, fn), n_true=n_true, n_pred=n_pred, mean_true_score=mean_true_score, mean_matched_score=mean_matched_score, panoptic_quality=panoptic_quality, ) if bool(report_matches): if not_trivial: stats_dict.update( # int() to be json serializable matched_pairs=tuple( (int(map_rev_true[i]), int(map_rev_pred[j])) for i, j in zip(1 + true_ind, 1 + pred_ind) ), matched_scores=tuple(scores[true_ind, pred_ind]), matched_tps=tuple(map(int, np.flatnonzero(match_ok))), pred_ids=tuple(map_rev_pred), gt_ids=tuple(map_rev_true), ) else: stats_dict.update( matched_pairs=(), matched_scores=(), matched_tps=(), pred_ids=(), gt_ids=(), ) return stats_dict return _single(thresh) if np.isscalar(thresh) else tuple(map(_single, thresh))
[docs]def instance_matching_simple(y_true, y_pred, thresh=0.5, criterion="iou"): """Calculate relaxed instance segmentation metrics without Hungarian matching. WARNING: This is a RELAXED metric for debugging/analysis only, NOT for benchmark ranking. Unlike instance_matching(), this does NOT use optimal bipartite matching (Hungarian algorithm). Instead, it simply counts all (GT, Pred) pairs with IoU >= threshold as true positives. This metric is useful for: - Quick debugging and sanity checks - Understanding raw overlap statistics - Comparing with strict Hungarian-based metrics Metrics computed: 'tp', 'fp', 'fn', 'precision', 'recall', 'accuracy', 'f1', 'criterion', 'thresh', 'n_true', 'n_pred' Parameters ---------- y_true: ndarray ground truth label image (integer valued) y_pred: ndarray predicted label image (integer valued) thresh: float threshold for matching criterion (default 0.5) criterion: string matching criterion (default 'iou') Returns ------- Dictionary with metrics (tp, fp, fn, precision, recall, accuracy, f1, etc.) Examples -------- >>> y_true = np.zeros((100,100), np.uint16) >>> y_true[10:20,10:20] = 1 >>> y_pred = np.roll(y_true, 5, axis=0) >>> stats = instance_matching_simple(y_true, y_pred) >>> print(f"Accuracy: {stats['accuracy']:.3f}") """ _check_label_array(y_true, "y_true") _check_label_array(y_pred, "y_pred") y_true.shape == y_pred.shape or _raise( ValueError( "y_true ({y_true.shape}) and y_pred ({y_pred.shape}) have different shapes".format( y_true=y_true, y_pred=y_pred ) ) ) criterion in matching_criteria or _raise( ValueError("Matching criterion '%s' not supported." % criterion) ) thresh = float(thresh) y_true, _, map_rev_true = relabel_sequential(y_true) y_pred, _, map_rev_pred = relabel_sequential(y_pred) overlap = label_overlap(y_true, y_pred, check=False) scores = matching_criteria[criterion](overlap) if not (0 <= np.min(scores) <= np.max(scores) <= 1): raise ValueError( f"Scores must be in [0, 1], got range [{np.min(scores)}, {np.max(scores)}]" ) # ignoring background scores = scores[1:, 1:] n_true, n_pred = scores.shape # Simple counting: any pair with IoU >= thresh counts as TP # No Hungarian matching - just count all pairs above threshold tp = np.sum(scores >= thresh) fp = n_pred - tp fn = n_true - tp stats_dict = dict( criterion=criterion, thresh=thresh, fp=fp, tp=tp, fn=fn, precision=precision(tp, fp, fn), recall=recall(tp, fp, fn), accuracy=accuracy(tp, fp, fn), f1=f1(tp, fp, fn), n_true=n_true, n_pred=n_pred, ) return stats_dict