diff --git a/pertpy/tools/_mixscape.py b/pertpy/tools/_mixscape.py index 5eb6acb7..674ce280 100644 --- a/pertpy/tools/_mixscape.py +++ b/pertpy/tools/_mixscape.py @@ -390,6 +390,188 @@ def mixscape( if copy: return adata + def mixscale( + self, + adata: AnnData, + pert_key: str, + control: str, + *, + new_class_name: str = "mixscale_score", + layer: str | None = None, + min_de_genes: int = 5, + max_de_genes: int = 100, + logfc_threshold: float = 0.25, + de_layer: str | None = None, + test_method: str = "wilcoxon", + scale: bool = True, + split_by: str | None = None, + pval_cutoff: float = 5e-2, + perturbation_type: str = "KO", + copy: bool = False, + ): + """Calculate continuous perturbation scores using the Mixscale method. + + Unlike :meth:`mixscape` which performs binary KO/NP classification via + Gaussian Mixture Models, this method assigns a continuous perturbation + efficiency score to each cell. The score is the scalar projection of + each cell's perturbation signature onto the estimated perturbation + direction vector, standardized relative to non-targeting controls. + + This is particularly useful for CRISPRi/CRISPRa screens where cells + exhibit a gradient of perturbation responses rather than binary + knockouts. + + The implementation follows Jiang, Dalgarno et al., "Systematic + reconstruction of molecular pathway signatures using scalable + single-cell perturbation screens", Nature Cell Biology (2025). + + Args: + adata: The annotated data object. + pert_key: The column of `.obs` with target gene labels. + control: Control category from the `pert_key` column. + new_class_name: Name of the score column to be stored in `.obs`. + layer: Key from `adata.layers` whose value will be used for scoring. + Default is using `.layers["X_pert"]`. + min_de_genes: Required number of DE genes for scoring a perturbation. + Perturbations with fewer DE genes are skipped. + max_de_genes: Maximum number of DE genes to use for scoring. + logfc_threshold: Minimum log fold-change threshold for DE gene selection. + de_layer: Layer to use for identifying differentially expressed genes. + If `None`, `adata.X` is used. + test_method: Method to use for differential expression testing. + scale: Whether to scale the perturbation data before computing scores. + split_by: Provide `.obs` column with experimental condition/cell type + annotation, if perturbations are condition/cell type-specific. + pval_cutoff: P-value cut-off for selection of significantly DE genes. + perturbation_type: Type of CRISPR perturbation for labeling. + copy: Determines whether a copy of the `adata` is returned. + + Returns: + If `copy=True`, returns the copy of `adata` with the scores in `.obs`. + Otherwise, writes the scores directly to `.obs` of the provided `adata`. + + The following fields are added to `adata.obs`: + + - `adata.obs[new_class_name]`: Continuous perturbation score per cell. + Higher absolute values indicate stronger perturbation effect. + Non-targeting control cells receive a score of 0. + Scores are z-score standardized relative to the control distribution. + + Examples: + Compute continuous perturbation scores: + + >>> import pertpy as pt + >>> mdata = pt.dt.papalexi_2021() + >>> ms_pt = pt.tl.Mixscape() + >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate") + >>> ms_pt.mixscale(mdata["rna"], "gene_target", "NT", layer="X_pert") + """ + if copy: + adata = adata.copy() + + if split_by is None: + split_masks = [np.full(adata.n_obs, True, dtype=bool)] + categories = ["all"] + else: + split_obs = adata.obs[split_by] + categories = split_obs.unique() + split_masks = [split_obs == category for category in categories] + + # Reuse the existing DE gene detection pipeline + perturbation_markers = self._get_perturbation_markers( + adata=adata, + split_masks=split_masks, + categories=categories, + pert_key=pert_key, + control=control, + layer=de_layer, + pval_cutoff=pval_cutoff, + min_de_genes=min_de_genes, + logfc_threshold=logfc_threshold, + test_method=test_method, + ) + + # Get perturbation signature matrix + if layer is not None: + X = adata.layers[layer] + else: + try: + X = adata.layers["X_pert"] + except KeyError: + raise KeyError("No 'X_pert' found in .layers! Please run perturbation_signature first.") from None + + # Initialize scores to 0 (NT control default) + adata.obs[new_class_name] = 0.0 + + for split, split_mask in enumerate(split_masks): + category = categories[split] + gene_targets = list(set(adata[split_mask].obs[pert_key]).difference([control])) + nt_cells = (adata.obs[pert_key] == control) & split_mask + + for gene in gene_targets: + guide_cells = (adata.obs[pert_key] == gene) & split_mask + all_cells = guide_cells | nt_cells + + if len(perturbation_markers[(category, gene)]) == 0: + continue + + de_genes = perturbation_markers[(category, gene)] + # Limit to max_de_genes + if len(de_genes) > max_de_genes: + de_genes = de_genes[:max_de_genes] + + de_genes_indices = np.where(np.isin(adata.var_names, list(de_genes)))[0] + + if len(de_genes_indices) == 0: + continue + + # Subset to DE genes for all relevant cells + dat = X[np.asarray(all_cells)][:, de_genes_indices] + if scale: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="zero-centering a sparse array/matrix densifies it.", + ) + dat = sc.pp.scale(dat) + + # Compute indices within the subsetted data + nt_cells_dat_idx = all_cells[all_cells].index.get_indexer(nt_cells[nt_cells].index) + guide_cells_dat_idx = all_cells[all_cells].index.get_indexer(guide_cells[guide_cells].index) + + # Compute perturbation direction vector + # (mean of perturbed cells minus mean of control cells) + guide_cells_mean = np.mean(dat[guide_cells_dat_idx], axis=0) + nt_cells_mean = np.mean(dat[nt_cells_dat_idx], axis=0) + vec = guide_cells_mean - nt_cells_mean + + # Scalar projection onto the perturbation direction + vec_norm_sq = np.dot(vec, vec) + if vec_norm_sq == 0: + continue + + pvec = dat.dot(vec) / vec_norm_sq if isinstance(dat, spmatrix) else np.dot(dat, vec) / vec_norm_sq + pvec = np.asarray(pvec).flatten() + + # Extract scores for guide and NT cells + guide_scores = pvec[guide_cells_dat_idx] + nt_scores = pvec[nt_cells_dat_idx] + + # Z-score standardization relative to NT controls + nt_mean = np.mean(nt_scores) + nt_std = np.std(nt_scores) + if nt_std == 0: + nt_std = 1.0 + + standardized_scores = (guide_scores - nt_mean) / nt_std + + # Store scores for perturbed cells + guide_cell_indices = guide_cells[guide_cells].index + adata.obs.loc[guide_cell_indices, new_class_name] = standardized_scores + + if copy: + return adata + def lda( self, adata: AnnData, diff --git a/tests/tools/test_mixscale.py b/tests/tools/test_mixscale.py new file mode 100644 index 00000000..3f7a51cc --- /dev/null +++ b/tests/tools/test_mixscale.py @@ -0,0 +1,157 @@ +"""Tests for Mixscape.mixscale continuous perturbation scoring.""" + +import numpy as np +import pytest +import scanpy as sc +from anndata import AnnData +from scipy.sparse import csr_matrix + +import pertpy as pt + + +@pytest.fixture +def synthetic_perturbation_adata(): + """Create synthetic perturbation data with known strong/weak effects.""" + np.random.seed(42) + n_genes = 200 + + # 100 NT controls, 100 strong KO, 100 weak KO for GeneA + # 50 cells for GeneB (moderate effect) + n_cells = 350 + + X = np.random.randn(n_cells, n_genes).astype(np.float32) + + # GeneA strong KO: large effect on first 20 genes + X[100:200, :20] -= 3.0 + # GeneA weak KO: small effect on first 20 genes + X[200:300, :20] -= 1.0 + # GeneB moderate: moderate effect on genes 20-40 + X[300:350, 20:40] -= 2.0 + + adata = AnnData(X=X) + adata.var_names = [f"Gene_{i}" for i in range(n_genes)] + adata.obs_names = [f"Cell_{i}" for i in range(n_cells)] + + labels = ["NT"] * 100 + ["GeneA"] * 100 + ["GeneA"] * 100 + ["GeneB"] * 50 + adata.obs["gene_target"] = labels + adata.obs["perturbation"] = ["NT" if x == "NT" else "targeting" for x in labels] + + sc.pp.pca(adata) + + return adata + + +class TestMixscale: + """Tests for the mixscale method.""" + + def test_basic_scoring(self, synthetic_perturbation_adata): + """Test that mixscale runs and produces scores.""" + adata = synthetic_perturbation_adata + ms = pt.tl.Mixscape() + ms.perturbation_signature(adata, "gene_target", "NT") + ms.mixscale(adata, "gene_target", "NT", layer="X_pert") + + assert "mixscale_score" in adata.obs.columns + assert adata.obs["mixscale_score"].dtype == float + + def test_control_cells_score_zero(self, synthetic_perturbation_adata): + """Control cells should have score 0.""" + adata = synthetic_perturbation_adata + ms = pt.tl.Mixscape() + ms.perturbation_signature(adata, "gene_target", "NT") + ms.mixscale(adata, "gene_target", "NT", layer="X_pert") + + nt_scores = adata.obs.loc[adata.obs["gene_target"] == "NT", "mixscale_score"] + assert (nt_scores == 0).all() + + def test_perturbed_cells_nonzero(self, synthetic_perturbation_adata): + """Perturbed cells should have non-zero scores.""" + adata = synthetic_perturbation_adata + ms = pt.tl.Mixscape() + ms.perturbation_signature(adata, "gene_target", "NT") + ms.mixscale(adata, "gene_target", "NT", layer="X_pert") + + ko_scores = adata.obs.loc[adata.obs["gene_target"] == "GeneA", "mixscale_score"] + assert ko_scores.abs().mean() > 0 + + def test_strong_vs_weak_perturbation(self, synthetic_perturbation_adata): + """Strongly perturbed cells should have higher absolute scores.""" + adata = synthetic_perturbation_adata + ms = pt.tl.Mixscape() + ms.perturbation_signature(adata, "gene_target", "NT") + ms.mixscale(adata, "gene_target", "NT", layer="X_pert") + + scores = adata.obs["mixscale_score"].values + # Cells 100-199 are strong KO, 200-299 are weak KO + strong_mean = np.abs(scores[100:200]).mean() + weak_mean = np.abs(scores[200:300]).mean() + + assert strong_mean > weak_mean, ( + f"Strong KO mean ({strong_mean:.2f}) should exceed weak KO mean ({weak_mean:.2f})" + ) + + def test_custom_column_name(self, synthetic_perturbation_adata): + """Test custom output column name.""" + adata = synthetic_perturbation_adata + ms = pt.tl.Mixscape() + ms.perturbation_signature(adata, "gene_target", "NT") + ms.mixscale( + adata, + "gene_target", + "NT", + layer="X_pert", + new_class_name="my_score", + ) + + assert "my_score" in adata.obs.columns + assert "mixscale_score" not in adata.obs.columns + + def test_copy_mode(self, synthetic_perturbation_adata): + """Test that copy=True returns a new object.""" + adata = synthetic_perturbation_adata + ms = pt.tl.Mixscape() + ms.perturbation_signature(adata, "gene_target", "NT") + result = ms.mixscale( + adata, + "gene_target", + "NT", + layer="X_pert", + copy=True, + ) + + assert result is not None + assert result is not adata + assert "mixscale_score" in result.obs.columns + + def test_no_perturbation_signature_raises(self, synthetic_perturbation_adata): + """Should raise KeyError if perturbation_signature hasn't been run.""" + adata = synthetic_perturbation_adata + ms = pt.tl.Mixscape() + + with pytest.raises(KeyError, match="X_pert"): + ms.mixscale(adata, "gene_target", "NT") + + def test_multiple_perturbations(self, synthetic_perturbation_adata): + """Test scoring with multiple perturbation groups.""" + adata = synthetic_perturbation_adata + ms = pt.tl.Mixscape() + ms.perturbation_signature(adata, "gene_target", "NT") + ms.mixscale(adata, "gene_target", "NT", layer="X_pert") + + # Both GeneA and GeneB should have scores + gene_a_scores = adata.obs.loc[adata.obs["gene_target"] == "GeneA", "mixscale_score"] + gene_b_scores = adata.obs.loc[adata.obs["gene_target"] == "GeneB", "mixscale_score"] + + assert gene_a_scores.abs().mean() > 0 + assert gene_b_scores.abs().mean() > 0 + + def test_sparse_input(self, synthetic_perturbation_adata): + """Test that mixscale works with sparse matrices.""" + adata = synthetic_perturbation_adata + adata.X = csr_matrix(adata.X) + ms = pt.tl.Mixscape() + ms.perturbation_signature(adata, "gene_target", "NT") + ms.mixscale(adata, "gene_target", "NT", layer="X_pert") + + assert "mixscale_score" in adata.obs.columns + assert not np.isnan(adata.obs.loc[adata.obs["gene_target"] != "NT", "mixscale_score"]).any()