Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/scembed/methods/gpu_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,7 @@ def __init__(
batch_size: int | None = None,
gene_likelihood: str | None = None,
check_val_every_n_epoch: int | None = None,
sample_key: str | None = None,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Behavior break vs. main. Previously sample_key was implicitly self.batch_key; now it defaults to None, which propagates through to scvi.external.SCVIVA.setup_anndata(sample_key=None) and raises KeyError: None (this is what's failing in CI for all three test_scviva_method parametrizations and test_embedding_retrieval[method_config4]).

Two reasonable options:

  1. Preserve old default — keep the kwarg, but fall back to batch_key when None:

    self.sample_key = sample_key if sample_key is not None else batch_key

    This way existing users keep working; new users can override.

  2. Make sample_key required-ish — keep None default but raise a clear error in fit() / setup() if it's still None at use time, and update the existing test_scviva_method to pass sample_key="batch" explicitly. Document the change as breaking in the PR body / CHANGELOG.

Either is fine, but the current state silently breaks every existing caller.

**kwargs,
):
"""
Expand Down Expand Up @@ -979,6 +980,8 @@ def __init__(
Likelihood, nb, zinb or poisson, see scvi docs.
check_val_every_n_epoch
Check validation loss every n epochs.
sample_key
Key in adata.obs indicating different slices/sections.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring should also state what None means and how this differs from batch_key. Suggested:

sample_key
    Key in ``adata.obs`` indicating different physical slices/sections (analogous to
    ``slice_key`` in ResolVI). Distinct from ``batch_key``, which can encode any
    experimental batch dimension. If ``None``, ... [whatever the chosen semantic is].

"""
super().__init__(
adata,
Expand All @@ -998,6 +1001,7 @@ def __init__(
batch_size=batch_size,
gene_likelihood=gene_likelihood,
check_val_every_n_epoch=check_val_every_n_epoch,
sample_key=sample_key,
**kwargs,
)

Expand All @@ -1018,6 +1022,7 @@ def __init__(
self.batch_size = batch_size
self.gene_likelihood = gene_likelihood
self.check_val_every_n_epoch = check_val_every_n_epoch
self.sample_key = sample_key

# Initialize models
self.embedding_model = None
Expand All @@ -1035,9 +1040,11 @@ def setup(self, expression_embedding_key: str, force_recompute: bool = False) ->
adata_prepared = self._prepare_hvg()

# Prepare preprocessing parameters, filtering out None values for k_nn only
# preprocessing_params are passed to preprocessing_anndata and setup_anndata
# as described in the tutorial

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per REVIEW_GUIDE.md (changed-path test lookup): methods/gpu_methods.pytests/methods/test_gpu_methods.py. The existing test_scviva_method doesn't cover the new sample_key parameter at all (and currently fails because of the new default — see line 937).

Once the default behavior is settled, please add a parametrize row exercising sample_key distinct from batch_key so the new functionality has a regression test. The existing fixture spatial_data should already have a usable obs column; reuse it rather than building a parallel fixture.

preprocessing_params = {
"adata": adata_prepared,
"sample_key": self.batch_key,
"sample_key": self.sample_key,
"labels_key": self.cell_type_key,
"cell_coordinates_key": self.spatial_key,
"expression_embedding_key": expression_embedding_key,
Expand Down Expand Up @@ -1125,7 +1132,7 @@ def fit(self):
adata_hvg,
layer=self.counts_layer,
batch_key=self.batch_key,
sample_key=self.batch_key,
sample_key=self.sample_key, # like slice_key in ResolVI

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two small things on this line:

  1. The # like slice_key in ResolVI comment is more useful in the docstring than at the call site — that's where users will look. Consider promoting it.
  2. With the current None default, this call hands None to setup_anndata, which is exactly what blows up in CI. Whatever fix is chosen for the default (see comment on line 937), the same value should also flow into preprocessing_anndata at line 1046 — they need to stay in sync.

labels_key=self.cell_type_key,
cell_coordinates_key=self.spatial_key,
expression_embedding_key=expression_embedding_key,
Expand Down
Loading