diff --git a/src/scib_metrics/__init__.py b/src/scib_metrics/__init__.py index cd12719a..f55fadf4 100644 --- a/src/scib_metrics/__init__.py +++ b/src/scib_metrics/__init__.py @@ -16,6 +16,13 @@ silhouette_batch, silhouette_label, bras, + spatial_mrre, + spatial_knn_overlap, + spatial_distance_correlation, + spatial_morans_i, + spatial_niche_knn_overlap, + spatial_pas, + spatial_chaos, ) from ._settings import settings @@ -35,6 +42,13 @@ "kbet", "kbet_per_label", "graph_connectivity", + "spatial_mrre", + "spatial_knn_overlap", + "spatial_distance_correlation", + "spatial_morans_i", + "spatial_niche_knn_overlap", + "spatial_pas", + "spatial_chaos", "settings", ] diff --git a/src/scib_metrics/benchmark/__init__.py b/src/scib_metrics/benchmark/__init__.py index be5bb7df..8a036473 100644 --- a/src/scib_metrics/benchmark/__init__.py +++ b/src/scib_metrics/benchmark/__init__.py @@ -1,3 +1,19 @@ -from ._core import BatchCorrection, Benchmarker, BioConservation +from ._core import ( + BatchCorrection, + Benchmarker, + BioConservation, + CoordinatePreservation, + DomainBoundary, + NichePreservation, + SpatialConservation, +) -__all__ = ["Benchmarker", "BioConservation", "BatchCorrection"] +__all__ = [ + "Benchmarker", + "BioConservation", + "BatchCorrection", + "CoordinatePreservation", + "NichePreservation", + "DomainBoundary", + "SpatialConservation", # backward-compat alias for CoordinatePreservation +] diff --git a/src/scib_metrics/benchmark/_core.py b/src/scib_metrics/benchmark/_core.py index 351153df..4a3ed2d7 100644 --- a/src/scib_metrics/benchmark/_core.py +++ b/src/scib_metrics/benchmark/_core.py @@ -25,11 +25,18 @@ Kwargs = dict[str, Any] MetricType = bool | Kwargs +# Sentinel used to detect when spatial_conservation_metrics was not explicitly +# passed so that it can be auto-derived from spatial_key. +_SPATIAL_UNSET: object = object() + _LABELS = "labels" _BATCH = "batch" _X_PRE = "X_pre" +_SPATIAL = "spatial" +_X_EXPR = "X_expr_pre" # pre-integrated PCA stored for niche metrics _METRIC_TYPE = "Metric Type" _AGGREGATE_SCORE = "Aggregate score" +_REGION_SCORE = "Region conservation" # Mapping of metric fn names to clean DataFrame column names metric_name_cleaner = { @@ -46,6 +53,16 @@ "bras": "BRAS", "graph_connectivity": "Graph connectivity", "pcr_comparison": "PCR comparison", + # Coordinate preservation + "spatial_mrre": "MRRE", + "spatial_knn_overlap": "kNN overlap", + "spatial_distance_correlation": "Distance corr.", + "spatial_morans_i": "Moran's I", + # Niche preservation + "spatial_niche_knn_overlap": "Niche kNN", + # Domain boundary + "spatial_pas": "PAS", + "spatial_chaos": "CHAOS", } @@ -81,6 +98,91 @@ class BatchCorrection: pcr_comparison: MetricType = True +@dataclass(frozen=True) +class CoordinatePreservation: + """Coordinate-preservation metrics: does the latent reproduce XY geometry. + + These metrics compare each model's latent representation directly against + the physical spot coordinates. They are most meaningful for spatial graph + autoencoders (STAGATE-style) whose latent is explicitly trained as a + surrogate for tissue coordinates. + + For models like scVIVA, resolVI, gimVI, or DestVI — where the latent + captures expression state, niche structure, denoising, or deconvolution — + these metrics measure a property the model was not trained to optimise. + Use :class:`NichePreservation` and :class:`DomainBoundary` instead to + assess those models on their intended objectives. + + All scores are in ``[0, 1]`` with higher better. + + Metrics (all higher = better): + + * ``spatial_mrre`` — 1 minus normalised mean relative rank error of + spatial neighbours in the latent space. + * ``spatial_knn_overlap`` — chance-normalised overlap of spatial vs. + latent k-NN sets per spot. + * ``spatial_distance_correlation`` — Spearman correlation of pairwise + spatial vs. latent distances, rescaled to ``[0, 1]``. + * ``spatial_morans_i`` — mean Moran's I of latent dimensions using the + spatial weight graph, rescaled to ``[0, 1]``. + """ + + spatial_mrre: MetricType = True + spatial_knn_overlap: MetricType = True + spatial_distance_correlation: MetricType = True + spatial_morans_i: MetricType = True + + +# Backward-compatible alias — ``SpatialConservation`` maps to the coordinate- +# preservation axis so that existing code continues to work unchanged. +SpatialConservation = CoordinatePreservation + + +@dataclass(frozen=True) +class NichePreservation: + """Niche-preservation metrics: does the latent capture microenvironment. + + Asks whether cells that share a similar local microenvironment (similar + average expression of spatial neighbours) are also close in latent space. + This is a direct measure of what niche-aware models such as scVIVA are + trained to do, and provides a complementary axis to raw coordinate + preservation. + + Niche features are computed as the mean pre-integrated embedding (PCA) + of spatial neighbours, keeping them low-dimensional and independent of + the model being evaluated. + + Metrics (all higher = better): + + * ``spatial_niche_knn_overlap`` — chance-normalised overlap of niche- + feature k-NN vs. latent k-NN per spot. + """ + + spatial_niche_knn_overlap: MetricType = True + + +@dataclass(frozen=True) +class DomainBoundary: + """Domain-boundary metrics: do latent clusters align with tissue domains. + + Clusters the latent embedding with k-means and measures how spatially + coherent the resulting domains are. High scores indicate that + latent-derived clusters occupy compact, contiguous patches of tissue + rather than scattered fragments — directly relevant for models aimed at + spatial domain identification. + + Metrics (all higher = better): + + * ``spatial_pas`` — 1 minus Proportion of Abnormal Spots; fraction of + spatial neighbours in the same cluster (higher = more coherent). + * ``spatial_chaos`` — 1 minus normalised mean intra-cluster spatial + distance (higher = more spatially compact clusters). + """ + + spatial_pas: MetricType = True + spatial_chaos: MetricType = True + + class MetricAnnDataAPI(Enum): """Specification of the AnnData API for a metric.""" @@ -94,6 +196,16 @@ class MetricAnnDataAPI(Enum): pcr_comparison = lambda ad, fn: fn(ad.obsm[_X_PRE], ad.X, ad.obs[_BATCH], categorical=True) ilisi_knn = lambda ad, fn: fn(ad.uns["90_neighbor_res"], ad.obs[_BATCH]) kbet_per_label = lambda ad, fn: fn(ad.uns["50_neighbor_res"], ad.obs[_BATCH], ad.obs[_LABELS]) + # Coordinate preservation — latent embedding vs physical XY coordinates + spatial_mrre = lambda ad, fn: fn(ad.X, ad.obsm[_SPATIAL]) + spatial_knn_overlap = lambda ad, fn: fn(ad.X, ad.obsm[_SPATIAL]) + spatial_distance_correlation = lambda ad, fn: fn(ad.X, ad.obsm[_SPATIAL]) + spatial_morans_i = lambda ad, fn: fn(ad.X, ad.obsm[_SPATIAL]) + # Niche preservation — latent kNN vs niche-feature kNN (pre-integrated PCA) + spatial_niche_knn_overlap = lambda ad, fn: fn(ad.X, ad.obsm[_SPATIAL], ad.obsm.get(_X_EXPR)) + # Domain boundary — derived from k-means clustering of latent embedding + spatial_pas = lambda ad, fn: fn(ad.X, ad.obsm[_SPATIAL]) + spatial_chaos = lambda ad, fn: fn(ad.X, ad.obsm[_SPATIAL]) class Benchmarker: @@ -113,6 +225,27 @@ class Benchmarker: Specification of which bio conservation metrics to run in the pipeline. batch_correction_metrics Specification of which batch correction metrics to run in the pipeline. + spatial_conservation_metrics + Specification of which spatial conservation metrics to run in the pipeline. + Requires ``spatial_key`` to be set. MRRE, kNN overlap, distance correlation + and Moran's I compare each embedding against ``adata.obsm[spatial_key]`` + to quantify spatial structure preservation. + spatial_key + Key in ``adata.obsm`` that contains the 2-D spatial coordinates (x, y). + Typically ``"spatial"`` (Squidpy / Scanpy convention). + When set, spatial conservation metrics are **automatically enabled** + using the default :class:`SpatialConservation` configuration unless + ``spatial_conservation_metrics`` is explicitly passed as ``None`` + to opt out. Spatial metrics are never computed when this is ``None``. + region_key + Optional key in ``adata.obs`` containing spatial region labels (e.g. + brain region, tissue zone). When provided, bio-conservation and + batch-correction metrics are re-computed **for each unique value in** + ``label_key`` (typically cell type) using ``region_key`` as the label, + then averaged across label values. Results appear as + "Region bio conservation" and "Region batch correction" aggregate + columns. This quantifies how well embeddings preserve regional + identity *within* each cell type. pre_integrated_embedding_obsm_key Obsm key containing a non-integrated embedding of the data. If `None`, the embedding will be computed in the prepare step. See the notes below for more information. @@ -146,12 +279,27 @@ def __init__( n_jobs: int = 1, progress_bar: bool = True, solver: str = "arpack", + spatial_conservation_metrics: CoordinatePreservation | None = _SPATIAL_UNSET, # type: ignore[assignment] + spatial_key: str | None = None, + spatial_conservation_weight: float = 0.0, + niche_preservation: NichePreservation | None = None, + domain_boundary: DomainBoundary | None = None, + region_key: str | None = None, ): self._adata = adata self._embedding_obsm_keys = embedding_obsm_keys self._pre_integrated_embedding_obsm_key = pre_integrated_embedding_obsm_key self._bio_conservation_metrics = bio_conservation_metrics self._batch_correction_metrics = batch_correction_metrics + # Auto-enable coordinate preservation when spatial_key is provided and the + # caller has not explicitly set spatial_conservation_metrics. + if spatial_conservation_metrics is _SPATIAL_UNSET: + spatial_conservation_metrics = CoordinatePreservation() if spatial_key is not None else None + self._spatial_conservation_metrics = spatial_conservation_metrics + self._niche_preservation = niche_preservation + self._domain_boundary = domain_boundary + self._spatial_key = spatial_key + self._spatial_conservation_weight = spatial_conservation_weight self._results = pd.DataFrame(columns=list(self._embedding_obsm_keys) + [_METRIC_TYPE]) self._emb_adatas = {} self._neighbor_values = (15, 50, 90) @@ -163,15 +311,56 @@ def __init__( self._progress_bar = progress_bar self._compute_neighbors = True self._solver = solver + self._region_key = region_key + if self._region_key is not None and self._region_key not in self._adata.obs.columns: + raise ValueError( + f"region_key '{self._region_key}' not found in adata.obs. " + f"Available columns: {list(self._adata.obs.columns)}" + ) + self._region_avg: dict[str, dict[str, float]] = {} + self._region_emb_adatas: dict[str, dict[str, AnnData]] = {} + + _any_spatial = ( + self._spatial_conservation_metrics is not None + or self._niche_preservation is not None + or self._domain_boundary is not None + ) + if self._bio_conservation_metrics is None and self._batch_correction_metrics is None and not _any_spatial: + raise ValueError("At least one of batch, bio, or spatial metrics must be defined.") + + if self._spatial_conservation_metrics is not None and self._spatial_key is None: + raise ValueError( + "spatial_key must be provided when spatial_conservation_metrics is set. " + "Typically this is 'spatial' (adata.obsm['spatial'])." + ) + if self._niche_preservation is not None and self._spatial_key is None: + raise ValueError( + "spatial_key must be provided when niche_preservation is set. " + "Typically this is 'spatial' (adata.obsm['spatial'])." + ) + if self._domain_boundary is not None and self._spatial_key is None: + raise ValueError( + "spatial_key must be provided when domain_boundary is set. " + "Typically this is 'spatial' (adata.obsm['spatial'])." + ) - if self._bio_conservation_metrics is None and self._batch_correction_metrics is None: - raise ValueError("Either batch or bio metrics must be defined.") + if self._spatial_key is not None and self._spatial_key not in self._adata.obsm: + raise ValueError( + f"spatial_key '{self._spatial_key}' not found in adata.obsm. " + f"Available keys: {list(self._adata.obsm.keys())}" + ) self._metric_collection_dict = {} if self._bio_conservation_metrics is not None: self._metric_collection_dict.update({"Bio conservation": self._bio_conservation_metrics}) if self._batch_correction_metrics is not None: self._metric_collection_dict.update({"Batch correction": self._batch_correction_metrics}) + if self._spatial_conservation_metrics is not None: + self._metric_collection_dict.update({"Coordinate preservation": self._spatial_conservation_metrics}) + if self._niche_preservation is not None: + self._metric_collection_dict.update({"Niche preservation": self._niche_preservation}) + if self._domain_boundary is not None: + self._metric_collection_dict.update({"Domain boundary": self._domain_boundary}) def prepare(self, neighbor_computer: Callable[[np.ndarray, int], NeighborsResults] | None = None) -> None: """Prepare the data for benchmarking. @@ -198,6 +387,13 @@ def prepare(self, neighbor_computer: Callable[[np.ndarray, int], NeighborsResult self._emb_adatas[emb_key].obs[_BATCH] = np.asarray(self._adata.obs[self._batch_key].values) self._emb_adatas[emb_key].obs[_LABELS] = np.asarray(self._adata.obs[self._label_key].values) self._emb_adatas[emb_key].obsm[_X_PRE] = self._adata.obsm[self._pre_integrated_embedding_obsm_key] + if self._spatial_key is not None: + self._emb_adatas[emb_key].obsm[_SPATIAL] = np.asarray(self._adata.obsm[self._spatial_key]) + # Store pre-integrated embedding as niche feature proxy for + # spatial_niche_knn_overlap; set after PCA so the key exists. + self._emb_adatas[emb_key].obsm[_X_EXPR] = np.asarray( + self._adata.obsm[self._pre_integrated_embedding_obsm_key] + ) # Compute neighbors if self._compute_neighbors: @@ -220,8 +416,112 @@ def prepare(self, neighbor_computer: Callable[[np.ndarray, int], NeighborsResult UserWarning, ) + if self._region_key is not None: + self._prepare_region_adatas(neighbor_computer) + self._prepared = True + def _prepare_region_adatas( + self, + neighbor_computer: Callable[[np.ndarray, int], NeighborsResults] | None = None, + ) -> None: + """Build per-label-value AnnData subsets with ``region_key`` as the label. + + For each unique value in ``label_key`` we subset the data, swap + ``_LABELS`` for the region column, and pre-compute neighbor graphs so + that the full bio/batch metric suite can run on each subset. + + Subsets with fewer than ``max(neighbor_values) + 1`` cells are skipped. + """ + min_cells = max(self._neighbor_values) + 1 + self._region_emb_adatas = {ek: {} for ek in self._embedding_obsm_keys} + + label_vals = sorted(self._adata.obs[self._label_key].unique()) + progress = label_vals + if self._progress_bar: + progress = tqdm(label_vals, desc="Region subsets") + + for label_val in progress: + mask = (self._adata.obs[self._label_key] == label_val).values + n_cells = int(mask.sum()) + if n_cells < min_cells: + warnings.warn( + f"Skipping region scoring for label '{label_val}': only {n_cells} cells (need >= {min_cells}).", + UserWarning, + stacklevel=2, + ) + continue + + sub = self._adata[mask] + k = min(max(self._neighbor_values), n_cells - 1) + + for emb_key in self._embedding_obsm_keys: + ad = AnnData(np.asarray(sub.obsm[emb_key]), obs=sub.obs.copy()) + ad.obs[_BATCH] = np.asarray(sub.obs[self._batch_key].values) + ad.obs[_LABELS] = np.asarray(sub.obs[self._region_key].values) + ad.obsm[_X_PRE] = np.asarray(sub.obsm[self._pre_integrated_embedding_obsm_key]) + if self._spatial_key is not None: + ad.obsm[_SPATIAL] = np.asarray(sub.obsm[self._spatial_key]) + ad.obsm[_X_EXPR] = np.asarray(sub.obsm[self._pre_integrated_embedding_obsm_key]) + + if neighbor_computer is not None: + neigh_result = neighbor_computer(ad.X, k) + else: + neigh_result = pynndescent(ad.X, n_neighbors=k, random_state=0, n_jobs=self._n_jobs) + for n in self._neighbor_values: + ad.uns[f"{n}_neighbor_res"] = neigh_result.subset_neighbors(n=min(n, k)) + + self._region_emb_adatas[emb_key][label_val] = ad + + def _benchmark_region(self) -> None: + """Run bio/batch metrics per label-value subset (region as label), then average.""" + region_metric_dict = { + k: v for k, v in self._metric_collection_dict.items() if k in ("Bio conservation", "Batch correction") + } + if not region_metric_dict or not self._region_emb_adatas: + return + + # metric_name → metric_type (handles nmi/ari split keys too) + metric_to_type: dict[str, str] = {} + for mt, mc in region_metric_dict.items(): + for mn in asdict(mc): + metric_to_type[mn] = mt + metric_to_type[f"{mn}_nmi"] = mt + metric_to_type[f"{mn}_ari"] = mt + + # scores[emb_key][metric_name] = [score_label1, score_label2, ...] + scores: dict[str, dict[str, list[float]]] = {ek: {} for ek in self._embedding_obsm_keys} + + for emb_key, label_adatas in self._region_emb_adatas.items(): + for label_val, ad in label_adatas.items(): + for _metric_type, metric_collection in region_metric_dict.items(): + for metric_name, use_metric_or_kwargs in asdict(metric_collection).items(): + if not use_metric_or_kwargs: + continue + try: + metric_fn = getattr(scib_metrics, metric_name) + if isinstance(use_metric_or_kwargs, dict): + metric_fn = partial(metric_fn, **use_metric_or_kwargs) + metric_value = getattr(MetricAnnDataAPI, metric_name)(ad, metric_fn) + except (ValueError, KeyError, RuntimeError, AttributeError) as exc: + warnings.warn( + f"Region metric '{metric_name}' failed for label '{label_val}' " + f"(emb '{emb_key}'): {exc}", + UserWarning, + stacklevel=2, + ) + continue + + if isinstance(metric_value, dict): + for k, v in metric_value.items(): + scores[emb_key].setdefault(f"{metric_name}_{k}", []).append(float(v)) + else: + scores[emb_key].setdefault(metric_name, []).append(float(metric_value)) + + self._region_avg = { + ek: {mn: float(np.nanmean(vals)) for mn, vals in m.items() if vals} for ek, m in scores.items() + } + def benchmark(self) -> None: """Run the pipeline.""" if self._benchmarked: @@ -265,6 +565,9 @@ def benchmark(self) -> None: self._results.loc[metric_name, _METRIC_TYPE] = metric_type pbar.update(1) if pbar is not None else None + if self._region_key: + self._benchmark_region() + self._benchmarked = True def get_results(self, min_max_scale: bool = False, clean_names: bool = True) -> pd.DataFrame: @@ -296,15 +599,53 @@ def get_results(self, min_max_scale: bool = False, clean_names: bool = True) -> df = df.transpose() df[_METRIC_TYPE] = self._results[_METRIC_TYPE].values - # Compute scores + # Compute per-category aggregate scores per_class_score = df.groupby(_METRIC_TYPE).mean().transpose() - # This is the default scIB weighting from the manuscript + + # Add region-aware aggregate columns when region_key was used + if self._region_key and self._region_avg: + # Build metric_name → metric_type map (including nmi/ari split keys) + _metric_to_type: dict[str, str] = {} + for _mt, _mc in self._metric_collection_dict.items(): + if _mt in ("Bio conservation", "Batch correction"): + for _mn in asdict(_mc): + _metric_to_type[_mn] = _mt + _metric_to_type[f"{_mn}_nmi"] = _mt + _metric_to_type[f"{_mn}_ari"] = _mt + + for emb_key in self._embedding_obsm_keys: + ravg = self._region_avg.get(emb_key, {}) + _type_vals: dict[str, list[float]] = {} + for mn, v in ravg.items(): + col = f"Region {_metric_to_type[mn]}" if mn in _metric_to_type else None + if col: + _type_vals.setdefault(col, []).append(v) + for col, vals in _type_vals.items(): + per_class_score.loc[emb_key, col] = float(np.mean(vals)) + + # Build Total score. Weights follow the original scIB manuscript + # (0.4 batch + 0.6 bio). Spatial axes are averaged across all enabled + # spatial groups and added with spatial_conservation_weight (default 0.0 + # so they do not affect Total unless explicitly enabled). if self._batch_correction_metrics is not None and self._bio_conservation_metrics is not None: - per_class_score["Total"] = ( - 0.4 * per_class_score["Batch correction"] + 0.6 * per_class_score["Bio conservation"] - ) + total = 0.4 * per_class_score["Batch correction"] + 0.6 * per_class_score["Bio conservation"] + if self._spatial_conservation_weight > 0.0: + _spatial_groups = [ + g + for g in ("Coordinate preservation", "Niche preservation", "Domain boundary") + if g in per_class_score.columns + ] + if _spatial_groups: + spatial_mean = sum(per_class_score[g] for g in _spatial_groups) / len(_spatial_groups) + total = total + self._spatial_conservation_weight * spatial_mean + per_class_score["Total"] = total + df = pd.concat([df.transpose(), per_class_score], axis=1) df.loc[_METRIC_TYPE, per_class_score.columns] = _AGGREGATE_SCORE + # Region aggregate columns get their own group header in the plot + _region_agg_cols = [c for c in per_class_score.columns if str(c).startswith("Region ")] + if _region_agg_cols: + df.loc[_METRIC_TYPE, _region_agg_cols] = _REGION_SCORE return df def plot_results_table(self, min_max_scale: bool = False, show: bool = True, save_dir: str | None = None) -> Table: @@ -319,6 +660,12 @@ def plot_results_table(self, min_max_scale: bool = False, show: bool = True, sav save_dir The directory to save the plot to. If `None`, the plot is not saved. """ + + def _fmt(x: float) -> str: + """Format to 2 d.p., mapping -0.00 → 0.00.""" + v = round(float(x), 2) + return "0.00" if v == 0.0 else f"{v:.2f}" + num_embeds = len(self._embedding_obsm_keys) cmap_fn = lambda col_data: normed_cmap(col_data, cmap=mpl.cm.PRGn, num_stds=2.5) df = self.get_results(min_max_scale=min_max_scale) @@ -329,14 +676,19 @@ def plot_results_table(self, min_max_scale: bool = False, show: bool = True, sav sort_col = "Total" elif self._batch_correction_metrics is not None: sort_col = "Batch correction" - else: + elif self._bio_conservation_metrics is not None: sort_col = "Bio conservation" + else: + # Only spatial conservation — no sensible ranking across embeddings + sort_col = plot_df.columns[0] plot_df = plot_df.sort_values(by=sort_col, ascending=False).astype(np.float64) plot_df["Method"] = plot_df.index # Split columns by metric type, using df as it doesn't have the new method col - score_cols = df.columns[df.loc[_METRIC_TYPE] == _AGGREGATE_SCORE] - other_cols = df.columns[df.loc[_METRIC_TYPE] != _AGGREGATE_SCORE] + # Both aggregate and region-conservation columns are shown as bar charts + _bar_types = {_AGGREGATE_SCORE, _REGION_SCORE} + score_cols = df.columns[df.loc[_METRIC_TYPE].isin(_bar_types)] + other_cols = df.columns[~df.loc[_METRIC_TYPE].isin(_bar_types)] column_definitions = [ ColumnDefinition("Method", width=1.5, textprops={"ha": "left", "weight": "bold"}), ] @@ -345,14 +697,14 @@ def plot_results_table(self, min_max_scale: bool = False, show: bool = True, sav ColumnDefinition( col, title=col.replace(" ", "\n", 1), - width=1, + width=1.5, textprops={ "ha": "center", "bbox": {"boxstyle": "circle", "pad": 0.25}, }, cmap=cmap_fn(plot_df[col]), group=df.loc[_METRIC_TYPE, col], - formatter="{:.2f}", + formatter=_fmt, ) for i, col in enumerate(other_cols) ] @@ -360,7 +712,7 @@ def plot_results_table(self, min_max_scale: bool = False, show: bool = True, sav column_definitions += [ ColumnDefinition( col, - width=1, + width=1.5, title=col.replace(" ", "\n", 1), plot_fn=bar, plot_kw={ @@ -368,7 +720,7 @@ def plot_results_table(self, min_max_scale: bool = False, show: bool = True, sav "plot_bg_bar": False, "annotate": True, "height": 0.9, - "formatter": "{:.2f}", + "formatter": _fmt, }, group=df.loc[_METRIC_TYPE, col], border="left" if i == 0 else None, diff --git a/src/scib_metrics/metrics/__init__.py b/src/scib_metrics/metrics/__init__.py index 2071c89f..38d52ca5 100644 --- a/src/scib_metrics/metrics/__init__.py +++ b/src/scib_metrics/metrics/__init__.py @@ -5,6 +5,15 @@ from ._nmi_ari import nmi_ari_cluster_labels_kmeans, nmi_ari_cluster_labels_leiden from ._pcr_comparison import pcr_comparison from ._silhouette import bras, silhouette_batch, silhouette_label +from ._spatial import ( + spatial_chaos, + spatial_distance_correlation, + spatial_knn_overlap, + spatial_morans_i, + spatial_mrre, + spatial_niche_knn_overlap, + spatial_pas, +) __all__ = [ "isolated_labels", @@ -20,4 +29,11 @@ "kbet", "kbet_per_label", "graph_connectivity", + "spatial_mrre", + "spatial_knn_overlap", + "spatial_distance_correlation", + "spatial_morans_i", + "spatial_niche_knn_overlap", + "spatial_pas", + "spatial_chaos", ] diff --git a/src/scib_metrics/metrics/_spatial.py b/src/scib_metrics/metrics/_spatial.py new file mode 100644 index 00000000..b9843169 --- /dev/null +++ b/src/scib_metrics/metrics/_spatial.py @@ -0,0 +1,559 @@ +"""Spatial transcriptomics metrics for scib-metrics. + +Metrics are organised into three conceptual axes: + +**Coordinate preservation** (``spatial_mrre``, ``spatial_knn_overlap``, +``spatial_distance_correlation``, ``spatial_morans_i``) — asks whether the +latent embedding reproduces the physical XY geometry of spots. Appropriate +for spatial graph autoencoders (STAGATE-style) where the latent is explicitly +trained to be a surrogate for tissue coordinates. + +**Niche preservation** (``spatial_niche_knn_overlap``) — asks whether cells +that share a similar local microenvironment (similar average expression of +spatial neighbours) are also close in latent space. This is the primary +objective of models like scVIVA and other niche-aware methods. + +**Domain boundary faithfulness** (``spatial_pas``, ``spatial_chaos``) — asks +whether clusters derived from the latent space are spatially coherent. +PAS (Proportion of Abnormal Spots) and CHAOS measure how well latent-derived +domains align with tissue boundaries, which is relevant for models aimed at +spatial domain identification. + +All functions return a float in **[0, 1]** where **higher is always better**. + +References +---------- +Hu et al. (2024) Benchmarking clustering, alignment, and integration + methods for spatial transcriptomics. PMC11312151. +Chen et al. (2025) A comprehensive benchmarking for spatially resolved + transcriptomics clustering methods. PMC12747554. +MuST (2023) Multi-modal spatial transcriptomics benchmark. +""" + +import warnings + +import numpy as np +from scipy.spatial.distance import pdist +from scipy.stats import ConstantInputWarning, spearmanr +from sklearn.cluster import KMeans +from sklearn.neighbors import NearestNeighbors + + +def spatial_mrre( + X_embedding: np.ndarray, + spatial_coords: np.ndarray, + k: int = 15, + max_cells: int = 2000, + seed: int = 42, +) -> float: + """Mean Relative Rank Error (MRRE), normalised to [0, 1]. + + For each spot, finds its ``k`` nearest spatial neighbours and compares + their rank ordering in spatial space to their rank ordering in the latent + embedding. The mean absolute rank difference, normalised by ``k``, + measures how much the local geometry is distorted. The score is + ``1 - MRRE/k`` so that perfect rank preservation yields 1. + + Parameters + ---------- + X_embedding + Array of shape ``(n_spots, n_dims)`` — latent representation. + spatial_coords + Array of shape ``(n_spots, 2)`` with spatial coordinates (x, y). + k + Neighbourhood size. Default ``15``. + max_cells + Subsample to this many cells before computation (O(n·k) cost). + Default ``2000``. + seed + Random seed for subsampling. Default ``42``. + + Returns + ------- + float + Score in ``[0, 1]``. **Higher is better** (better rank preservation). + + References + ---------- + MuST benchmark (2023). Lähnemann et al. (2020). + """ + X_embedding = np.asarray(X_embedding, dtype=float) + spatial_coords = np.asarray(spatial_coords, dtype=float) + n = len(X_embedding) + + if n > max_cells: + rng = np.random.default_rng(seed) + idx = rng.choice(n, max_cells, replace=False) + X_embedding = X_embedding[idx] + spatial_coords = spatial_coords[idx] + n = max_cells + + k = min(k, n - 1) + if k == 0: + return 1.0 + kp1 = k + 1 + + nn_s = NearestNeighbors(n_neighbors=kp1, algorithm="kd_tree").fit(spatial_coords) + _, s_inds = nn_s.kneighbors(spatial_coords) + s_inds = s_inds[:, 1:] # (n, k) + + nn_l = NearestNeighbors(n_neighbors=kp1, algorithm="kd_tree").fit(X_embedding) + _, l_inds = nn_l.kneighbors(X_embedding) + l_inds = l_inds[:, 1:] # (n, k) + + total_error = 0.0 + for i in range(n): + rank_lookup = np.full(n, k, dtype=np.int32) + rank_lookup[l_inds[i]] = np.arange(k, dtype=np.int32) + lat_ranks = rank_lookup[s_inds[i]] + total_error += float(np.sum(np.abs(np.arange(k) - lat_ranks))) + + mrre = total_error / (n * k) + return float(max(0.0, 1.0 - mrre / k)) + + +def spatial_knn_overlap( + X_embedding: np.ndarray, + spatial_coords: np.ndarray, + k: int = 15, + max_cells: int = 2000, + seed: int = 42, +) -> float: + """k-NN overlap score between spatial and latent neighbourhoods. + + For each spot, computes the fraction of its ``k`` spatial nearest + neighbours that are also among its ``k`` latent nearest neighbours. + The mean over all spots gives an intuitive measure of local geometry + preservation. + + Parameters + ---------- + X_embedding + Array of shape ``(n_spots, n_dims)`` — latent representation. + spatial_coords + Array of shape ``(n_spots, 2)`` with spatial coordinates (x, y). + k + Neighbourhood size. Default ``15``. + max_cells + Subsample to this many cells before computation. Default ``2000``. + seed + Random seed for subsampling. Default ``42``. + + Returns + ------- + float + Score in ``[0, 1]``. **Higher is better** (more neighbours shared). + + References + ---------- + MuST benchmark (2023). + """ + X_embedding = np.asarray(X_embedding, dtype=float) + spatial_coords = np.asarray(spatial_coords, dtype=float) + n = len(X_embedding) + + if n > max_cells: + rng = np.random.default_rng(seed) + idx = rng.choice(n, max_cells, replace=False) + X_embedding = X_embedding[idx] + spatial_coords = spatial_coords[idx] + n = max_cells + + k = min(k, n - 1) + if k == 0: + return 1.0 + kp1 = k + 1 + + nn_s = NearestNeighbors(n_neighbors=kp1, algorithm="kd_tree").fit(spatial_coords) + _, s_inds = nn_s.kneighbors(spatial_coords) + s_inds = s_inds[:, 1:] + + nn_l = NearestNeighbors(n_neighbors=kp1, algorithm="kd_tree").fit(X_embedding) + _, l_inds = nn_l.kneighbors(X_embedding) + l_inds = l_inds[:, 1:] + + overlaps = np.array([np.sum(np.isin(s_inds[i], l_inds[i])) / k for i in range(n)]) + raw = float(np.mean(overlaps)) + + # Normalise against the random-chance baseline: for n spots and k neighbours + # a random embedding shares k/(n-1) neighbours on average. Rescale so that + # random → 0 and perfect overlap → 1, then clip to [0, 1]. + chance = k / (n - 1) + if chance >= 1.0: + return raw + return float(np.clip((raw - chance) / (1.0 - chance), 0.0, 1.0)) + + +def spatial_distance_correlation( + X_embedding: np.ndarray, + spatial_coords: np.ndarray, + max_cells: int = 1000, + seed: int = 42, +) -> float: + """Spearman correlation of pairwise distances, rescaled to [0, 1]. + + Computes pairwise Euclidean distances in spatial coordinate space and in + the latent embedding, then measures their Spearman rank correlation. + A high correlation indicates that the embedding preserves the global + spatial distance structure. + + The Spearman correlation (in ``[-1, 1]``) is rescaled to ``[0, 1]`` via + ``(r + 1) / 2``. + + Parameters + ---------- + X_embedding + Array of shape ``(n_spots, n_dims)`` — latent representation. + spatial_coords + Array of shape ``(n_spots, 2)`` with spatial coordinates (x, y). + max_cells + Subsample to this many cells before computation (O(n²) cost). + Default ``1000``. + seed + Random seed for subsampling. Default ``42``. + + Returns + ------- + float + Score in ``[0, 1]``. **Higher is better** (better global preservation). + """ + X_embedding = np.asarray(X_embedding, dtype=float) + spatial_coords = np.asarray(spatial_coords, dtype=float) + n = len(X_embedding) + + if n > max_cells: + rng = np.random.default_rng(seed) + idx = rng.choice(n, max_cells, replace=False) + X_embedding = X_embedding[idx] + spatial_coords = spatial_coords[idx] + + sp_dists = pdist(spatial_coords, metric="euclidean") + lat_dists = pdist(X_embedding, metric="euclidean") + + if len(sp_dists) < 2: + return 1.0 + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", ConstantInputWarning) + corr, _ = spearmanr(sp_dists, lat_dists) + corr = 0.0 if np.isnan(corr) else float(corr) + return float((corr + 1.0) / 2.0) + + +def spatial_morans_i( + X_embedding: np.ndarray, + spatial_coords: np.ndarray, + n_neighbors: int = 6, +) -> float: + """Mean Moran's I of latent dimensions, rescaled to [0, 1]. + + Computes Moran's I spatial autocorrelation statistic for each latent + dimension using a row-standardised k-NN spatial weight matrix. Positive + Moran's I indicates that spatially adjacent spots have similar values in + that latent dimension (smooth spatial variation). The mean across all + latent dimensions is rescaled from ``[-1, 1]`` to ``[0, 1]``. + + Parameters + ---------- + X_embedding + Array of shape ``(n_spots, n_dims)`` — latent representation. + spatial_coords + Array of shape ``(n_spots, 2)`` with spatial coordinates (x, y). + n_neighbors + Number of spatial neighbours for the weight matrix. Default ``6``. + + Returns + ------- + float + Score in ``[0, 1]``. **Higher is better** (more spatial autocorrelation + in the latent space, indicating the model captures spatial patterns). + """ + X_embedding = np.asarray(X_embedding, dtype=float) + spatial_coords = np.asarray(spatial_coords, dtype=float) + n = len(X_embedding) + + k = min(n_neighbors, n - 1) + nn = NearestNeighbors(n_neighbors=k + 1, algorithm="kd_tree").fit(spatial_coords) + _, inds = nn.kneighbors(spatial_coords) + inds = inds[:, 1:] # (n, k), exclude self + + # Vectorised Moran's I for all latent dimensions simultaneously. + # Using row-standardised weights: W_sum = n, so I = numerator / denominator. + X_c = X_embedding - X_embedding.mean(axis=0) # (n, d) + spatial_lag = X_c[inds].mean(axis=1) # (n, d) + numerator = np.sum(X_c * spatial_lag, axis=0) # (d,) + denominator = np.sum(X_c**2, axis=0) # (d,) + I_per_dim = np.where(denominator > 0, numerator / denominator, 0.0) + + mean_I = float(np.mean(I_per_dim)) + return float((mean_I + 1.0) / 2.0) + + +# --------------------------------------------------------------------------- +# Niche preservation +# --------------------------------------------------------------------------- + + +def spatial_niche_knn_overlap( + X_embedding: np.ndarray, + spatial_coords: np.ndarray, + X_expression: np.ndarray | None = None, + k: int = 15, + k_spatial: int = 6, + max_cells: int = 2000, + seed: int = 42, +) -> float: + """k-NN overlap between latent space and spatial niche feature space. + + For each spot, the *niche feature* is the mean embedding (or expression) + of its ``k_spatial`` spatial neighbours. This vector encodes the local + microenvironment rather than the spot's own position. The score is the + chance-normalised fraction of latent k-NN that are also niche-feature + k-NN — analogous to :func:`spatial_knn_overlap` but using niche + descriptors as the reference instead of raw XY coordinates. + + Unlike coordinate-preservation metrics, this rewards models that learn + to represent *who shares the same microenvironment*, which is the + primary objective of niche-aware methods such as scVIVA. + + Parameters + ---------- + X_embedding + Array of shape ``(n_spots, n_dims)`` — latent representation. + spatial_coords + Array of shape ``(n_spots, 2)`` with spatial coordinates (x, y). + X_expression + Array of shape ``(n_spots, n_features)`` used to build niche + descriptors. Typically the pre-integrated PCA embedding so that + niche features remain low-dimensional. If ``None``, falls back to + averaging spatial coordinates of neighbours, which degrades the + metric to a coarser spatial-proximity measure. + k + Neighbourhood size for latent and niche kNN. Default ``15``. + k_spatial + Number of spatial neighbours used to aggregate the niche descriptor. + Default ``6``. + max_cells + Subsample to this many cells before computation. Default ``2000``. + seed + Random seed for subsampling. Default ``42``. + + Returns + ------- + float + Score in ``[0, 1]``. **Higher is better** (latent neighbours match + niche-feature neighbours more often than chance). + + References + ---------- + Inspired by scVIVA (Boyeau et al.) and COVET/ENVI niche representations. + """ + X_embedding = np.asarray(X_embedding, dtype=float) + spatial_coords = np.asarray(spatial_coords, dtype=float) + n = len(X_embedding) + + if n > max_cells: + rng = np.random.default_rng(seed) + idx = rng.choice(n, max_cells, replace=False) + X_embedding = X_embedding[idx] + spatial_coords = spatial_coords[idx] + if X_expression is not None: + X_expression = np.asarray(X_expression, dtype=float)[idx] + n = max_cells + + # ── build niche descriptors ────────────────────────────────────────────── + k_sp = min(k_spatial, n - 1) + nn_sp = NearestNeighbors(n_neighbors=k_sp + 1, algorithm="kd_tree").fit(spatial_coords) + _, sp_inds = nn_sp.kneighbors(spatial_coords) + sp_inds = sp_inds[:, 1:] # exclude self + + if X_expression is not None: + X_expr = np.asarray(X_expression, dtype=float) + niche_features = X_expr[sp_inds].mean(axis=1) # (n, n_features) + else: + niche_features = spatial_coords[sp_inds].mean(axis=1) # (n, 2) + + # ── kNN in latent and niche spaces ─────────────────────────────────────── + k = min(k, n - 1) + if k == 0: + return 1.0 + kp1 = k + 1 + + nn_lat = NearestNeighbors(n_neighbors=kp1, algorithm="auto").fit(X_embedding) + _, lat_inds = nn_lat.kneighbors(X_embedding) + lat_inds = lat_inds[:, 1:] + + nn_niche = NearestNeighbors(n_neighbors=kp1, algorithm="auto").fit(niche_features) + _, niche_inds = nn_niche.kneighbors(niche_features) + niche_inds = niche_inds[:, 1:] + + # ── chance-normalised overlap ──────────────────────────────────────────── + overlaps = np.array([np.sum(np.isin(niche_inds[i], lat_inds[i])) / k for i in range(n)]) + raw = float(np.mean(overlaps)) + chance = k / (n - 1) + if chance >= 1.0: + return raw + return float(np.clip((raw - chance) / (1.0 - chance), 0.0, 1.0)) + + +# --------------------------------------------------------------------------- +# Domain boundary faithfulness +# --------------------------------------------------------------------------- + + +def _cluster_embedding(X_embedding: np.ndarray, n_clusters: int, seed: int) -> np.ndarray: + """K-means cluster X_embedding; returns integer label array.""" + km = KMeans(n_clusters=n_clusters, random_state=seed, n_init=10) + return km.fit_predict(X_embedding).astype(np.int32) + + +def spatial_pas( + X_embedding: np.ndarray, + spatial_coords: np.ndarray, + n_clusters: int = 10, + k_spatial: int = 6, + max_cells: int = 5000, + seed: int = 42, +) -> float: + """Proportion of Abnormal Spots (PAS), inverted to [0, 1] higher=better. + + Clusters the latent embedding with k-means, then measures how often + spatial neighbours belong to the same cluster. For each spot the + *abnormality* is the fraction of its ``k_spatial`` spatial neighbours + assigned to a different cluster; PAS is the mean over all spots. The + score is ``1 - PAS`` so that perfect spatial coherence of clusters → 1. + + A high score indicates that latent-derived domains are spatially + contiguous and align with tissue boundaries. + + Parameters + ---------- + X_embedding + Array of shape ``(n_spots, n_dims)`` — latent representation. + spatial_coords + Array of shape ``(n_spots, 2)`` with spatial coordinates (x, y). + n_clusters + Number of k-means clusters to derive from the latent embedding. + Default ``10``. + k_spatial + Number of spatial neighbours used to assess cluster consistency. + Default ``6``. + max_cells + Subsample to this many cells before computation. Default ``5000``. + seed + Random seed. Default ``42``. + + Returns + ------- + float + Score in ``[0, 1]``. **Higher is better** (more spatially coherent + clusters). + + References + ---------- + Hu et al. (2024) Benchmarking clustering methods for spatial + transcriptomics. PMC11312151. + """ + X_embedding = np.asarray(X_embedding, dtype=float) + spatial_coords = np.asarray(spatial_coords, dtype=float) + n = len(X_embedding) + + if n > max_cells: + rng = np.random.default_rng(seed) + idx = rng.choice(n, max_cells, replace=False) + X_embedding = X_embedding[idx] + spatial_coords = spatial_coords[idx] + n = max_cells + + n_clusters = min(n_clusters, n - 1) + cluster_labels = _cluster_embedding(X_embedding, n_clusters, seed) + + k = min(k_spatial, n - 1) + if k == 0: + return 1.0 + + nn = NearestNeighbors(n_neighbors=k + 1, algorithm="kd_tree").fit(spatial_coords) + _, inds = nn.kneighbors(spatial_coords) + inds = inds[:, 1:] # exclude self + + pas_per_spot = np.array([np.mean(cluster_labels[inds[i]] != cluster_labels[i]) for i in range(n)]) + pas = float(np.mean(pas_per_spot)) + return float(1.0 - pas) + + +def spatial_chaos( + X_embedding: np.ndarray, + spatial_coords: np.ndarray, + n_clusters: int = 10, + max_cells: int = 2000, + seed: int = 42, +) -> float: + """CHAOS score (spatial cluster compactness), inverted to [0, 1] higher=better. + + Clusters the latent embedding with k-means, then measures how spatially + compact the clusters are. For each cluster the mean pairwise spatial + distance is computed and normalised by the global mean pairwise spatial + distance. CHAOS is the mean normalised intra-cluster distance; the score + is ``1 - CHAOS``. + + A high score means that latent-derived clusters occupy small, contiguous + patches of tissue rather than scattered, fragmented regions. + + Parameters + ---------- + X_embedding + Array of shape ``(n_spots, n_dims)`` — latent representation. + spatial_coords + Array of shape ``(n_spots, 2)`` with spatial coordinates (x, y). + n_clusters + Number of k-means clusters to derive from the latent embedding. + Default ``10``. + max_cells + Subsample to this many cells before computation (O(n²) cost for + pairwise distances). Default ``2000``. + seed + Random seed. Default ``42``. + + Returns + ------- + float + Score in ``[0, 1]``. **Higher is better** (more spatially compact + clusters). + + References + ---------- + Hu et al. (2024) Benchmarking clustering methods for spatial + transcriptomics. PMC11312151. + Chen et al. (2025) A comprehensive benchmarking for spatially resolved + transcriptomics clustering methods. PMC12747554. + """ + X_embedding = np.asarray(X_embedding, dtype=float) + spatial_coords = np.asarray(spatial_coords, dtype=float) + n = len(X_embedding) + + if n > max_cells: + rng = np.random.default_rng(seed) + idx = rng.choice(n, max_cells, replace=False) + X_embedding = X_embedding[idx] + spatial_coords = spatial_coords[idx] + n = max_cells + + n_clusters = min(n_clusters, n - 1) + cluster_labels = _cluster_embedding(X_embedding, n_clusters, seed) + + global_dists = pdist(spatial_coords) + global_mean = float(np.mean(global_dists)) if len(global_dists) > 0 else 0.0 + if global_mean == 0.0: + return 1.0 + + per_cluster = [] + for c in range(n_clusters): + mask = cluster_labels == c + if mask.sum() < 2: + per_cluster.append(0.0) + continue + d = pdist(spatial_coords[mask]) + per_cluster.append(float(np.mean(d))) + + chaos = float(np.mean(per_cluster)) / global_mean + return float(np.clip(1.0 - chaos, 0.0, 1.0)) diff --git a/tests/test_spatial_metrics.py b/tests/test_spatial_metrics.py new file mode 100644 index 00000000..154199ac --- /dev/null +++ b/tests/test_spatial_metrics.py @@ -0,0 +1,938 @@ +"""Tests for spatial transcriptomics metrics.""" + +import anndata +import numpy as np +import pandas as pd +import pytest + +import scib_metrics +from scib_metrics.benchmark import ( + BatchCorrection, + Benchmarker, + BioConservation, + CoordinatePreservation, + DomainBoundary, + NichePreservation, + SpatialConservation, +) +from tests.utils.data import dummy_benchmarker_adata, dummy_spatial_benchmarker_adata + +# ── helpers ────────────────────────────────────────────────────────────────── + + +def _compact_data(seed=0): + """4 tight spatial blobs with matching labels.""" + rng = np.random.default_rng(seed) + centers = np.array([[0, 0], [50, 0], [0, 50], [50, 50]], dtype=float) + labels = np.repeat(np.arange(4), 50) + coords = centers[labels] + rng.normal(scale=2.0, size=(200, 2)) + return coords, labels + + +# ── MRRE ────────────────────────────────────────────────────────────────────── + + +def test_spatial_mrre_returns_float_in_range(): + coords, _ = _compact_data() + rng = np.random.default_rng(3) + emb = coords + rng.normal(scale=1.0, size=coords.shape) + score = scib_metrics.spatial_mrre(emb, coords) + assert isinstance(score, float) + assert 0.0 <= score <= 1.0 + + +def test_spatial_mrre_identity_is_best(): + coords, _ = _compact_data() + assert scib_metrics.spatial_mrre(coords, coords) > scib_metrics.spatial_mrre( + np.random.default_rng(3).normal(size=coords.shape), coords + ) + + +def test_spatial_mrre_random_worse_than_correlated(): + coords, _ = _compact_data() + rng = np.random.default_rng(4) + corr = scib_metrics.spatial_mrre(coords + rng.normal(scale=1.0, size=coords.shape), coords) + rand = scib_metrics.spatial_mrre(rng.normal(size=(len(coords), 10)), coords) + assert corr > rand + + +# ── kNN overlap ─────────────────────────────────────────────────────────────── + + +def test_spatial_knn_overlap_returns_float_in_range(): + coords, _ = _compact_data() + rng = np.random.default_rng(4) + score = scib_metrics.spatial_knn_overlap(coords + rng.normal(scale=1.0, size=coords.shape), coords) + assert isinstance(score, float) + assert 0.0 <= score <= 1.0 + + +def test_spatial_knn_overlap_identity_is_one(): + coords, _ = _compact_data() + assert scib_metrics.spatial_knn_overlap(coords, coords) == pytest.approx(1.0) + + +def test_spatial_knn_overlap_random_lower(): + coords, _ = _compact_data() + rng = np.random.default_rng(5) + corr = scib_metrics.spatial_knn_overlap(coords + rng.normal(scale=0.5, size=coords.shape), coords) + rand = scib_metrics.spatial_knn_overlap(rng.normal(size=(len(coords), 10)), coords) + assert corr > rand + + +# ── Distance correlation ────────────────────────────────────────────────────── + + +def test_spatial_distance_correlation_returns_float_in_range(): + coords, _ = _compact_data() + rng = np.random.default_rng(5) + score = scib_metrics.spatial_distance_correlation(coords + rng.normal(scale=1.0, size=coords.shape), coords) + assert isinstance(score, float) + assert 0.0 <= score <= 1.0 + + +def test_spatial_distance_correlation_identity_is_one(): + coords, _ = _compact_data() + assert scib_metrics.spatial_distance_correlation(coords, coords) == pytest.approx(1.0) + + +def test_spatial_distance_correlation_correlated_higher(): + coords, _ = _compact_data() + rng = np.random.default_rng(5) + corr = scib_metrics.spatial_distance_correlation(coords + rng.normal(scale=1.0, size=coords.shape), coords) + rand = scib_metrics.spatial_distance_correlation(rng.normal(size=(len(coords), 10)), coords) + assert corr > rand + + +# ── Moran's I ───────────────────────────────────────────────────────────────── + + +def test_spatial_morans_i_returns_float_in_range(): + coords, _ = _compact_data() + rng = np.random.default_rng(6) + emb = coords + rng.normal(scale=1.0, size=coords.shape) + score = scib_metrics.spatial_morans_i(emb, coords) + assert isinstance(score, float) + assert 0.0 <= score <= 1.0 + + +def test_spatial_morans_i_spatially_smooth_higher(): + """A spatially smooth embedding should score higher than random noise.""" + coords, _ = _compact_data() + rng = np.random.default_rng(7) + smooth = coords + rng.normal(scale=0.5, size=coords.shape) + noisy = rng.normal(size=coords.shape) + assert scib_metrics.spatial_morans_i(smooth, coords) > scib_metrics.spatial_morans_i(noisy, coords) + + +# ── Benchmarker integration ─────────────────────────────────────────────────── + + +def test_benchmarker_with_spatial_conservation(): + adata, emb_keys, batch_key, labels_key = dummy_spatial_benchmarker_adata() + bm = Benchmarker( + adata, + batch_key, + labels_key, + emb_keys, + bio_conservation_metrics=None, + batch_correction_metrics=None, + spatial_conservation_metrics=SpatialConservation(), + spatial_key="spatial", + ) + bm.benchmark() + results = bm.get_results(clean_names=False) + assert isinstance(results, pd.DataFrame) + for col in ("spatial_mrre", "spatial_knn_overlap", "spatial_distance_correlation", "spatial_morans_i"): + assert col in results.columns, f"Missing column: {col}" + bm.plot_results_table() + + +def test_benchmarker_spatial_aggregate_present(): + """Coordinate preservation aggregate column should appear in results.""" + adata, emb_keys, batch_key, labels_key = dummy_spatial_benchmarker_adata() + bm = Benchmarker( + adata, + batch_key, + labels_key, + emb_keys, + bio_conservation_metrics=None, + batch_correction_metrics=None, + spatial_key="spatial", + ) + bm.benchmark() + results = bm.get_results() + assert "Coordinate preservation" in results.columns + + +def test_benchmarker_spatial_with_bio_and_batch(): + """All three metric categories run together; spatial not in Total (weight=0).""" + adata, emb_keys, batch_key, labels_key = dummy_spatial_benchmarker_adata() + bm = Benchmarker( + adata, + batch_key, + labels_key, + emb_keys, + bio_conservation_metrics=BioConservation( + isolated_labels=False, + nmi_ari_cluster_labels_leiden=False, + nmi_ari_cluster_labels_kmeans=True, + silhouette_label=True, + clisi_knn=False, + ), + batch_correction_metrics=BatchCorrection( + bras=True, + ilisi_knn=False, + kbet_per_label=False, + graph_connectivity=False, + pcr_comparison=False, + ), + spatial_key="spatial", + ) + bm.benchmark() + results = bm.get_results() + assert isinstance(results, pd.DataFrame) + assert "Total" in results.columns + assert "Coordinate preservation" in results.columns + bm.plot_results_table() + + +def test_benchmarker_spatial_weight_in_total(): + """Non-zero spatial_conservation_weight should shift the Total score.""" + adata, emb_keys, batch_key, labels_key = dummy_spatial_benchmarker_adata() + bio = BioConservation( + isolated_labels=False, + nmi_ari_cluster_labels_leiden=False, + nmi_ari_cluster_labels_kmeans=True, + silhouette_label=True, + clisi_knn=False, + ) + batch = BatchCorrection( + bras=True, + ilisi_knn=False, + kbet_per_label=False, + graph_connectivity=False, + pcr_comparison=False, + ) + bm0 = Benchmarker( + adata, + batch_key, + labels_key, + emb_keys, + bio_conservation_metrics=bio, + batch_correction_metrics=batch, + spatial_key="spatial", + spatial_conservation_weight=0.0, + ) + bm0.benchmark() + + bm1 = Benchmarker( + adata, + batch_key, + labels_key, + emb_keys, + bio_conservation_metrics=bio, + batch_correction_metrics=batch, + spatial_key="spatial", + spatial_conservation_weight=0.2, + ) + bm1.benchmark() + + # Drop the "Metric Type" row (which contains string "Aggregate score") before casting + r0 = bm0.get_results().loc[emb_keys, "Total"].astype(float) + r1 = bm1.get_results().loc[emb_keys, "Total"].astype(float) + # With weight=0.2 the totals should differ from weight=0.0 + assert not np.allclose(r0.values, r1.values) + + +def test_benchmarker_embedding_spatial_scores_vary(): + """Embedding-based spatial metrics should differ across different embeddings.""" + adata, emb_keys, batch_key, labels_key = dummy_spatial_benchmarker_adata(n_spots=200, seed=0) + bm = Benchmarker( + adata, + batch_key, + labels_key, + emb_keys, + bio_conservation_metrics=None, + batch_correction_metrics=None, + spatial_key="spatial", + ) + bm.benchmark() + results = bm.get_results(clean_names=False) + for col in ("spatial_mrre", "spatial_knn_overlap", "spatial_distance_correlation", "spatial_morans_i"): + vals = results.loc[results.index.isin(emb_keys), col].astype(float).values + # At least some variation expected (embeddings are independently random) + assert not np.allclose(vals, vals[0]), f"Metric '{col}' should vary across embeddings" + + +def test_benchmarker_spatial_scores_in_unit_interval(): + """All spatial metric values should be in [0, 1].""" + adata, emb_keys, batch_key, labels_key = dummy_spatial_benchmarker_adata() + bm = Benchmarker( + adata, + batch_key, + labels_key, + emb_keys, + bio_conservation_metrics=None, + batch_correction_metrics=None, + spatial_key="spatial", + ) + bm.benchmark() + results = bm.get_results(clean_names=False) + for col in ("spatial_mrre", "spatial_knn_overlap", "spatial_distance_correlation", "spatial_morans_i"): + vals = results.loc[results.index.isin(emb_keys), col].astype(float).values + assert np.all(vals >= 0.0), f"'{col}' has values < 0" + assert np.all(vals <= 1.0), f"'{col}' has values > 1" + + +# ── auto-detection / flag behaviour ────────────────────────────────────────── + + +def test_benchmarker_no_spatial_key_no_spatial_metrics(): + ad, emb_keys, batch_key, labels_key = dummy_benchmarker_adata() + bm = Benchmarker(ad, batch_key, labels_key, emb_keys) + bm.benchmark() + results = bm.get_results(clean_names=False) + for col in ("spatial_mrre", "spatial_knn_overlap", "spatial_distance_correlation", "spatial_morans_i"): + assert col not in results.columns + + +def test_benchmarker_spatial_key_auto_enables_spatial_metrics(): + adata, emb_keys, batch_key, labels_key = dummy_spatial_benchmarker_adata() + bm = Benchmarker( + adata, + batch_key, + labels_key, + emb_keys, + bio_conservation_metrics=None, + batch_correction_metrics=None, + spatial_key="spatial", + ) + bm.benchmark() + results = bm.get_results(clean_names=False) + for col in ("spatial_mrre", "spatial_knn_overlap", "spatial_distance_correlation", "spatial_morans_i"): + assert col in results.columns + + +def test_benchmarker_spatial_key_explicit_none_disables_spatial(): + adata, emb_keys, batch_key, labels_key = dummy_spatial_benchmarker_adata() + bm = Benchmarker( + adata, + batch_key, + labels_key, + emb_keys, + bio_conservation_metrics=None, + batch_correction_metrics=BatchCorrection( + bras=True, + ilisi_knn=False, + kbet_per_label=False, + graph_connectivity=False, + pcr_comparison=False, + ), + spatial_conservation_metrics=None, + spatial_key="spatial", + ) + bm.benchmark() + results = bm.get_results(clean_names=False) + for col in ("spatial_mrre", "spatial_knn_overlap", "spatial_distance_correlation", "spatial_morans_i"): + assert col not in results.columns + + +def test_benchmarker_spatial_partial_config(): + adata, emb_keys, batch_key, labels_key = dummy_spatial_benchmarker_adata() + bm = Benchmarker( + adata, + batch_key, + labels_key, + emb_keys, + bio_conservation_metrics=None, + batch_correction_metrics=None, + spatial_conservation_metrics=SpatialConservation( + spatial_knn_overlap=False, + spatial_distance_correlation=False, + ), + spatial_key="spatial", + ) + bm.benchmark() + results = bm.get_results(clean_names=False) + assert "spatial_mrre" in results.columns + assert "spatial_morans_i" in results.columns + assert "spatial_knn_overlap" not in results.columns + assert "spatial_distance_correlation" not in results.columns + + +def test_benchmarker_spatial_missing_key_raises(): + adata, emb_keys, batch_key, labels_key = dummy_spatial_benchmarker_adata() + with pytest.raises(ValueError, match="spatial_key must be provided"): + Benchmarker( + adata, + batch_key, + labels_key, + emb_keys, + bio_conservation_metrics=None, + batch_correction_metrics=None, + spatial_conservation_metrics=SpatialConservation(), + spatial_key=None, + ) + + +def test_benchmarker_spatial_wrong_key_raises(): + adata, emb_keys, batch_key, labels_key = dummy_spatial_benchmarker_adata() + with pytest.raises(ValueError, match="not found in adata.obsm"): + Benchmarker( + adata, + batch_key, + labels_key, + emb_keys, + spatial_conservation_metrics=SpatialConservation(), + spatial_key="nonexistent_key", + ) + + +# ── Metric differentiation: each metric uniquely sensitive to its distortion ── +# +# All four spatial metrics measure "does the embedding preserve spatial structure?" +# but from fundamentally different angles: +# +# • distance_correlation — ALL pairwise distance ranks (global structure). +# • kNN overlap — SET MEMBERSHIP of the local k-NN (local structure). +# • MRRE — RANK ORDER within the local k-NN (stricter than kNN). +# • Moran's I — SPATIAL AUTOCORRELATION of embedding values; +# does NOT compare distances — asks whether adjacent +# cells have similar embeddings regardless of scale. +# +# In real notebooks all four tend to agree because a model either captures +# spatial structure (all high) or doesn't (all moderate). The tests below +# use four synthetic distortion types that isolate each metric's blind spot, +# demonstrating when they MUST disagree. + + +def _four_cluster_coords(n_per: int = 60, sep: float = 50.0, spread: float = 1.0, seed: int = 42): + """4 tight, well-separated 2-D clusters (sep >> spread).""" + rng = np.random.default_rng(seed) + centers = np.array([[0, 0], [sep, 0], [0, sep], [sep, sep]], dtype=float) + labels = np.repeat(np.arange(4), n_per) + coords = centers[labels] + rng.normal(scale=spread, size=(4 * n_per, 2)) + return coords, labels + + +def test_distance_correlation_uniquely_captures_global_structure(): + """ + 'Global-only' embedding: cluster centroids are preserved in the embedding + but within-cluster positions are replaced by large independent noise. + + Between-cluster pairs dominate the pairwise distance matrix (75 % of pairs) + and are correctly ordered in both spaces → distance_correlation HIGH. + The k spatial neighbours of every cell are specific nearby cells within the + same cluster; the random intra-cluster noise picks different cells as + latent neighbours → kNN overlap and MRRE clearly LOWER. + + This shows that distance_correlation is uniquely sensitive to global + distance structure and can be high even when local neighbourhoods are wrong. + """ + coords, labels = _four_cluster_coords() + rng = np.random.default_rng(1) + centers = np.array([coords[labels == l].mean(0) for l in range(4)]) + # Centroid preserved; within-cluster noise >> spread but << sep + emb = centers[labels] + rng.normal(scale=8.0, size=coords.shape) + + dist_corr = scib_metrics.spatial_distance_correlation(emb, coords) + knn = scib_metrics.spatial_knn_overlap(emb, coords) + mrre = scib_metrics.spatial_mrre(emb, coords) + + assert dist_corr > 0.80, f"distance_correlation should be high: {dist_corr:.3f}" + assert knn < 0.40, f"knn_overlap should be low: {knn:.3f}" + assert mrre < 0.65, f"mrre should be low: {mrre:.3f}" + assert dist_corr > knn + 0.50, ( + f"distance_correlation ({dist_corr:.3f}) should clearly exceed " + f"knn_overlap ({knn:.3f}) for a global-only embedding" + ) + + +def test_knn_overlap_uniquely_captures_local_neighbourhood(): + """ + 'Local-only' embedding: within-cluster relative positions are preserved + exactly, but each cluster is translated by a large random offset, completely + destroying between-cluster distances. + + The k spatial neighbours of every cell lie within its cluster; the embedding + shifts the whole cluster as a rigid body, so the same k cells remain the k + nearest in embedding space → kNN overlap = 1.0, MRRE = 1.0. + Between-cluster distances are random (offsets >> sep) + → distance_correlation clearly LOWER. + + This shows that kNN overlap and MRRE are uniquely sensitive to local + neighbourhood membership and can be perfect even when global structure fails. + """ + coords, labels = _four_cluster_coords() + rng = np.random.default_rng(2) + centers = np.array([coords[labels == l].mean(0) for l in range(4)]) + # Subtract cluster centre (preserves relative positions), add random large offset + offsets = rng.normal(scale=500.0, size=(4, 2)) + emb = (coords - centers[labels]) + offsets[labels] + + knn = scib_metrics.spatial_knn_overlap(emb, coords) + mrre = scib_metrics.spatial_mrre(emb, coords) + dist_corr = scib_metrics.spatial_distance_correlation(emb, coords) + + assert knn > 0.90, f"knn_overlap should be near 1: {knn:.3f}" + assert mrre > 0.90, f"mrre should be near 1: {mrre:.3f}" + assert knn > dist_corr + 0.10, ( + f"knn_overlap ({knn:.3f}) should clearly exceed " + f"distance_correlation ({dist_corr:.3f}) for a local-only embedding" + ) + + +def test_morans_i_uniquely_captures_spatial_smoothness(): + """ + 'Cluster-smooth' embedding: every cell is mapped to a randomly chosen + cluster centroid (not its own cluster's position) plus tiny noise. + + Spatially adjacent cells all belong to the same physical cluster, so they + all receive the same (random) centroid embedding → embedding values are + perfectly constant within each neighbourhood → Moran's I = 1.0. + + The k latent neighbours of a cell are determined by tiny independent noise + rather than physical proximity → kNN overlap and MRRE clearly LOWER. + + This shows that Moran's I is uniquely sensitive to spatial smoothness of + the embedding signal, independently of whether exact neighbours are correct. + """ + coords, labels = _four_cluster_coords() + rng = np.random.default_rng(99) + # Each cluster maps to a random point in embedding space + random_centroids = rng.normal(scale=50.0, size=(4, 2)) + emb = random_centroids[labels] + rng.normal(scale=0.01, size=coords.shape) + + morans = scib_metrics.spatial_morans_i(emb, coords) + knn = scib_metrics.spatial_knn_overlap(emb, coords) + mrre = scib_metrics.spatial_mrre(emb, coords) + + assert morans > 0.95, f"spatial_morans_i should be near 1: {morans:.3f}" + assert knn < 0.40, f"knn_overlap should be low: {knn:.3f}" + assert mrre < 0.65, f"mrre should be low: {mrre:.3f}" + assert morans > knn + 0.50, ( + f"spatial_morans_i ({morans:.3f}) should clearly exceed knn_overlap ({knn:.3f}) for a cluster-smooth embedding" + ) + + +def test_mrre_uniquely_sensitive_to_rank_order_within_knn(): + """ + 'Rank-reversed' embedding: for a 1-D spatial layout the k-NN SET is + identical in spatial and embedding spaces, but the rank ORDER within + that set is reversed. + + Construction: cells are placed on a 1-D line (y = 0). The embedding is + 2-D with emb_y alternating 0 / delta (delta = 3 > sqrt(3)). For every + even cell i, distance-2 neighbours (even, same emb_y) appear CLOSER in + the embedding than distance-1 neighbours (odd, large emb_y difference), + inverting the rank order while keeping the same 4-member k-NN set. + + Expected: + kNN overlap ≈ 1.0 (same 4 members as spatial k-NN) + MRRE ≈ 0.5 (rank order within k-NN inverted) + + This shows MRRE is strictly more demanding than kNN overlap: it penalises + incorrect ordering even when set membership is perfect. + """ + n = 80 + k = 4 + delta = 3.0 # > sqrt(3) ≈ 1.73, ensuring distance-2 < distance-1 in embedding + coords_1d = np.column_stack([np.arange(n, dtype=float), np.zeros(n)]) + emb_y = np.where(np.arange(n) % 2 == 0, 0.0, delta) + emb = np.column_stack([np.arange(n, dtype=float), emb_y]) + + knn = scib_metrics.spatial_knn_overlap(emb, coords_1d, k=k) + mrre = scib_metrics.spatial_mrre(emb, coords_1d, k=k) + + assert knn > 0.90, f"knn_overlap should be near 1 (same members): {knn:.3f}" + assert mrre < 0.65, f"mrre should be low (reversed ranks): {mrre:.3f}" + assert knn > mrre + 0.30, ( + f"knn_overlap ({knn:.3f}) should clearly exceed mrre ({mrre:.3f}) when rank order within k-NN is reversed" + ) + + +# ── PAS ─────────────────────────────────────────────────────────────────────── + + +def test_spatial_pas_returns_float_in_range(): + coords, _ = _compact_data() + rng = np.random.default_rng(10) + emb = coords + rng.normal(scale=1.0, size=coords.shape) + score = scib_metrics.spatial_pas(emb, coords, n_clusters=4) + assert isinstance(score, float) + assert 0.0 <= score <= 1.0 + + +def test_spatial_pas_spatially_coherent_clusters_score_higher(): + """Clusters aligned with spatial blobs should score higher than random.""" + coords, labels = _compact_data() + rng = np.random.default_rng(11) + # Good embedding: tight around spatial clusters → k-means recovers spatial blobs + good_emb = coords + rng.normal(scale=0.5, size=coords.shape) + # Bad embedding: pure noise → k-means finds random clusters + bad_emb = rng.normal(scale=50.0, size=coords.shape) + good = scib_metrics.spatial_pas(good_emb, coords, n_clusters=4, seed=0) + bad = scib_metrics.spatial_pas(bad_emb, coords, n_clusters=4, seed=0) + assert good > bad, f"good={good:.3f} should exceed bad={bad:.3f}" + + +# ── CHAOS ───────────────────────────────────────────────────────────────────── + + +def test_spatial_chaos_returns_float_in_range(): + coords, _ = _compact_data() + rng = np.random.default_rng(12) + emb = coords + rng.normal(scale=1.0, size=coords.shape) + score = scib_metrics.spatial_chaos(emb, coords, n_clusters=4) + assert isinstance(score, float) + assert 0.0 <= score <= 1.0 + + +def test_spatial_chaos_compact_clusters_score_higher(): + """Embedding whose clusters are spatially compact should score higher.""" + coords, labels = _compact_data() + rng = np.random.default_rng(13) + good_emb = coords + rng.normal(scale=0.5, size=coords.shape) + bad_emb = rng.normal(scale=50.0, size=coords.shape) + good = scib_metrics.spatial_chaos(good_emb, coords, n_clusters=4, seed=0) + bad = scib_metrics.spatial_chaos(bad_emb, coords, n_clusters=4, seed=0) + assert good > bad, f"good={good:.3f} should exceed bad={bad:.3f}" + + +# ── Niche kNN overlap ───────────────────────────────────────────────────────── + + +def test_spatial_niche_knn_overlap_returns_float_in_range(): + coords, _ = _compact_data() + rng = np.random.default_rng(14) + emb = coords + rng.normal(scale=1.0, size=coords.shape) + X_expr = coords + rng.normal(scale=0.5, size=coords.shape) + score = scib_metrics.spatial_niche_knn_overlap(emb, coords, X_expression=X_expr) + assert isinstance(score, float) + assert 0.0 <= score <= 1.0 + + +def test_spatial_niche_knn_overlap_fallback_no_expression(): + """Should work without X_expression (falls back to spatial coord averaging).""" + coords, _ = _compact_data() + rng = np.random.default_rng(15) + emb = coords + rng.normal(scale=1.0, size=coords.shape) + score = scib_metrics.spatial_niche_knn_overlap(emb, coords, X_expression=None) + assert isinstance(score, float) + assert 0.0 <= score <= 1.0 + + +def test_spatial_niche_knn_overlap_niche_aware_higher(): + """Embedding that captures niche structure should outscore random noise.""" + coords, labels = _compact_data() + rng = np.random.default_rng(16) + X_expr = coords + rng.normal(scale=0.5, size=coords.shape) # expr correlated with space + good_emb = coords + rng.normal(scale=0.5, size=coords.shape) + rand_emb = rng.normal(size=(len(coords), 10)) + good = scib_metrics.spatial_niche_knn_overlap(good_emb, coords, X_expression=X_expr) + rand = scib_metrics.spatial_niche_knn_overlap(rand_emb, coords, X_expression=X_expr) + assert good > rand, f"good={good:.3f} should exceed rand={rand:.3f}" + + +# ── New dataclasses / Benchmarker integration ───────────────────────────────── + + +def test_niche_preservation_benchmarker(): + """NichePreservation runs through the Benchmarker pipeline.""" + adata, emb_keys, batch_key, labels_key = dummy_spatial_benchmarker_adata() + bm = Benchmarker( + adata, + batch_key, + labels_key, + emb_keys, + bio_conservation_metrics=None, + batch_correction_metrics=None, + spatial_conservation_metrics=None, + niche_preservation=NichePreservation(), + spatial_key="spatial", + ) + bm.benchmark() + results = bm.get_results(clean_names=False) + assert "spatial_niche_knn_overlap" in results.columns + vals = results.loc[results.index.isin(emb_keys), "spatial_niche_knn_overlap"].astype(float).values + assert np.all(vals >= 0.0) and np.all(vals <= 1.0) + + +def test_domain_boundary_benchmarker(): + """DomainBoundary runs through the Benchmarker pipeline.""" + adata, emb_keys, batch_key, labels_key = dummy_spatial_benchmarker_adata() + bm = Benchmarker( + adata, + batch_key, + labels_key, + emb_keys, + bio_conservation_metrics=None, + batch_correction_metrics=None, + spatial_conservation_metrics=None, + domain_boundary=DomainBoundary(), + spatial_key="spatial", + ) + bm.benchmark() + results = bm.get_results(clean_names=False) + assert "spatial_pas" in results.columns + assert "spatial_chaos" in results.columns + for col in ("spatial_pas", "spatial_chaos"): + vals = results.loc[results.index.isin(emb_keys), col].astype(float).values + assert np.all(vals >= 0.0) and np.all(vals <= 1.0) + bm.plot_results_table() + + +def test_all_three_spatial_axes_together(): + """All three spatial axes run together and produce three aggregate columns.""" + adata, emb_keys, batch_key, labels_key = dummy_spatial_benchmarker_adata() + bm = Benchmarker( + adata, + batch_key, + labels_key, + emb_keys, + bio_conservation_metrics=None, + batch_correction_metrics=None, + spatial_conservation_metrics=CoordinatePreservation(), + niche_preservation=NichePreservation(), + domain_boundary=DomainBoundary(), + spatial_key="spatial", + ) + bm.benchmark() + results = bm.get_results() + for col in ("Coordinate preservation", "Niche preservation", "Domain boundary"): + assert col in results.columns, f"Missing aggregate column: {col}" + bm.plot_results_table() + + +def test_spatial_conservation_alias_still_works(): + """SpatialConservation is still a valid alias for CoordinatePreservation.""" + adata, emb_keys, batch_key, labels_key = dummy_spatial_benchmarker_adata() + bm = Benchmarker( + adata, + batch_key, + labels_key, + emb_keys, + bio_conservation_metrics=None, + batch_correction_metrics=None, + spatial_conservation_metrics=SpatialConservation(), + spatial_key="spatial", + ) + bm.benchmark() + results = bm.get_results() + assert "Coordinate preservation" in results.columns + + +def test_niche_domain_missing_spatial_key_raises(): + """niche_preservation and domain_boundary require spatial_key.""" + adata, emb_keys, batch_key, labels_key = dummy_spatial_benchmarker_adata() + with pytest.raises(ValueError, match="spatial_key must be provided"): + Benchmarker( + adata, + batch_key, + labels_key, + emb_keys, + bio_conservation_metrics=None, + batch_correction_metrics=None, + spatial_conservation_metrics=None, + niche_preservation=NichePreservation(), + spatial_key=None, + ) + with pytest.raises(ValueError, match="spatial_key must be provided"): + Benchmarker( + adata, + batch_key, + labels_key, + emb_keys, + bio_conservation_metrics=None, + batch_correction_metrics=None, + spatial_conservation_metrics=None, + domain_boundary=DomainBoundary(), + spatial_key=None, + ) + + +# ── region_key: per-label-value regional scoring ───────────────────────────── + + +def _region_benchmarker_adata(n_per_celltype: int = 150, seed: int = 0): + """AnnData with cell_type and region columns for region_key tests. + + 3 cell types × 150 cells each = 450 total. + 2 regions, 2 batches. One embedding is spatially structured, one is random. + """ + rng = np.random.default_rng(seed) + n = n_per_celltype * 3 + cell_type = np.repeat(["A", "B", "C"], n_per_celltype) + region = np.tile(["R1", "R2"], n // 2) + batch = np.tile(["b1", "b2"], n // 2) + x_data = rng.normal(size=(n, 20)) + ad = anndata.AnnData(X=x_data) + ad.obs["cell_type"] = cell_type + ad.obs["region"] = region + ad.obs["batch"] = batch + ad.obsm["X_good"] = rng.normal(size=(n, 10)) + ad.obsm["X_rand"] = rng.normal(size=(n, 10)) + return ad + + +_FAST_BIO = BioConservation( + silhouette_label=True, + nmi_ari_cluster_labels_kmeans=False, + isolated_labels=False, + clisi_knn=False, +) +_FAST_BATCH = BatchCorrection( + bras=False, + ilisi_knn=False, + kbet_per_label=False, + graph_connectivity=False, + pcr_comparison=True, +) + + +def test_region_key_produces_region_conservation_columns(): + """region_key should add 'Region Bio conservation' and 'Region Batch correction' columns.""" + adata = _region_benchmarker_adata() + bm = Benchmarker( + adata, + batch_key="batch", + label_key="cell_type", + embedding_obsm_keys=["X_good"], + bio_conservation_metrics=_FAST_BIO, + batch_correction_metrics=_FAST_BATCH, + region_key="region", + n_jobs=1, + progress_bar=False, + ) + bm.benchmark() + results = bm.get_results() + assert "Region Bio conservation" in results.columns + assert "Region Batch correction" in results.columns + + +def test_region_key_group_header_is_region_conservation(): + """The Metric Type row for region columns should read 'Region conservation'.""" + adata = _region_benchmarker_adata() + bm = Benchmarker( + adata, + batch_key="batch", + label_key="cell_type", + embedding_obsm_keys=["X_good"], + bio_conservation_metrics=_FAST_BIO, + batch_correction_metrics=_FAST_BATCH, + region_key="region", + n_jobs=1, + progress_bar=False, + ) + bm.benchmark() + results = bm.get_results() + assert results.loc["Metric Type", "Region Bio conservation"] == "Region conservation" + assert results.loc["Metric Type", "Region Batch correction"] == "Region conservation" + + +def test_region_key_scores_in_unit_interval(): + """Region aggregate scores should be in [0, 1].""" + adata = _region_benchmarker_adata() + bm = Benchmarker( + adata, + batch_key="batch", + label_key="cell_type", + embedding_obsm_keys=["X_good", "X_rand"], + bio_conservation_metrics=_FAST_BIO, + batch_correction_metrics=_FAST_BATCH, + region_key="region", + n_jobs=1, + progress_bar=False, + ) + bm.benchmark() + results = bm.get_results() + for col in ("Region Bio conservation", "Region Batch correction"): + vals = results.loc[["X_good", "X_rand"], col].astype(float).values + assert np.all(vals >= 0.0), f"'{col}' has values < 0: {vals}" + assert np.all(vals <= 1.0), f"'{col}' has values > 1: {vals}" + + +def test_region_key_none_does_not_add_region_columns(): + """Without region_key the results table should not contain region columns.""" + adata = _region_benchmarker_adata() + bm = Benchmarker( + adata, + batch_key="batch", + label_key="cell_type", + embedding_obsm_keys=["X_good"], + bio_conservation_metrics=_FAST_BIO, + batch_correction_metrics=_FAST_BATCH, + n_jobs=1, + progress_bar=False, + ) + bm.benchmark() + results = bm.get_results() + assert "Region Bio conservation" not in results.columns + assert "Region Batch correction" not in results.columns + + +def test_region_key_invalid_column_raises(): + """Passing a non-existent region_key should raise ValueError at construction.""" + adata = _region_benchmarker_adata() + with pytest.raises(ValueError, match="region_key 'nonexistent'"): + Benchmarker( + adata, + batch_key="batch", + label_key="cell_type", + embedding_obsm_keys=["X_good"], + bio_conservation_metrics=_FAST_BIO, + region_key="nonexistent", + ) + + +def test_region_key_small_subset_skipped_with_warning(): + """Cell types with fewer than min_cells cells should be skipped with a warning.""" + rng = np.random.default_rng(42) + # 200 cells for type A, but only 50 for type B (< 91 = max_neighbors + 1) + cell_type = np.array(["A"] * 200 + ["B"] * 50) + region = np.tile(["R1", "R2"], 125) + batch = np.tile(["b1", "b2"], 125) + x_data = rng.normal(size=(250, 20)) + adata = anndata.AnnData(X=x_data) + adata.obs["cell_type"] = cell_type + adata.obs["region"] = region + adata.obs["batch"] = batch + adata.obsm["X_emb"] = rng.normal(size=(250, 10)) + + bm = Benchmarker( + adata, + batch_key="batch", + label_key="cell_type", + embedding_obsm_keys=["X_emb"], + bio_conservation_metrics=_FAST_BIO, + batch_correction_metrics=None, + region_key="region", + n_jobs=1, + progress_bar=False, + ) + with pytest.warns(UserWarning, match="Skipping region scoring for label 'B'"): + bm.benchmark() + # Region Bio conservation should still appear (computed for type A) + results = bm.get_results() + assert "Region Bio conservation" in results.columns + + +def test_region_key_only_bio_metrics_enabled(): + """When batch_correction_metrics=None, only Region Bio conservation column appears.""" + adata = _region_benchmarker_adata() + bm = Benchmarker( + adata, + batch_key="batch", + label_key="cell_type", + embedding_obsm_keys=["X_good"], + bio_conservation_metrics=_FAST_BIO, + batch_correction_metrics=None, + region_key="region", + n_jobs=1, + progress_bar=False, + ) + bm.benchmark() + results = bm.get_results() + assert "Region Bio conservation" in results.columns + assert "Region Batch correction" not in results.columns diff --git a/tests/utils/data.py b/tests/utils/data.py index 85177416..47fb92d2 100644 --- a/tests/utils/data.py +++ b/tests/utils/data.py @@ -41,3 +41,26 @@ def dummy_benchmarker_adata(): adata.obsm[key] = X embedding_keys.append(key) return adata, embedding_keys, labels_key, batch_key + + +def dummy_spatial_benchmarker_adata(n_spots: int = 200, n_clusters: int = 4, seed: int = 0): + """AnnData with spatial coordinates for testing the spatial benchmarker path.""" + rng = np.random.default_rng(seed) + centers = rng.uniform(0, 100, size=(n_clusters, 2)) + labels = rng.integers(0, n_clusters, size=n_spots) + spatial_coords = centers[labels] + rng.normal(scale=5.0, size=(n_spots, 2)) + X = rng.normal(size=(n_spots, 10)) + batch = rng.integers(0, 2, size=n_spots) + + adata = anndata.AnnData(X) + adata.obs["labels"] = labels + adata.obs["batch"] = batch + adata.obsm["spatial"] = spatial_coords + + embedding_keys = [] + for i in range(3): + key = f"X_emb_{i}" + adata.obsm[key] = rng.normal(size=(n_spots, 10)) + embedding_keys.append(key) + + return adata, embedding_keys, "batch", "labels"