Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions pertpy/tools/_mixscape.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,188 @@ def mixscape(
if copy:
return adata

def mixscale(
self,
adata: AnnData,
pert_key: str,
control: str,
*,
new_class_name: str = "mixscale_score",
layer: str | None = None,
min_de_genes: int = 5,
max_de_genes: int = 100,
logfc_threshold: float = 0.25,
de_layer: str | None = None,
test_method: str = "wilcoxon",
scale: bool = True,
split_by: str | None = None,
pval_cutoff: float = 5e-2,
perturbation_type: str = "KO",
copy: bool = False,
):
"""Calculate continuous perturbation scores using the Mixscale method.

Unlike :meth:`mixscape` which performs binary KO/NP classification via
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sphinx ref doesn't work. Please fix it.

Gaussian Mixture Models, this method assigns a continuous perturbation
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please format these docstrings like the others are formatted. This is LLM generated.

efficiency score to each cell. The score is the scalar projection of
each cell's perturbation signature onto the estimated perturbation
direction vector, standardized relative to non-targeting controls.

This is particularly useful for CRISPRi/CRISPRa screens where cells
exhibit a gradient of perturbation responses rather than binary
knockouts.

The implementation follows Jiang, Dalgarno et al., "Systematic
reconstruction of molecular pathway signatures using scalable
single-cell perturbation screens", Nature Cell Biology (2025).

Args:
adata: The annotated data object.
pert_key: The column of `.obs` with target gene labels.
control: Control category from the `pert_key` column.
new_class_name: Name of the score column to be stored in `.obs`.
layer: Key from `adata.layers` whose value will be used for scoring.
Default is using `.layers["X_pert"]`.
min_de_genes: Required number of DE genes for scoring a perturbation.
Perturbations with fewer DE genes are skipped.
max_de_genes: Maximum number of DE genes to use for scoring.
logfc_threshold: Minimum log fold-change threshold for DE gene selection.
de_layer: Layer to use for identifying differentially expressed genes.
If `None`, `adata.X` is used.
test_method: Method to use for differential expression testing.
scale: Whether to scale the perturbation data before computing scores.
split_by: Provide `.obs` column with experimental condition/cell type
annotation, if perturbations are condition/cell type-specific.
pval_cutoff: P-value cut-off for selection of significantly DE genes.
perturbation_type: Type of CRISPR perturbation for labeling.
copy: Determines whether a copy of the `adata` is returned.

Returns:
If `copy=True`, returns the copy of `adata` with the scores in `.obs`.
Otherwise, writes the scores directly to `.obs` of the provided `adata`.

The following fields are added to `adata.obs`:

- `adata.obs[new_class_name]`: Continuous perturbation score per cell.
Higher absolute values indicate stronger perturbation effect.
Non-targeting control cells receive a score of 0.
Scores are z-score standardized relative to the control distribution.

Examples:
Compute continuous perturbation scores:

>>> import pertpy as pt
>>> mdata = pt.dt.papalexi_2021()
>>> ms_pt = pt.tl.Mixscape()
>>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate")
>>> ms_pt.mixscale(mdata["rna"], "gene_target", "NT", layer="X_pert")
"""
if copy:
adata = adata.copy()

if split_by is None:
split_masks = [np.full(adata.n_obs, True, dtype=bool)]
categories = ["all"]
else:
split_obs = adata.obs[split_by]
categories = split_obs.unique()
split_masks = [split_obs == category for category in categories]

# Reuse the existing DE gene detection pipeline
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is LLM noise. Please remove it

perturbation_markers = self._get_perturbation_markers(
adata=adata,
split_masks=split_masks,
categories=categories,
pert_key=pert_key,
control=control,
layer=de_layer,
pval_cutoff=pval_cutoff,
min_de_genes=min_de_genes,
logfc_threshold=logfc_threshold,
test_method=test_method,
)

# Get perturbation signature matrix
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit of a noisy comment.

if layer is not None:
X = adata.layers[layer]
else:
try:
X = adata.layers["X_pert"]
except KeyError:
raise KeyError("No 'X_pert' found in .layers! Please run perturbation_signature first.") from None

# Initialize scores to 0 (NT control default)
adata.obs[new_class_name] = 0.0

for split, split_mask in enumerate(split_masks):
category = categories[split]
gene_targets = list(set(adata[split_mask].obs[pert_key]).difference([control]))
nt_cells = (adata.obs[pert_key] == control) & split_mask

for gene in gene_targets:
guide_cells = (adata.obs[pert_key] == gene) & split_mask
all_cells = guide_cells | nt_cells

if len(perturbation_markers[(category, gene)]) == 0:
continue

de_genes = perturbation_markers[(category, gene)]
# Limit to max_de_genes
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LLM noise?

if len(de_genes) > max_de_genes:
de_genes = de_genes[:max_de_genes]

de_genes_indices = np.where(np.isin(adata.var_names, list(de_genes)))[0]

if len(de_genes_indices) == 0:
continue

# Subset to DE genes for all relevant cells
dat = X[np.asarray(all_cells)][:, de_genes_indices]
if scale:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="zero-centering a sparse array/matrix densifies it.",
)
dat = sc.pp.scale(dat)

# Compute indices within the subsetted data
nt_cells_dat_idx = all_cells[all_cells].index.get_indexer(nt_cells[nt_cells].index)
guide_cells_dat_idx = all_cells[all_cells].index.get_indexer(guide_cells[guide_cells].index)

# Compute perturbation direction vector
# (mean of perturbed cells minus mean of control cells)
guide_cells_mean = np.mean(dat[guide_cells_dat_idx], axis=0)
nt_cells_mean = np.mean(dat[nt_cells_dat_idx], axis=0)
vec = guide_cells_mean - nt_cells_mean

# Scalar projection onto the perturbation direction
vec_norm_sq = np.dot(vec, vec)
if vec_norm_sq == 0:
continue

pvec = dat.dot(vec) / vec_norm_sq if isinstance(dat, spmatrix) else np.dot(dat, vec) / vec_norm_sq
pvec = np.asarray(pvec).flatten()

# Extract scores for guide and NT cells
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LLM noise?

guide_scores = pvec[guide_cells_dat_idx]
nt_scores = pvec[nt_cells_dat_idx]

# Z-score standardization relative to NT controls
nt_mean = np.mean(nt_scores)
nt_std = np.std(nt_scores)
if nt_std == 0:
nt_std = 1.0

standardized_scores = (guide_scores - nt_mean) / nt_std

# Store scores for perturbed cells
guide_cell_indices = guide_cells[guide_cells].index
adata.obs.loc[guide_cell_indices, new_class_name] = standardized_scores

if copy:
return adata

def lda(
self,
adata: AnnData,
Expand Down
157 changes: 157 additions & 0 deletions tests/tools/test_mixscale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
"""Tests for Mixscape.mixscale continuous perturbation scoring."""

import numpy as np
import pytest
import scanpy as sc
from anndata import AnnData
from scipy.sparse import csr_matrix

import pertpy as pt


@pytest.fixture
def synthetic_perturbation_adata():
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you reuse the fixture that we're using to test mixscape?

"""Create synthetic perturbation data with known strong/weak effects."""
np.random.seed(42)
n_genes = 200

# 100 NT controls, 100 strong KO, 100 weak KO for GeneA
# 50 cells for GeneB (moderate effect)
n_cells = 350

X = np.random.randn(n_cells, n_genes).astype(np.float32)

# GeneA strong KO: large effect on first 20 genes
X[100:200, :20] -= 3.0
# GeneA weak KO: small effect on first 20 genes
X[200:300, :20] -= 1.0
# GeneB moderate: moderate effect on genes 20-40
X[300:350, 20:40] -= 2.0

adata = AnnData(X=X)
adata.var_names = [f"Gene_{i}" for i in range(n_genes)]
adata.obs_names = [f"Cell_{i}" for i in range(n_cells)]

labels = ["NT"] * 100 + ["GeneA"] * 100 + ["GeneA"] * 100 + ["GeneB"] * 50
adata.obs["gene_target"] = labels
adata.obs["perturbation"] = ["NT" if x == "NT" else "targeting" for x in labels]

sc.pp.pca(adata)

return adata


class TestMixscale:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're not really using test classes in pertpy

"""Tests for the mixscale method."""

def test_basic_scoring(self, synthetic_perturbation_adata):
"""Test that mixscale runs and produces scores."""
adata = synthetic_perturbation_adata
ms = pt.tl.Mixscape()
ms.perturbation_signature(adata, "gene_target", "NT")
ms.mixscale(adata, "gene_target", "NT", layer="X_pert")

assert "mixscale_score" in adata.obs.columns
assert adata.obs["mixscale_score"].dtype == float

def test_control_cells_score_zero(self, synthetic_perturbation_adata):
"""Control cells should have score 0."""
adata = synthetic_perturbation_adata
ms = pt.tl.Mixscape()
ms.perturbation_signature(adata, "gene_target", "NT")
ms.mixscale(adata, "gene_target", "NT", layer="X_pert")

nt_scores = adata.obs.loc[adata.obs["gene_target"] == "NT", "mixscale_score"]
assert (nt_scores == 0).all()

def test_perturbed_cells_nonzero(self, synthetic_perturbation_adata):
"""Perturbed cells should have non-zero scores."""
adata = synthetic_perturbation_adata
ms = pt.tl.Mixscape()
ms.perturbation_signature(adata, "gene_target", "NT")
ms.mixscale(adata, "gene_target", "NT", layer="X_pert")

ko_scores = adata.obs.loc[adata.obs["gene_target"] == "GeneA", "mixscale_score"]
assert ko_scores.abs().mean() > 0

def test_strong_vs_weak_perturbation(self, synthetic_perturbation_adata):
"""Strongly perturbed cells should have higher absolute scores."""
adata = synthetic_perturbation_adata
ms = pt.tl.Mixscape()
ms.perturbation_signature(adata, "gene_target", "NT")
ms.mixscale(adata, "gene_target", "NT", layer="X_pert")

scores = adata.obs["mixscale_score"].values
# Cells 100-199 are strong KO, 200-299 are weak KO
strong_mean = np.abs(scores[100:200]).mean()
weak_mean = np.abs(scores[200:300]).mean()

assert strong_mean > weak_mean, (
f"Strong KO mean ({strong_mean:.2f}) should exceed weak KO mean ({weak_mean:.2f})"
)

def test_custom_column_name(self, synthetic_perturbation_adata):
"""Test custom output column name."""
adata = synthetic_perturbation_adata
ms = pt.tl.Mixscape()
ms.perturbation_signature(adata, "gene_target", "NT")
ms.mixscale(
adata,
"gene_target",
"NT",
layer="X_pert",
new_class_name="my_score",
)

assert "my_score" in adata.obs.columns
assert "mixscale_score" not in adata.obs.columns

def test_copy_mode(self, synthetic_perturbation_adata):
"""Test that copy=True returns a new object."""
adata = synthetic_perturbation_adata
ms = pt.tl.Mixscape()
ms.perturbation_signature(adata, "gene_target", "NT")
result = ms.mixscale(
adata,
"gene_target",
"NT",
layer="X_pert",
copy=True,
)

assert result is not None
assert result is not adata
assert "mixscale_score" in result.obs.columns

def test_no_perturbation_signature_raises(self, synthetic_perturbation_adata):
"""Should raise KeyError if perturbation_signature hasn't been run."""
adata = synthetic_perturbation_adata
ms = pt.tl.Mixscape()

with pytest.raises(KeyError, match="X_pert"):
ms.mixscale(adata, "gene_target", "NT")

def test_multiple_perturbations(self, synthetic_perturbation_adata):
"""Test scoring with multiple perturbation groups."""
adata = synthetic_perturbation_adata
ms = pt.tl.Mixscape()
ms.perturbation_signature(adata, "gene_target", "NT")
ms.mixscale(adata, "gene_target", "NT", layer="X_pert")

# Both GeneA and GeneB should have scores
gene_a_scores = adata.obs.loc[adata.obs["gene_target"] == "GeneA", "mixscale_score"]
gene_b_scores = adata.obs.loc[adata.obs["gene_target"] == "GeneB", "mixscale_score"]

assert gene_a_scores.abs().mean() > 0
assert gene_b_scores.abs().mean() > 0

def test_sparse_input(self, synthetic_perturbation_adata):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit of a useless test. But ideally this were parametrized for the different array types instead of being its own test.

"""Test that mixscale works with sparse matrices."""
adata = synthetic_perturbation_adata
adata.X = csr_matrix(adata.X)
ms = pt.tl.Mixscape()
ms.perturbation_signature(adata, "gene_target", "NT")
ms.mixscale(adata, "gene_target", "NT", layer="X_pert")

assert "mixscale_score" in adata.obs.columns
assert not np.isnan(adata.obs.loc[adata.obs["gene_target"] != "NT", "mixscale_score"]).any()
Loading