Shortcuts

Source code for connectomics.models.architectures.base

"""
Base model interface for all architectures.

Provides a standard interface that all models should implement,
with explicit support for deep supervision.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Dict, Union

import torch
import torch.nn as nn


[docs]class ConnectomicsModel(nn.Module, ABC): """ Base class for all connectomics models. Provides common interface for: - Forward pass (single or multi-scale outputs) - Deep supervision support - Model information All models in the architectures module should inherit from this class or at least implement the same interface. """ def __init__(self): super().__init__() self.supports_deep_supervision = False self.output_scales = 1
[docs] @abstractmethod def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: """ Forward pass through the model. Args: x: Input tensor of shape (B, C, D, H, W) for 3D or (B, C, H, W) for 2D Returns: For single-scale models: torch.Tensor: Output tensor of shape (B, num_classes, D, H, W) For deep supervision models: Dict[str, torch.Tensor]: Dictionary with keys: - 'output': Main output (full resolution) - 'ds_1': Deep supervision output at scale 1 (1/2 resolution) - 'ds_2': Deep supervision output at scale 2 (1/4 resolution) - 'ds_3': Deep supervision output at scale 3 (1/8 resolution) - 'ds_4': Deep supervision output at scale 4 (1/16 resolution) """ raise NotImplementedError
[docs] def get_model_info(self) -> Dict[str, Any]: """ Get model information and statistics. Returns: Dictionary with model metadata: - name: Model class name - deep_supervision: Whether model supports deep supervision - output_scales: Number of output scales - parameters: Total trainable parameters - trainable_parameters: Number of trainable parameters """ total_params = sum(p.numel() for p in self.parameters()) trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) return { "name": self.__class__.__name__, "deep_supervision": self.supports_deep_supervision, "output_scales": self.output_scales, "parameters": total_params, "trainable_parameters": trainable_params, }
def __repr__(self) -> str: """String representation of the model.""" info = self.get_model_info() return ( f"{info['name']}(" f"parameters={info['parameters']:,}, " f"deep_supervision={info['deep_supervision']})" )
__all__ = ["ConnectomicsModel"]