-
Notifications
You must be signed in to change notification settings - Fork 55
feat: add mixscale continuous perturbation scoring #945
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -390,6 +390,188 @@ def mixscape( | |
| if copy: | ||
| return adata | ||
|
|
||
| def mixscale( | ||
| self, | ||
| adata: AnnData, | ||
| pert_key: str, | ||
| control: str, | ||
| *, | ||
| new_class_name: str = "mixscale_score", | ||
| layer: str | None = None, | ||
| min_de_genes: int = 5, | ||
| max_de_genes: int = 100, | ||
| logfc_threshold: float = 0.25, | ||
| de_layer: str | None = None, | ||
| test_method: str = "wilcoxon", | ||
| scale: bool = True, | ||
| split_by: str | None = None, | ||
| pval_cutoff: float = 5e-2, | ||
| perturbation_type: str = "KO", | ||
| copy: bool = False, | ||
| ): | ||
| """Calculate continuous perturbation scores using the Mixscale method. | ||
|
|
||
| Unlike :meth:`mixscape` which performs binary KO/NP classification via | ||
| Gaussian Mixture Models, this method assigns a continuous perturbation | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please format these docstrings like the others are formatted. This is LLM generated. |
||
| efficiency score to each cell. The score is the scalar projection of | ||
| each cell's perturbation signature onto the estimated perturbation | ||
| direction vector, standardized relative to non-targeting controls. | ||
|
|
||
| This is particularly useful for CRISPRi/CRISPRa screens where cells | ||
| exhibit a gradient of perturbation responses rather than binary | ||
| knockouts. | ||
|
|
||
| The implementation follows Jiang, Dalgarno et al., "Systematic | ||
| reconstruction of molecular pathway signatures using scalable | ||
| single-cell perturbation screens", Nature Cell Biology (2025). | ||
|
|
||
| Args: | ||
| adata: The annotated data object. | ||
| pert_key: The column of `.obs` with target gene labels. | ||
| control: Control category from the `pert_key` column. | ||
| new_class_name: Name of the score column to be stored in `.obs`. | ||
| layer: Key from `adata.layers` whose value will be used for scoring. | ||
| Default is using `.layers["X_pert"]`. | ||
| min_de_genes: Required number of DE genes for scoring a perturbation. | ||
| Perturbations with fewer DE genes are skipped. | ||
| max_de_genes: Maximum number of DE genes to use for scoring. | ||
| logfc_threshold: Minimum log fold-change threshold for DE gene selection. | ||
| de_layer: Layer to use for identifying differentially expressed genes. | ||
| If `None`, `adata.X` is used. | ||
| test_method: Method to use for differential expression testing. | ||
| scale: Whether to scale the perturbation data before computing scores. | ||
| split_by: Provide `.obs` column with experimental condition/cell type | ||
| annotation, if perturbations are condition/cell type-specific. | ||
| pval_cutoff: P-value cut-off for selection of significantly DE genes. | ||
| perturbation_type: Type of CRISPR perturbation for labeling. | ||
| copy: Determines whether a copy of the `adata` is returned. | ||
|
|
||
| Returns: | ||
| If `copy=True`, returns the copy of `adata` with the scores in `.obs`. | ||
| Otherwise, writes the scores directly to `.obs` of the provided `adata`. | ||
|
|
||
| The following fields are added to `adata.obs`: | ||
|
|
||
| - `adata.obs[new_class_name]`: Continuous perturbation score per cell. | ||
| Higher absolute values indicate stronger perturbation effect. | ||
| Non-targeting control cells receive a score of 0. | ||
| Scores are z-score standardized relative to the control distribution. | ||
|
|
||
| Examples: | ||
| Compute continuous perturbation scores: | ||
|
|
||
| >>> import pertpy as pt | ||
| >>> mdata = pt.dt.papalexi_2021() | ||
| >>> ms_pt = pt.tl.Mixscape() | ||
| >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate") | ||
| >>> ms_pt.mixscale(mdata["rna"], "gene_target", "NT", layer="X_pert") | ||
| """ | ||
| if copy: | ||
| adata = adata.copy() | ||
|
|
||
| if split_by is None: | ||
| split_masks = [np.full(adata.n_obs, True, dtype=bool)] | ||
| categories = ["all"] | ||
| else: | ||
| split_obs = adata.obs[split_by] | ||
| categories = split_obs.unique() | ||
| split_masks = [split_obs == category for category in categories] | ||
|
|
||
| # Reuse the existing DE gene detection pipeline | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this is LLM noise. Please remove it |
||
| perturbation_markers = self._get_perturbation_markers( | ||
| adata=adata, | ||
| split_masks=split_masks, | ||
| categories=categories, | ||
| pert_key=pert_key, | ||
| control=control, | ||
| layer=de_layer, | ||
| pval_cutoff=pval_cutoff, | ||
| min_de_genes=min_de_genes, | ||
| logfc_threshold=logfc_threshold, | ||
| test_method=test_method, | ||
| ) | ||
|
|
||
| # Get perturbation signature matrix | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A bit of a noisy comment. |
||
| if layer is not None: | ||
| X = adata.layers[layer] | ||
| else: | ||
| try: | ||
| X = adata.layers["X_pert"] | ||
| except KeyError: | ||
| raise KeyError("No 'X_pert' found in .layers! Please run perturbation_signature first.") from None | ||
|
|
||
| # Initialize scores to 0 (NT control default) | ||
| adata.obs[new_class_name] = 0.0 | ||
|
|
||
| for split, split_mask in enumerate(split_masks): | ||
| category = categories[split] | ||
| gene_targets = list(set(adata[split_mask].obs[pert_key]).difference([control])) | ||
| nt_cells = (adata.obs[pert_key] == control) & split_mask | ||
|
|
||
| for gene in gene_targets: | ||
| guide_cells = (adata.obs[pert_key] == gene) & split_mask | ||
| all_cells = guide_cells | nt_cells | ||
|
|
||
| if len(perturbation_markers[(category, gene)]) == 0: | ||
| continue | ||
|
|
||
| de_genes = perturbation_markers[(category, gene)] | ||
| # Limit to max_de_genes | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LLM noise? |
||
| if len(de_genes) > max_de_genes: | ||
| de_genes = de_genes[:max_de_genes] | ||
|
|
||
| de_genes_indices = np.where(np.isin(adata.var_names, list(de_genes)))[0] | ||
|
|
||
| if len(de_genes_indices) == 0: | ||
| continue | ||
|
|
||
| # Subset to DE genes for all relevant cells | ||
| dat = X[np.asarray(all_cells)][:, de_genes_indices] | ||
| if scale: | ||
| with warnings.catch_warnings(): | ||
| warnings.filterwarnings( | ||
| "ignore", | ||
| message="zero-centering a sparse array/matrix densifies it.", | ||
| ) | ||
| dat = sc.pp.scale(dat) | ||
|
|
||
| # Compute indices within the subsetted data | ||
| nt_cells_dat_idx = all_cells[all_cells].index.get_indexer(nt_cells[nt_cells].index) | ||
| guide_cells_dat_idx = all_cells[all_cells].index.get_indexer(guide_cells[guide_cells].index) | ||
|
|
||
| # Compute perturbation direction vector | ||
| # (mean of perturbed cells minus mean of control cells) | ||
| guide_cells_mean = np.mean(dat[guide_cells_dat_idx], axis=0) | ||
| nt_cells_mean = np.mean(dat[nt_cells_dat_idx], axis=0) | ||
| vec = guide_cells_mean - nt_cells_mean | ||
|
|
||
| # Scalar projection onto the perturbation direction | ||
| vec_norm_sq = np.dot(vec, vec) | ||
| if vec_norm_sq == 0: | ||
| continue | ||
|
|
||
| pvec = dat.dot(vec) / vec_norm_sq if isinstance(dat, spmatrix) else np.dot(dat, vec) / vec_norm_sq | ||
| pvec = np.asarray(pvec).flatten() | ||
|
|
||
| # Extract scores for guide and NT cells | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LLM noise? |
||
| guide_scores = pvec[guide_cells_dat_idx] | ||
| nt_scores = pvec[nt_cells_dat_idx] | ||
|
|
||
| # Z-score standardization relative to NT controls | ||
| nt_mean = np.mean(nt_scores) | ||
| nt_std = np.std(nt_scores) | ||
| if nt_std == 0: | ||
| nt_std = 1.0 | ||
|
|
||
| standardized_scores = (guide_scores - nt_mean) / nt_std | ||
|
|
||
| # Store scores for perturbed cells | ||
| guide_cell_indices = guide_cells[guide_cells].index | ||
| adata.obs.loc[guide_cell_indices, new_class_name] = standardized_scores | ||
|
|
||
| if copy: | ||
| return adata | ||
|
|
||
| def lda( | ||
| self, | ||
| adata: AnnData, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,157 @@ | ||
| """Tests for Mixscape.mixscale continuous perturbation scoring.""" | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
| import scanpy as sc | ||
| from anndata import AnnData | ||
| from scipy.sparse import csr_matrix | ||
|
|
||
| import pertpy as pt | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def synthetic_perturbation_adata(): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you reuse the fixture that we're using to test mixscape? |
||
| """Create synthetic perturbation data with known strong/weak effects.""" | ||
| np.random.seed(42) | ||
| n_genes = 200 | ||
|
|
||
| # 100 NT controls, 100 strong KO, 100 weak KO for GeneA | ||
| # 50 cells for GeneB (moderate effect) | ||
| n_cells = 350 | ||
|
|
||
| X = np.random.randn(n_cells, n_genes).astype(np.float32) | ||
|
|
||
| # GeneA strong KO: large effect on first 20 genes | ||
| X[100:200, :20] -= 3.0 | ||
| # GeneA weak KO: small effect on first 20 genes | ||
| X[200:300, :20] -= 1.0 | ||
| # GeneB moderate: moderate effect on genes 20-40 | ||
| X[300:350, 20:40] -= 2.0 | ||
|
|
||
| adata = AnnData(X=X) | ||
| adata.var_names = [f"Gene_{i}" for i in range(n_genes)] | ||
| adata.obs_names = [f"Cell_{i}" for i in range(n_cells)] | ||
|
|
||
| labels = ["NT"] * 100 + ["GeneA"] * 100 + ["GeneA"] * 100 + ["GeneB"] * 50 | ||
| adata.obs["gene_target"] = labels | ||
| adata.obs["perturbation"] = ["NT" if x == "NT" else "targeting" for x in labels] | ||
|
|
||
| sc.pp.pca(adata) | ||
|
|
||
| return adata | ||
|
|
||
|
|
||
| class TestMixscale: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're not really using test classes in pertpy |
||
| """Tests for the mixscale method.""" | ||
|
|
||
| def test_basic_scoring(self, synthetic_perturbation_adata): | ||
| """Test that mixscale runs and produces scores.""" | ||
| adata = synthetic_perturbation_adata | ||
| ms = pt.tl.Mixscape() | ||
| ms.perturbation_signature(adata, "gene_target", "NT") | ||
| ms.mixscale(adata, "gene_target", "NT", layer="X_pert") | ||
|
|
||
| assert "mixscale_score" in adata.obs.columns | ||
| assert adata.obs["mixscale_score"].dtype == float | ||
|
|
||
| def test_control_cells_score_zero(self, synthetic_perturbation_adata): | ||
| """Control cells should have score 0.""" | ||
| adata = synthetic_perturbation_adata | ||
| ms = pt.tl.Mixscape() | ||
| ms.perturbation_signature(adata, "gene_target", "NT") | ||
| ms.mixscale(adata, "gene_target", "NT", layer="X_pert") | ||
|
|
||
| nt_scores = adata.obs.loc[adata.obs["gene_target"] == "NT", "mixscale_score"] | ||
| assert (nt_scores == 0).all() | ||
|
|
||
| def test_perturbed_cells_nonzero(self, synthetic_perturbation_adata): | ||
| """Perturbed cells should have non-zero scores.""" | ||
| adata = synthetic_perturbation_adata | ||
| ms = pt.tl.Mixscape() | ||
| ms.perturbation_signature(adata, "gene_target", "NT") | ||
| ms.mixscale(adata, "gene_target", "NT", layer="X_pert") | ||
|
|
||
| ko_scores = adata.obs.loc[adata.obs["gene_target"] == "GeneA", "mixscale_score"] | ||
| assert ko_scores.abs().mean() > 0 | ||
|
|
||
| def test_strong_vs_weak_perturbation(self, synthetic_perturbation_adata): | ||
| """Strongly perturbed cells should have higher absolute scores.""" | ||
| adata = synthetic_perturbation_adata | ||
| ms = pt.tl.Mixscape() | ||
| ms.perturbation_signature(adata, "gene_target", "NT") | ||
| ms.mixscale(adata, "gene_target", "NT", layer="X_pert") | ||
|
|
||
| scores = adata.obs["mixscale_score"].values | ||
| # Cells 100-199 are strong KO, 200-299 are weak KO | ||
| strong_mean = np.abs(scores[100:200]).mean() | ||
| weak_mean = np.abs(scores[200:300]).mean() | ||
|
|
||
| assert strong_mean > weak_mean, ( | ||
| f"Strong KO mean ({strong_mean:.2f}) should exceed weak KO mean ({weak_mean:.2f})" | ||
| ) | ||
|
|
||
| def test_custom_column_name(self, synthetic_perturbation_adata): | ||
| """Test custom output column name.""" | ||
| adata = synthetic_perturbation_adata | ||
| ms = pt.tl.Mixscape() | ||
| ms.perturbation_signature(adata, "gene_target", "NT") | ||
| ms.mixscale( | ||
| adata, | ||
| "gene_target", | ||
| "NT", | ||
| layer="X_pert", | ||
| new_class_name="my_score", | ||
| ) | ||
|
|
||
| assert "my_score" in adata.obs.columns | ||
| assert "mixscale_score" not in adata.obs.columns | ||
|
|
||
| def test_copy_mode(self, synthetic_perturbation_adata): | ||
| """Test that copy=True returns a new object.""" | ||
| adata = synthetic_perturbation_adata | ||
| ms = pt.tl.Mixscape() | ||
| ms.perturbation_signature(adata, "gene_target", "NT") | ||
| result = ms.mixscale( | ||
| adata, | ||
| "gene_target", | ||
| "NT", | ||
| layer="X_pert", | ||
| copy=True, | ||
| ) | ||
|
|
||
| assert result is not None | ||
| assert result is not adata | ||
| assert "mixscale_score" in result.obs.columns | ||
|
|
||
| def test_no_perturbation_signature_raises(self, synthetic_perturbation_adata): | ||
| """Should raise KeyError if perturbation_signature hasn't been run.""" | ||
| adata = synthetic_perturbation_adata | ||
| ms = pt.tl.Mixscape() | ||
|
|
||
| with pytest.raises(KeyError, match="X_pert"): | ||
| ms.mixscale(adata, "gene_target", "NT") | ||
|
|
||
| def test_multiple_perturbations(self, synthetic_perturbation_adata): | ||
| """Test scoring with multiple perturbation groups.""" | ||
| adata = synthetic_perturbation_adata | ||
| ms = pt.tl.Mixscape() | ||
| ms.perturbation_signature(adata, "gene_target", "NT") | ||
| ms.mixscale(adata, "gene_target", "NT", layer="X_pert") | ||
|
|
||
| # Both GeneA and GeneB should have scores | ||
| gene_a_scores = adata.obs.loc[adata.obs["gene_target"] == "GeneA", "mixscale_score"] | ||
| gene_b_scores = adata.obs.loc[adata.obs["gene_target"] == "GeneB", "mixscale_score"] | ||
|
|
||
| assert gene_a_scores.abs().mean() > 0 | ||
| assert gene_b_scores.abs().mean() > 0 | ||
|
|
||
| def test_sparse_input(self, synthetic_perturbation_adata): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A bit of a useless test. But ideally this were parametrized for the different array types instead of being its own test. |
||
| """Test that mixscale works with sparse matrices.""" | ||
| adata = synthetic_perturbation_adata | ||
| adata.X = csr_matrix(adata.X) | ||
| ms = pt.tl.Mixscape() | ||
| ms.perturbation_signature(adata, "gene_target", "NT") | ||
| ms.mixscale(adata, "gene_target", "NT", layer="X_pert") | ||
|
|
||
| assert "mixscale_score" in adata.obs.columns | ||
| assert not np.isnan(adata.obs.loc[adata.obs["gene_target"] != "NT", "mixscale_score"]).any() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sphinx ref doesn't work. Please fix it.