Source code for connectomics.metrics.metrics_seg
"""
Segmentation metrics entrypoints.
This module re-exports numpy/scipy implementations from `segmentation_numpy`
and provides lightweight torchmetrics-compatible wrappers for online evaluation.
"""
from __future__ import annotations
import torch
import torchmetrics
from .segmentation_numpy import (
adapted_rand,
instance_matching,
instance_matching_simple,
matching_criteria,
voi,
)
__all__ = [
"adapted_rand",
"instance_matching",
"instance_matching_simple",
"matching_criteria",
"AdaptedRandError",
"VariationOfInformation",
"InstanceAccuracy",
"InstanceAccuracySimple",
]
[docs]class AdaptedRandError(torchmetrics.Metric):
"""
Torchmetrics-style wrapper around the numpy-based adapted Rand implementation.
This wrapper lets us accumulate scores during Lightning `test_step` without
manual numpy<->torch conversions in the training loop.
Args:
return_all_stats: If True, also compute and return precision and recall
dist_sync_on_step: Whether to sync across distributed processes on each step
"""
full_state_update: bool = False
def __init__(self, return_all_stats: bool = False, dist_sync_on_step: bool = False) -> None:
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.return_all_stats = return_all_stats
self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
if return_all_stats:
self.add_state("total_precision", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total_recall", default=torch.tensor(0.0), dist_reduce_fx="sum")
[docs] def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
# Move to CPU and numpy for the underlying implementation
preds_np = preds.detach().cpu().numpy()
target_np = target.detach().cpu().numpy()
if self.return_all_stats:
are, prec, rec = adapted_rand(preds_np, target_np, all_stats=True)
self.total += float(are)
self.total_precision += float(prec)
self.total_recall += float(rec)
else:
score = float(adapted_rand(preds_np, target_np, all_stats=False))
self.total += score
self.count += 1
[docs] def compute(self) -> torch.Tensor:
if self.count == 0:
if self.return_all_stats:
return {
"adapted_rand_error": torch.tensor(0.0, device=self.total.device),
"adapted_rand_precision": torch.tensor(0.0, device=self.total.device),
"adapted_rand_recall": torch.tensor(0.0, device=self.total.device),
}
return torch.tensor(0.0, device=self.total.device)
if self.return_all_stats:
return {
"adapted_rand_error": self.total / self.count,
"adapted_rand_precision": self.total_precision / self.count,
"adapted_rand_recall": self.total_recall / self.count,
}
return self.total / self.count
[docs]class InstanceAccuracy(torchmetrics.Metric):
"""
Torchmetrics-style wrapper around instance_matching for instance-level accuracy.
Instance accuracy measures the fraction of correctly detected instances:
accuracy = TP / (TP + FP + FN)
Where:
- TP (True Positives): Number of GT instances correctly matched to predictions
- FP (False Positives): Number of predicted instances not matched to GT
- FN (False Negatives): Number of GT instances not matched to predictions
Matching is based on IoU threshold (default 0.5).
Higher values are better (1.0 = perfect detection).
This wrapper lets us accumulate scores during Lightning `test_step` without
manual numpy<->torch conversions in the training loop.
"""
full_state_update: bool = False
def __init__(
self, thresh: float = 0.5, criterion: str = "iou", dist_sync_on_step: bool = False
) -> None:
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.thresh = thresh
self.criterion = criterion
self.add_state("tp_total", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("fp_total", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("fn_total", default=torch.tensor(0), dist_reduce_fx="sum")
[docs] def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
# Move to CPU and numpy for the underlying implementation
preds_np = preds.detach().cpu().numpy()
target_np = target.detach().cpu().numpy()
stats = instance_matching(target_np, preds_np, thresh=self.thresh, criterion=self.criterion)
self.tp_total += int(stats["tp"])
self.fp_total += int(stats["fp"])
self.fn_total += int(stats["fn"])
[docs] def compute(self) -> torch.Tensor:
"""Return instance-level accuracy: TP / (TP + FP + FN)."""
denom = self.tp_total + self.fp_total + self.fn_total
if denom == 0:
return torch.tensor(0.0, device=self.tp_total.device)
return self.tp_total.float() / denom.float()
[docs]class InstanceAccuracySimple(torchmetrics.Metric):
"""
Torchmetrics-style wrapper for relaxed instance-level accuracy (NO Hungarian matching).
WARNING: This is a RELAXED metric for debugging/analysis only, NOT for benchmark ranking.
Unlike InstanceAccuracy, this does NOT use optimal bipartite matching.
Simple counting approach:
- Count all (GT, Pred) pairs with IoU >= threshold as TP
- fp = n_pred - tp
- fn = n_true - tp
- accuracy = tp / (tp + fp + fn)
This metric is useful for:
- Quick debugging and sanity checks
- Understanding raw overlap statistics
- Comparing with strict Hungarian-based metrics
Higher values are better (1.0 = perfect detection).
This wrapper lets us accumulate scores during Lightning `test_step` without
manual numpy<->torch conversions in the training loop.
"""
full_state_update: bool = False
def __init__(
self, thresh: float = 0.5, criterion: str = "iou", dist_sync_on_step: bool = False
) -> None:
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.thresh = thresh
self.criterion = criterion
self.add_state("tp_total", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("fp_total", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("fn_total", default=torch.tensor(0), dist_reduce_fx="sum")
[docs] def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
# Move to CPU and numpy for the underlying implementation
preds_np = preds.detach().cpu().numpy()
target_np = target.detach().cpu().numpy()
stats = instance_matching_simple(
target_np, preds_np, thresh=self.thresh, criterion=self.criterion
)
self.tp_total += int(stats["tp"])
self.fp_total += int(stats["fp"])
self.fn_total += int(stats["fn"])
[docs] def compute(self) -> torch.Tensor:
"""Return relaxed instance-level accuracy: TP / (TP + FP + FN)."""
denom = self.tp_total + self.fp_total + self.fn_total
if denom == 0:
return torch.tensor(0.0, device=self.tp_total.device)
return self.tp_total.float() / denom.float()
[docs] def compute_precision(self) -> torch.Tensor:
"""Return instance-level precision: TP / (TP + FP)."""
denom = self.tp_total + self.fp_total
if denom == 0:
return torch.tensor(0.0, device=self.tp_total.device)
return self.tp_total.float() / denom.float()
[docs] def compute_recall(self) -> torch.Tensor:
"""Return instance-level recall: TP / (TP + FN)."""
denom = self.tp_total + self.fn_total
if denom == 0:
return torch.tensor(0.0, device=self.tp_total.device)
return self.tp_total.float() / denom.float()
[docs] def compute_f1(self) -> torch.Tensor:
"""Return instance-level F1: 2*TP / (2*TP + FP + FN)."""
denom = 2 * self.tp_total + self.fp_total + self.fn_total
if denom == 0:
return torch.tensor(0.0, device=self.tp_total.device)
return (2 * self.tp_total).float() / denom.float()