Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/scib_metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
silhouette_batch,
silhouette_label,
bras,
spatial_mrre,
spatial_knn_overlap,
spatial_distance_correlation,
spatial_morans_i,
)
from ._settings import settings

Expand All @@ -35,6 +39,10 @@
"kbet",
"kbet_per_label",
"graph_connectivity",
"spatial_mrre",
"spatial_knn_overlap",
"spatial_distance_correlation",
"spatial_morans_i",
"settings",
]

Expand Down
4 changes: 2 additions & 2 deletions src/scib_metrics/benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from ._core import BatchCorrection, Benchmarker, BioConservation
from ._core import BatchCorrection, Benchmarker, BioConservation, SpatialConservation

__all__ = ["Benchmarker", "BioConservation", "BatchCorrection"]
__all__ = ["Benchmarker", "BioConservation", "BatchCorrection", "SpatialConservation"]
130 changes: 118 additions & 12 deletions src/scib_metrics/benchmark/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,14 @@
Kwargs = dict[str, Any]
MetricType = bool | Kwargs

# Sentinel used to detect when spatial_conservation_metrics was not explicitly
# passed so that it can be auto-derived from spatial_key.
_SPATIAL_UNSET: object = object()

_LABELS = "labels"
_BATCH = "batch"
_X_PRE = "X_pre"
_SPATIAL = "spatial"
_METRIC_TYPE = "Metric Type"
_AGGREGATE_SCORE = "Aggregate score"

Expand All @@ -46,6 +51,10 @@
"bras": "BRAS",
"graph_connectivity": "Graph connectivity",
"pcr_comparison": "PCR comparison",
"spatial_mrre": "MRRE",
"spatial_knn_overlap": "kNN overlap",
"spatial_distance_correlation": "Distance corr.",
"spatial_morans_i": "Moran's I",
}


Expand Down Expand Up @@ -81,6 +90,37 @@ class BatchCorrection:
pcr_comparison: MetricType = True


@dataclass(frozen=True)
class SpatialConservation:
"""Specification of spatial conservation metrics to run in the pipeline.

These embedding-based metrics compare each model's latent representation
against the physical spot coordinates to quantify how well spatial
structure is preserved. All scores are in ``[0, 1]`` with higher better,
and they vary across embeddings.

Metrics can be included using a boolean flag. Custom keyword args can be
used by passing a dictionary here. Keyword args should not set data-related
parameters such as ``X_embedding`` or ``spatial_coords``.

Direction of each metric (all higher = better):

* ``spatial_mrre`` — 1 minus normalised mean relative rank error of spatial
neighbours in the latent space.
* ``spatial_knn_overlap`` — mean Jaccard overlap of spatial vs. latent
k-NN sets per spot.
* ``spatial_distance_correlation`` — Spearman rank correlation of
pairwise spatial vs. latent distances, rescaled to ``[0, 1]``.
* ``spatial_morans_i`` — mean Moran's I of latent dimensions using the
spatial weight graph, rescaled to ``[0, 1]``.
"""

spatial_mrre: MetricType = True
spatial_knn_overlap: MetricType = True
spatial_distance_correlation: MetricType = True
spatial_morans_i: MetricType = True


class MetricAnnDataAPI(Enum):
"""Specification of the AnnData API for a metric."""

Expand All @@ -94,6 +134,11 @@ class MetricAnnDataAPI(Enum):
pcr_comparison = lambda ad, fn: fn(ad.obsm[_X_PRE], ad.X, ad.obs[_BATCH], categorical=True)
ilisi_knn = lambda ad, fn: fn(ad.uns["90_neighbor_res"], ad.obs[_BATCH])
kbet_per_label = lambda ad, fn: fn(ad.uns["50_neighbor_res"], ad.obs[_BATCH], ad.obs[_LABELS])
# Spatial conservation — embedding-based (latent embedding + spatial coords)
spatial_mrre = lambda ad, fn: fn(ad.X, ad.obsm[_SPATIAL])
spatial_knn_overlap = lambda ad, fn: fn(ad.X, ad.obsm[_SPATIAL])
spatial_distance_correlation = lambda ad, fn: fn(ad.X, ad.obsm[_SPATIAL])
spatial_morans_i = lambda ad, fn: fn(ad.X, ad.obsm[_SPATIAL])


class Benchmarker:
Expand All @@ -113,6 +158,18 @@ class Benchmarker:
Specification of which bio conservation metrics to run in the pipeline.
batch_correction_metrics
Specification of which batch correction metrics to run in the pipeline.
spatial_conservation_metrics
Specification of which spatial conservation metrics to run in the pipeline.
Requires ``spatial_key`` to be set. MRRE, kNN overlap, distance correlation
and Moran's I compare each embedding against ``adata.obsm[spatial_key]``
to quantify spatial structure preservation.
spatial_key
Key in ``adata.obsm`` that contains the 2-D spatial coordinates (x, y).
Typically ``"spatial"`` (Squidpy / Scanpy convention).
When set, spatial conservation metrics are **automatically enabled**
using the default :class:`SpatialConservation` configuration unless
``spatial_conservation_metrics`` is explicitly passed as ``None``
to opt out. Spatial metrics are never computed when this is ``None``.
pre_integrated_embedding_obsm_key
Obsm key containing a non-integrated embedding of the data. If `None`, the embedding will be computed
in the prepare step. See the notes below for more information.
Expand Down Expand Up @@ -142,6 +199,9 @@ def __init__(
embedding_obsm_keys: list[str],
bio_conservation_metrics: BioConservation | None = BioConservation(),
batch_correction_metrics: BatchCorrection | None = BatchCorrection(),
spatial_conservation_metrics: SpatialConservation | None = _SPATIAL_UNSET, # type: ignore[assignment]
spatial_key: str | None = None,
spatial_conservation_weight: float = 0.0,
pre_integrated_embedding_obsm_key: str | None = None,
Comment thread
ori-kron-wis marked this conversation as resolved.
Outdated
n_jobs: int = 1,
progress_bar: bool = True,
Expand All @@ -152,6 +212,13 @@ def __init__(
self._pre_integrated_embedding_obsm_key = pre_integrated_embedding_obsm_key
self._bio_conservation_metrics = bio_conservation_metrics
self._batch_correction_metrics = batch_correction_metrics
# Auto-enable spatial metrics when spatial_key is provided and the
# caller has not explicitly set spatial_conservation_metrics.
if spatial_conservation_metrics is _SPATIAL_UNSET:
spatial_conservation_metrics = SpatialConservation() if spatial_key is not None else None
self._spatial_conservation_metrics = spatial_conservation_metrics
self._spatial_key = spatial_key
self._spatial_conservation_weight = spatial_conservation_weight
self._results = pd.DataFrame(columns=list(self._embedding_obsm_keys) + [_METRIC_TYPE])
self._emb_adatas = {}
self._neighbor_values = (15, 50, 90)
Expand All @@ -164,14 +231,32 @@ def __init__(
self._compute_neighbors = True
self._solver = solver

if self._bio_conservation_metrics is None and self._batch_correction_metrics is None:
raise ValueError("Either batch or bio metrics must be defined.")
if (
self._bio_conservation_metrics is None
and self._batch_correction_metrics is None
and self._spatial_conservation_metrics is None
):
raise ValueError("At least one of batch, bio, or spatial metrics must be defined.")

if self._spatial_conservation_metrics is not None and self._spatial_key is None:
raise ValueError(
"spatial_key must be provided when spatial_conservation_metrics is set. "
"Typically this is 'spatial' (adata.obsm['spatial'])."
)

if self._spatial_key is not None and self._spatial_key not in self._adata.obsm:
raise ValueError(
f"spatial_key '{self._spatial_key}' not found in adata.obsm. "
f"Available keys: {list(self._adata.obsm.keys())}"
)

self._metric_collection_dict = {}
if self._bio_conservation_metrics is not None:
self._metric_collection_dict.update({"Bio conservation": self._bio_conservation_metrics})
if self._batch_correction_metrics is not None:
self._metric_collection_dict.update({"Batch correction": self._batch_correction_metrics})
if self._spatial_conservation_metrics is not None:
self._metric_collection_dict.update({"Spatial conservation": self._spatial_conservation_metrics})

def prepare(self, neighbor_computer: Callable[[np.ndarray, int], NeighborsResults] | None = None) -> None:
"""Prepare the data for benchmarking.
Expand All @@ -198,6 +283,8 @@ def prepare(self, neighbor_computer: Callable[[np.ndarray, int], NeighborsResult
self._emb_adatas[emb_key].obs[_BATCH] = np.asarray(self._adata.obs[self._batch_key].values)
self._emb_adatas[emb_key].obs[_LABELS] = np.asarray(self._adata.obs[self._label_key].values)
self._emb_adatas[emb_key].obsm[_X_PRE] = self._adata.obsm[self._pre_integrated_embedding_obsm_key]
if self._spatial_key is not None:
self._emb_adatas[emb_key].obsm[_SPATIAL] = np.asarray(self._adata.obsm[self._spatial_key])

# Compute neighbors
if self._compute_neighbors:
Expand Down Expand Up @@ -296,13 +383,23 @@ def get_results(self, min_max_scale: bool = False, clean_names: bool = True) ->
df = df.transpose()
df[_METRIC_TYPE] = self._results[_METRIC_TYPE].values

# Compute scores
# Compute per-category aggregate scores
per_class_score = df.groupby(_METRIC_TYPE).mean().transpose()
# This is the default scIB weighting from the manuscript

# Build Total score. Weights follow the original scIB manuscript
# (0.4 batch + 0.6 bio). Spatial conservation is added with a
# placeholder weight that defaults to 0.0 so it does not affect the
# total until explicitly enabled by the user.
if self._batch_correction_metrics is not None and self._bio_conservation_metrics is not None:
per_class_score["Total"] = (
0.4 * per_class_score["Batch correction"] + 0.6 * per_class_score["Bio conservation"]
)
total = 0.4 * per_class_score["Batch correction"] + 0.6 * per_class_score["Bio conservation"]
if (
self._spatial_conservation_metrics is not None
and self._spatial_conservation_weight > 0.0
and "Spatial conservation" in per_class_score.columns
):
total = total + self._spatial_conservation_weight * per_class_score["Spatial conservation"]
per_class_score["Total"] = total

df = pd.concat([df.transpose(), per_class_score], axis=1)
df.loc[_METRIC_TYPE, per_class_score.columns] = _AGGREGATE_SCORE
return df
Expand All @@ -319,6 +416,12 @@ def plot_results_table(self, min_max_scale: bool = False, show: bool = True, sav
save_dir
The directory to save the plot to. If `None`, the plot is not saved.
"""

def _fmt(x: float) -> str:
"""Format to 2 d.p., mapping -0.00 → 0.00."""
v = round(float(x), 2)
return "0.00" if v == 0.0 else f"{v:.2f}"

num_embeds = len(self._embedding_obsm_keys)
cmap_fn = lambda col_data: normed_cmap(col_data, cmap=mpl.cm.PRGn, num_stds=2.5)
df = self.get_results(min_max_scale=min_max_scale)
Expand All @@ -329,8 +432,11 @@ def plot_results_table(self, min_max_scale: bool = False, show: bool = True, sav
sort_col = "Total"
elif self._batch_correction_metrics is not None:
sort_col = "Batch correction"
else:
elif self._bio_conservation_metrics is not None:
sort_col = "Bio conservation"
else:
# Only spatial conservation — no sensible ranking across embeddings
sort_col = plot_df.columns[0]
plot_df = plot_df.sort_values(by=sort_col, ascending=False).astype(np.float64)
plot_df["Method"] = plot_df.index

Expand All @@ -345,30 +451,30 @@ def plot_results_table(self, min_max_scale: bool = False, show: bool = True, sav
ColumnDefinition(
col,
title=col.replace(" ", "\n", 1),
width=1,
width=1.5,
textprops={
"ha": "center",
"bbox": {"boxstyle": "circle", "pad": 0.25},
},
cmap=cmap_fn(plot_df[col]),
group=df.loc[_METRIC_TYPE, col],
formatter="{:.2f}",
formatter=_fmt,
)
for i, col in enumerate(other_cols)
]
# Bars for the aggregate scores
column_definitions += [
ColumnDefinition(
col,
width=1,
width=1.5,
title=col.replace(" ", "\n", 1),
plot_fn=bar,
plot_kw={
"cmap": mpl.cm.YlGnBu,
"plot_bg_bar": False,
"annotate": True,
"height": 0.9,
"formatter": "{:.2f}",
"formatter": _fmt,
},
group=df.loc[_METRIC_TYPE, col],
border="left" if i == 0 else None,
Expand Down
10 changes: 10 additions & 0 deletions src/scib_metrics/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
from ._nmi_ari import nmi_ari_cluster_labels_kmeans, nmi_ari_cluster_labels_leiden
from ._pcr_comparison import pcr_comparison
from ._silhouette import bras, silhouette_batch, silhouette_label
from ._spatial import (
spatial_distance_correlation,
spatial_knn_overlap,
spatial_morans_i,
spatial_mrre,
)

__all__ = [
"isolated_labels",
Expand All @@ -20,4 +26,8 @@
"kbet",
"kbet_per_label",
"graph_connectivity",
"spatial_mrre",
"spatial_knn_overlap",
"spatial_distance_correlation",
"spatial_morans_i",
]
Loading