diff --git a/CHANGELOG.md b/CHANGELOG.md index 756c28c..201e4c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,13 @@ and this project adheres to [Semantic Versioning][]. [keep a changelog]: https://keepachangelog.com/en/1.1.0/ [semantic versioning]: https://semver.org/spec/v2.0.0.html +## [0.4.0] (unreleased) + +### Changed + +- `update()` no longer automatically pulls obs/var columns from individual modalities by default. Set `mudata.set_options(pull_on_update=true)` + to restore the old behavior. Use `pull_obs/pull_var` and `push_obs/push_var` for more flexibility. + ## [0.3.4] ### Added @@ -144,6 +151,7 @@ To copy the annotations explicitly, you will need to use `pull_obs()` and/or `pu Initial `mudata` release with `MuData`, previously a part of the `muon` framework. +[0.4.0]: https://github.com/scverse/mudata/releases/tag/v0.4.0 [0.3.4]: https://github.com/scverse/mudata/releases/tag/v0.3.4 [0.3.3]: https://github.com/scverse/mudata/releases/tag/v0.3.3 [0.3.2]: https://github.com/scverse/mudata/releases/tag/v0.3.2 diff --git a/src/mudata/_core/config.py b/src/mudata/_core/config.py index e42fa4d..6190952 100644 --- a/src/mudata/_core/config.py +++ b/src/mudata/_core/config.py @@ -1,6 +1,6 @@ import logging as log -OPTIONS = {"display_style": "text", "display_html_expand": 0b010, "pull_on_update": None} +OPTIONS = {"display_style": "text", "display_html_expand": 0b010, "pull_on_update": False} _VALID_OPTIONS = { "display_style": lambda x: x in ("text", "html"), diff --git a/src/mudata/_core/io.py b/src/mudata/_core/io.py index 40f7146..eb94bd7 100644 --- a/src/mudata/_core/io.py +++ b/src/mudata/_core/io.py @@ -28,7 +28,6 @@ from anndata.compat import _read_attr from scipy import sparse -from .config import OPTIONS from .file_backing import AnnDataFileManager, MuDataFileManager from .mudata import ModDict, MuData @@ -46,22 +45,8 @@ def _is_openfile(obj) -> bool: def _write_h5mu(file: h5py.File, mdata: MuData, write_data=True, **kwargs): from .. import __anndataversion__, __mudataversion__, __version__ - write_elem( - file, - "obs", - mdata.strings_to_categoricals( - mdata._shrink_attr("obs", inplace=False).copy() if OPTIONS["pull_on_update"] is None else mdata.obs.copy() - ), - dataset_kwargs=kwargs, - ) - write_elem( - file, - "var", - mdata.strings_to_categoricals( - mdata._shrink_attr("var", inplace=False).copy() if OPTIONS["pull_on_update"] is None else mdata.var.copy() - ), - dataset_kwargs=kwargs, - ) + write_elem(file, "obs", mdata.strings_to_categoricals(mdata.obs.copy()), dataset_kwargs=kwargs) + write_elem(file, "var", mdata.strings_to_categoricals(mdata.var.copy()), dataset_kwargs=kwargs) write_elem(file, "obsm", dict(mdata.obsm), dataset_kwargs=kwargs) write_elem(file, "varm", dict(mdata.varm), dataset_kwargs=kwargs) write_elem(file, "obsp", dict(mdata.obsp), dataset_kwargs=kwargs) @@ -157,26 +142,8 @@ def write_zarr( # zarr_format is not supported in this version of zarr file = zarr.open(store, mode="w") mdata = data - write_elem( - file, - "obs", - mdata.strings_to_categoricals( - mdata._shrink_attr("obs", inplace=False).copy() - if OPTIONS["pull_on_update"] is None - else mdata.obs.copy() - ), - dataset_kwargs=kwargs, - ) - write_elem( - file, - "var", - mdata.strings_to_categoricals( - mdata._shrink_attr("var", inplace=False).copy() - if OPTIONS["pull_on_update"] is None - else mdata.var.copy() - ), - dataset_kwargs=kwargs, - ) + write_elem(file, "obs", mdata.strings_to_categoricals(mdata.obs.copy()), dataset_kwargs=kwargs) + write_elem(file, "var", mdata.strings_to_categoricals(mdata.var.copy()), dataset_kwargs=kwargs) write_elem(file, "obsm", dict(mdata.obsm), dataset_kwargs=kwargs) write_elem(file, "varm", dict(mdata.varm), dataset_kwargs=kwargs) write_elem(file, "obsp", dict(mdata.obsp), dataset_kwargs=kwargs) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index 61aa398..3f9cde3 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -550,21 +550,6 @@ def _update_attr( - are there intersecting obs_names/var_names between modalities? - have obs_names/var_names of modalities changed? """ - if OPTIONS["pull_on_update"] is None: - warnings.warn( - "From 0.4 .update() will not pull obs/var columns from individual modalities by default anymore. " - "Set mudata.set_options(pull_on_update=False) to adopt the new behaviour, which will become the default. " - "Use new pull_obs/pull_var and push_obs/push_var methods for more flexibility.", - FutureWarning, - stacklevel=2, - ) - - join_common = False - if "join_common" in kwargs: - join_common = kwargs.pop("join_common") - self._update_attr_legacy(attr, axis, join_common, **kwargs) - return - # No _attrhash when upon read # No _attrhash in mudata < 0.2.0 _attrhash = f"_{attr}hash" @@ -626,7 +611,7 @@ def calc_attrm_update(): 0 ] # renamed (since new_idx.shape[0] > 0 and kept_idx.shape[0] < data_global.shape[0]) or ( - axis == self.axis and axis != -1 and data_mod.shape[0] > data_global.shape[0] + axis != self.axis and axis != -1 and data_mod.shape[0] > data_global.shape[0] ) # new modality added and concacenated ) @@ -636,9 +621,7 @@ def calc_attrm_update(): # Main case: no duplicates and no intersection if the axis is not shared if not attr_duplicated: # Shared axis - data_mod = pd.concat( - dfs, join="outer", axis=1 if axis == (1 - self._axis) or self._axis == -1 else 0, sort=False - ) + data_mod = pd.concat(dfs, join="outer", axis=1 if axis == self._axis or self._axis == -1 else 0, sort=False) for mod in self._mod.keys(): fix_attrmap_col(data_mod, mod, rowcol) @@ -660,9 +643,7 @@ def calc_attrm_update(): # else: dfs = [_make_index_unique(df, force=True) for df in dfs] - data_mod = pd.concat( - dfs, join="outer", axis=1 if axis == (1 - self._axis) or self._axis == -1 else 0, sort=False - ) + data_mod = pd.concat(dfs, join="outer", axis=1 if axis == self._axis or self._axis == -1 else 0, sort=False) data_mod = _restore_index(data_mod) data_mod.index.set_names(rowcol, inplace=True) @@ -775,413 +756,6 @@ def calc_attrm_update(): if OPTIONS["pull_on_update"]: self._pull_attr(attr, **kwargs) - def _update_attr_legacy( - self, - attr: str, - axis: int, - join_common: bool = False, - **kwargs, # for _pull_attr() - ): - """ - Update global observations/variables with observations/variables for each modality. - - This method will be removed in the next versions. See _update_attr() instead. - """ - prev_index = getattr(self, attr).index - - # No _attrhash when upon read - # No _attrhash in mudata < 0.2.0 - _attrhash = f"_{attr}hash" - attr_changed = self._check_changed_attr_names(attr, columns=True) - - attr_duplicated = self._check_duplicated_attr_names(attr) - attr_intersecting = self._check_intersecting_attr_names(attr) - - if attr_duplicated: - warnings.warn( - f"{attr}_names are not unique. To make them unique, call `.{attr}_names_make_unique`.", stacklevel=2 - ) - if self._axis == -1: - warnings.warn( - f"Behaviour is not defined with axis=-1, {attr}_names need to be made unique first.", stacklevel=2 - ) - - if not any(attr_changed): - # Nothing to update - return - - # Check if the are same obs_names/var_names in different modalities - # If there are, join_common=True request can not be satisfied - if join_common: - if attr_intersecting: - warnings.warn( - f"Cannot join columns with the same name because {attr}_names are intersecting.", stacklevel=2 - ) - join_common = False - - # Figure out which global columns exist - columns_global = getattr(self, attr).columns[ - list( - map( - all, - zip( - *[ - [ - not col.startswith(mod + ":") - or col[col.startswith(mod + ":") and len(mod + ":") :] - not in getattr(self._mod[mod], attr).columns - for col in getattr(self, attr).columns - ] - for mod in self._mod - ], - strict=False, - ), - ) - ) - ] - - # Keep data from global .obs/.var columns - data_global = getattr(self, attr).loc[:, columns_global] - - # Generate unique colnames - (rowcol,) = self._find_unique_colnames(attr, 1) - - attrm = getattr(self, attr + "m") - attrp = getattr(self, attr + "p") - attrmap = getattr(self, f"_{attr}map") - - if join_common: - # If all modalities have a column with the same name, it is not global - columns_common = reduce( - lambda a, b: a.intersection(b), [getattr(self._mod[mod], attr).columns for mod in self._mod] - ) - data_global = data_global.loc[:, [c not in columns_common for c in data_global.columns]] - - # TODO: take advantage when attr_changed[0] == False — only new columns to be added - - # - # Join modality .obs/.var tables - # - # Main case: no duplicates and no intersection if the axis is not shared - # - if not attr_duplicated: - # Shared axis - if axis == (1 - self._axis) or self._axis == -1: - # We assume attr_intersecting and can't join_common - data_mod = try_convert_dataframe_to_numpy_dtypes( - pd.concat( - [ - getattr(a, attr) - .assign(**{rowcol: np.arange(getattr(a, attr).shape[0])}) - .add_prefix(m + ":") - .convert_dtypes() - for m, a in self._mod.items() - ], - join="outer", - axis=1, - sort=False, - ) - ) - else: - if join_common: - # We checked above that attr_names are guaranteed to be unique and thus are safe to be used for joins - data_mod = pd.concat( - [ - getattr(a, attr) - .drop(columns_common, axis=1) - .assign(**{rowcol: np.arange(getattr(a, attr).shape[0])}) - .add_prefix(m + ":") - .convert_dtypes() - for m, a in self._mod.items() - ], - join="outer", - axis=0, - sort=False, - ) - data_common = pd.concat( - [getattr(a, attr)[columns_common].convert_dtypes() for m, a in self._mod.items()], - join="outer", - axis=0, - sort=False, - ) - - data_mod = try_convert_dataframe_to_numpy_dtypes(data_mod.join(data_common, how="left", sort=False)) - data_common = try_convert_dataframe_to_numpy_dtypes(data_common) - - # this occurs when join_common=True and we already have a global data frame, e.g. after reading from H5MU - sharedcols = data_mod.columns.intersection(data_global.columns) - data_global.rename(columns={col: f"global:{col}" for col in sharedcols}, inplace=True) - else: - data_mod = try_convert_dataframe_to_numpy_dtypes( - pd.concat( - [ - getattr(a, attr) - .assign(**{rowcol: np.arange(getattr(a, attr).shape[0])}) - .add_prefix(m + ":") - .convert_dtypes() - for m, a in self._mod.items() - ], - join="outer", - axis=0, - sort=False, - ) - ) - - for mod in self._mod.keys(): - colname = mod + ":" + rowcol - # use 0 as special value for missing - # we could use a pandas.array, which has missing values support, but then we get an Exception upon hdf5 write - # also, this is compatible to Muon.jl - col = data_mod[colname] + 1 - col.replace(np.nan, 0, inplace=True) - data_mod[colname] = col.astype(np.uint32) - - if len(data_global.columns) > 0: - # TODO: if there were intersecting attrnames between modalities, - # this will increase the size of the index - # Should we use attrmap to figure the index out? - # - if not attr_intersecting: - data_mod = data_mod.join(data_global, how="left", sort=False) - else: - # In order to preserve the order of the index, instead, - # perform a join based on (index, cumcount) pairs. - col_index, col_cumcount = self._find_unique_colnames(attr, 2) - data_mod = data_mod.rename_axis(col_index, axis=0).reset_index() - data_mod[col_cumcount] = data_mod.groupby(col_index).cumcount() - data_global = data_global.rename_axis(col_index, axis=0).reset_index() - data_global[col_cumcount] = data_global.reset_index().groupby(col_index).cumcount() - data_mod = data_mod.merge(data_global, on=[col_index, col_cumcount], how="left", sort=False) - # Restore the index and remove the helper column - data_mod = data_mod.set_index(col_index).rename_axis(None, axis=0) - del data_mod[col_cumcount] - data_global = data_global.set_index(col_index).rename_axis(None, axis=0) - del data_global[col_cumcount] - - # - # General case: with duplicates and/or intersections - # - else: - if join_common: - dfs = [ - _make_index_unique( - getattr(a, attr) - .drop(columns_common, axis=1) - .assign(**{rowcol: np.arange(getattr(a, attr).shape[0])}) - .add_prefix(m + ":"), - force=True, - ).convert_dtypes() - for m, a in self._mod.items() - ] - - # Here, attr_names are guaranteed to be unique and are safe to be used for joins - data_mod = pd.concat(dfs, join="outer", axis=axis, sort=False) - - data_common = pd.concat( - [ - _make_index_unique(getattr(a, attr)[columns_common], force=True).convert_dtypes() - for m, a in self._mod.items() - ], - join="outer", - axis=0, - sort=False, - ) - - data_mod = try_convert_dataframe_to_numpy_dtypes(data_mod.join(data_common, how="left", sort=False)) - data_common = try_convert_dataframe_to_numpy_dtypes(data_common) - else: - dfs = [ - _make_index_unique( - getattr(a, attr).assign(**{rowcol: np.arange(getattr(a, attr).shape[0])}).add_prefix(m + ":"), - force=True, - ) - for m, a in self._mod.items() - ] - data_mod = pd.concat(dfs, join="outer", axis=axis, sort=False) - - # pd.concat wrecks the ordering when doing an outer join with a MultiIndex and different data frame shapes - if axis == 1: - newidx = ( - reduce(lambda x, y: x.union(y, sort=False), (df.index for df in dfs)) - .to_frame() - .reset_index(level=1, drop=True) - ) - globalidx = data_global.index.get_level_values(0) - mask = globalidx.isin(newidx.iloc[:, 0]) - if len(mask) > 0: - negativemask = ~newidx.index.get_level_values(0).isin(globalidx) - newidx = pd.MultiIndex.from_frame( - pd.concat([newidx.loc[globalidx[mask], :], newidx.iloc[negativemask, :]], axis=0) - ) - data_mod = data_mod.reindex(newidx, copy=False) - - # this occurs when join_common=True and we already have a global data frame, e.g. after reading from HDF5 - if join_common: - sharedcols = data_mod.columns.intersection(data_global.columns) - data_global.rename(columns={col: f"global:{col}" for col in sharedcols}, inplace=True) - - data_mod = _restore_index(data_mod) - data_mod.index.set_names(rowcol, inplace=True) - data_global.index.set_names(rowcol, inplace=True) - for mod, amod in self._mod.items(): - colname = mod + ":" + rowcol - # use 0 as special value for missing - # we could use a pandas.array, which has missing values support, but then we get an Exception upon hdf5 write - # also, this is compatible to Muon.jl - col = data_mod.loc[:, colname] + 1 - col.replace(np.nan, 0, inplace=True) - col = col.astype(np.uint32) - data_mod.loc[:, colname] = col - data_mod.set_index(colname, append=True, inplace=True) - if mod in attrmap and np.sum(attrmap[mod] > 0) == getattr(amod, attr).shape[0]: - data_global.set_index(attrmap[mod].ravel(), append=True, inplace=True) - data_global.index.set_names(colname, level=-1, inplace=True) - - if len(data_global) > 0: - if not data_global.index.is_unique: - warnings.warn( - f"{attr}_names is not unique, global {attr} is present, and {attr}map is empty. The update() is not well-defined, verify if global {attr} map to the correct modality-specific {attr}.", - stacklevel=2, - ) - data_mod.reset_index(data_mod.index.names.difference(data_global.index.names), inplace=True) - data_mod = _make_index_unique(data_mod, force=True) - data_global = _make_index_unique(data_global, force=True) - data_mod = data_mod.join(data_global, how="left", sort=False) - data_mod.reset_index(level=list(range(1, data_mod.index.nlevels)), inplace=True) - data_mod.index.set_names(None, inplace=True) - - if join_common: - for col in sharedcols: - gcol = f"global:{col}" - if data_mod[col].equals(data_mod[gcol]): - data_mod.drop(columns=gcol, inplace=True) - else: - warnings.warn( - f"Column {col} was present in {attr} but is also a common column in all modalities, and their contents differ. {attr}.{col} was renamed to {attr}.{gcol}.", - stacklevel=2, - ) - - # get adata positions and remove columns from the data frame - mdict = {} - for m in self._mod.keys(): - colname = m + ":" + rowcol - mdict[m] = data_mod[colname].to_numpy() - data_mod.drop(colname, axis=1, inplace=True) - - # Add data from global .obs/.var columns # This might reduce the size of .obs/.var if observations/variables were removed - setattr( - # Original index is present in data_global - self, - "_" + attr, - data_mod, - ) - - # Update .obsm/.varm - # this needs to be after setting _obs/_var due to dimension checking in the aligned mapping - attrmap.clear() - attrmap.update(mdict) - for mod, mapping in mdict.items(): - attrm[mod] = mapping > 0 - - now_index = getattr(self, attr).index - - if len(prev_index) == 0: - # New object - pass - elif now_index.equals(prev_index): - # Index is the same - pass - else: - keep_index = prev_index.isin(now_index) - new_index = ~now_index.isin(prev_index) - - if new_index.sum() == 0 or ( - keep_index.sum() + new_index.sum() == len(now_index) and len(now_index) > len(prev_index) - ): - # Another length (filtered) or new modality added - # Update .obsm/.varm (size might have changed) - # NOTE: .get_index doesn't work with duplicated indices - if any(prev_index.duplicated()): - # Assume the relative order of duplicates hasn't changed - # NOTE: .get_loc() for each element is too slow - # We will rename duplicated in prev_index and now_index - # in order to use .get_indexer - # index_order = [ - # prev_index.get_loc(i) if i in prev_index else -1 for i in now_index - # ] - prev_values = prev_index.values.copy() - now_values = now_index.values.copy() - for value in prev_index[np.where(prev_index.duplicated())[0]]: - v_now = np.where(now_index == value)[0] - v_prev = np.where(prev_index.get_loc(value))[0] - for i in range(min(len(v_now), len(v_prev))): - prev_values[v_prev[i]] = f"{str(value)}-{i}" - now_values[v_now[i]] = f"{str(value)}-{i}" - - prev_index = pd.Index(prev_values) - now_index = pd.Index(now_values) - - index_order = prev_index.get_indexer(now_index) - - for mx_key in attrm.keys(): - if mx_key not in self._mod.keys(): # not a modality name - attrm[mx_key] = attrm[mx_key][index_order] - attrm[mx_key][index_order == -1] = np.nan - - # Update .obsp/.varp (size might have changed) - for mx_key in attrp.keys(): - attrp[mx_key] = attrp[mx_key][index_order, :][:, index_order] - attrp[mx_key][index_order == -1, :] = -1 - attrp[mx_key][:, index_order == -1] = -1 - - elif len(now_index) == len(prev_index): - # Renamed since new_index.sum() != 0 - # We have to assume the order hasn't changed - pass - - else: - raise NotImplementedError( - f"{attr}_names seem to have been renamed and filtered at the same time. " - "There is no way to restore the order. MuData object has to be re-created from these modalities:\n" - " mdata1 = MuData(mdata.mod)" - ) - - # Write _attrhash - if attr_changed: - if not hasattr(self, _attrhash): - setattr(self, _attrhash, {}) - for m, mod in self._mod.items(): - getattr(self, _attrhash)[m] = ( - sha1(np.ascontiguousarray(getattr(mod, attr).index.values)).hexdigest(), - sha1(np.ascontiguousarray(getattr(mod, attr).columns.values)).hexdigest(), - ) - - def _shrink_attr(self, attr: str, inplace=True) -> pd.DataFrame: - """Remove observations/variables for each modality from the global observations/variables table.""" - # Figure out which global columns exist - columns_global = list( - map( - all, - zip( - *([not col.startswith(mod + ":") for col in getattr(self, attr).columns] for mod in self._mod), - strict=False, - ), - ) - ) - # Make sure modname-prefix columns exist in modalities, - # keep them in place if they don't - for mod in self._mod: - for i, col in enumerate(getattr(self, attr).columns): - if col.startswith(mod + ":"): - mcol = col[len(mod) + 1 :] - if mcol not in getattr(self._mod[mod], attr).columns: - columns_global[i] = True - # Only keep data from global .obs/.var columns - newdf = getattr(self, attr).loc[:, columns_global] - if inplace: - setattr(self, attr, newdf) - return newdf - @property def n_mod(self) -> int: """Number of modalities.""" @@ -1263,13 +837,9 @@ def obs_vector(self, key: str, layer: str | None = None) -> np.ndarray: return self._attr_vector(key, "obs") def update_obs(self): - """Update :attr:`obs` indices of the object with the data from all the modalities. - - .. note:: - From v0.4, it will not pull columns from modalities by default. - """ + """Update :attr:`obs` indices of the object with the data from all the modalities.""" join_common = self.axis == 1 - self._update_attr("obs", axis=1, join_common=join_common) + self._update_attr("obs", axis=0, join_common=join_common) def _names_make_unique(self, attr: Literal["obs", "var"]): axis = 0 if attr == "obs" else 1 @@ -1407,13 +977,9 @@ def var_vector(self, key: str, layer: str | None = None) -> np.ndarray: return self._attr_vector(key, "var") def update_var(self): - """Update :attr:`var` indices of the object with the data from all the modalities. - - .. note:: - From v0.4, it will not pull columns from modalities by default. - """ + """Update :attr:`var` indices of the object with the data from all the modalities.""" join_common = self.axis == 0 - self._update_attr("var", axis=0, join_common=join_common) + self._update_attr("var", axis=1, join_common=join_common) def var_names_make_unique(self): """ @@ -1587,11 +1153,7 @@ def uns_keys(self) -> list[str]: return list(self._uns.keys()) def update(self): - """Update both :attr:`obs` and :attr:`var` indices of the object with the data from all the modalities. - - .. note:: - From v0.4, it will not pull columns from modalities by default. - """ + """Update both :attr:`obs` and :attr:`var` indices of the object with the data from all the modalities.""" if len(self._mod) > 0: self.update_var() self.update_obs() diff --git a/tests/test_io.py b/tests/test_io.py index fa979cf..3948bab 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -78,13 +78,12 @@ def test_write_read_mod_obs_colname( mdata.obs["column"] = 0 mdata.obs["mod1:column"] = 1 mdata["mod1"].obs["column"] = 2 - mdata.update() getattr(mdata, write_func)(filepath) mdata_ = getattr(md, read_func)(filepath) assert "column" in mdata_.obs.columns assert "mod1:column" in mdata_.obs.columns - # 2 should supercede 1 on .update() - assert mdata_.obs["mod1:column"].values[0] == 2 + # 2 should not overwrite 1 on .update() + assert mdata_.obs["mod1:column"].values[0] == 1 def test_h5mu_backed(mdata: md.MuData, filepath_h5mu: str | Path, filepath2_h5mu: str | Path): diff --git a/tests/test_obs_var.py b/tests/test_obs_var.py index c0155dd..a0970f2 100644 --- a/tests/test_obs_var.py +++ b/tests/test_obs_var.py @@ -8,20 +8,24 @@ @pytest.mark.parametrize("mdata", (0, 1), indirect=True) -def test_obs_global_columns(mdata: md.MuData, filepath_h5mu: str | Path): +@pytest.mark.parametrize("pull_on_update", (False, True)) +def test_obs_global_columns(mdata: md.MuData, pull_on_update: bool, filepath_h5mu: str | Path): mdata.obs.drop(columns=mdata.obs.columns, inplace=True) for m, mod in mdata.mod.items(): mod.obs.drop(columns=mod.obs.columns, inplace=True) mod.obs["demo"] = m mdata.obs["demo"] = "global" - mdata.update() - if mdata.axis == 0: - assert list(mdata.obs.columns.values) == [f"{m}:demo" for m in mdata.mod.keys()] + ["demo"] + if pull_on_update: + with md.set_options(pull_on_update=pull_on_update): + del mdata._obshash + mdata.update() + if mdata.axis == 0 and pull_on_update: + assert mdata.obs.columns.to_list() == ["demo"] + [f"{m}:demo" for m in mdata.mod.keys()] else: - assert list(mdata.obs.columns.values) == ["demo"] + assert mdata.obs.columns.to_list() == ["demo"] mdata.write(filepath_h5mu) mdata_ = md.read(filepath_h5mu) - assert list(mdata_.obs.columns.values) == list(mdata.obs.columns.values) + assert (mdata_.obs.columns == mdata.obs.columns.values).all() @pytest.mark.parametrize("mdata", (0, 1), indirect=True) @@ -45,23 +49,32 @@ def test_obs_vector(mdata: md.MuData): @pytest.mark.parametrize("mdata", (0, 1), indirect=True) -def test_var_global_columns(mdata: md.MuData, filepath_h5mu: str | Path): +@pytest.mark.parametrize("pull_on_update", (False, True)) +def test_var_global_columns(mdata: md.MuData, pull_on_update, filepath_h5mu: str | Path): mdata.var.drop(columns=mdata.var.columns, inplace=True) for m, mod in mdata.mod.items(): mod.var.drop(columns=mod.var.columns, inplace=True) mod.var["demo"] = m mdata.var["global"] = "global_var" - mdata.update() - if mdata.axis == 0: - assert list(mdata.var.columns.values) == ["demo", "global"] + if pull_on_update: + with md.set_options(pull_on_update=pull_on_update): + del mdata._varhash + mdata.update() + if not pull_on_update: + assert mdata.var.columns.to_list() == ["global"] + elif mdata.axis == 0: + assert mdata.var.columns.to_list() == ["global", "demo"] else: - assert list(mdata.var.columns.values) == [f"{m}:demo" for m in mdata.mod.keys()] + ["global"] + assert mdata.var.columns.to_list() == ["global"] + [f"{m}:demo" for m in mdata.mod.keys()] del mdata.var["global"] - mdata.update() - if mdata.axis == 0: - assert list(mdata.var.columns.values) == ["demo"] + with md.set_options(pull_on_update=pull_on_update): + mdata.update() + if not pull_on_update: + assert mdata.var.shape[1] == 0 + elif mdata.axis == 0: + assert mdata.var.columns.to_list() == ["demo"] else: - assert list(mdata.var.columns.values) == [f"{m}:demo" for m in mdata.mod.keys()] + assert mdata.var.columns.to_list() == [f"{m}:demo" for m in mdata.mod.keys()] mdata.write(filepath_h5mu) mdata_ = md.read(filepath_h5mu) assert list(mdata_.var.columns.values) == list(mdata.var.columns.values) diff --git a/tests/test_pull_push.py b/tests/test_pull_push.py index 64d25ee..e47617b 100644 --- a/tests/test_pull_push.py +++ b/tests/test_pull_push.py @@ -6,7 +6,7 @@ import pytest from anndata import AnnData -from mudata import MuData, set_options +from mudata import MuData Axis: TypeAlias = Literal[0, 1] AxisAttr: TypeAlias = Literal["obs", "var"] @@ -37,21 +37,9 @@ def unique(request: pytest.FixtureRequest) -> bool: return request.param -@pytest.fixture -def new_update() -> None: - set_options(pull_on_update=False) - yield - set_options(pull_on_update=None) - - @pytest.fixture def mdata( - rng: np.random.Generator, - axis: Axis, - attr: AxisAttr, - n: Literal["joint", "disjoint"], - unique: bool, - new_update: None, + rng: np.random.Generator, axis: Axis, attr: AxisAttr, n: Literal["joint", "disjoint"], unique: bool ) -> MuData: n_mod = 3 mods = {} @@ -107,7 +95,7 @@ def mdata( @pytest.fixture -def mdata_for_push(rng: np.random.Generator, mdata: MuData, new_update: None) -> MuData: +def mdata_for_push(rng: np.random.Generator, mdata: MuData) -> MuData: for axis, attr in enumerate(("obs", "var")): df = getattr(mdata, attr) diff --git a/tests/test_repr.py b/tests/test_repr.py index d32da61..af7cec9 100644 --- a/tests/test_repr.py +++ b/tests/test_repr.py @@ -11,7 +11,6 @@ def test_repr(mdata: md.MuData): assert rep[0] == f"MuData object with n_obs × n_vars = {mdata.n_obs} × {mdata.n_vars}" assert rep[1].lstrip().startswith("obs:") - assert rep[2].lstrip().startswith("var:") for col in mdata.obs.columns: if not any(col.startswith(f"{mod}:") for mod in mdata.mod_names): @@ -20,14 +19,14 @@ def test_repr(mdata: md.MuData): if not any(col.startswith(f"{mod}:") for mod in mdata.mod_names): assert col in rep[2] - assert rep[3].strip() == f"{mdata.n_mod} modalities" + assert rep[2].strip() == f"{mdata.n_mod} modalities" indentation = 1e6 - for line in rep[4:]: + for line in rep[3:]: for i, char in enumerate(line): if not char.isspace(): indentation = min(indentation, i) - for line in rep[4:]: + for line in rep[3:]: if not line[indentation].isspace(): # modality header match = modality_header_pattern.fullmatch(line) assert match is not None diff --git a/tests/test_update.py b/tests/test_update.py index 938529f..a573c79 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -8,11 +8,31 @@ from anndata import AnnData from scipy.sparse import csc_array, csr_array -from mudata import MuData, set_options +from mudata import MuData Axis: TypeAlias = Literal[0, 1] +@pytest.fixture(params=(0, 1)) +def axis(request: pytest.FixtureRequest) -> Axis: + return request.param + + +@pytest.fixture(params=("unique", "duplicated", "extreme_duplicated")) +def mod(request: pytest.FixtureRequest) -> Literal["unique", "duplicated", "extreme_duplicated"]: + return request.param + + +@pytest.fixture(params=("intersecting",)) +def across(request: pytest.FixtureRequest) -> Literal["intersecting"]: + return request.param + + +@pytest.fixture(params=("joint", "disjoint")) +def n(request: pytest.FixtureRequest) -> Literal["joint", "disjoint"]: + return request.param + + @pytest.fixture def modalities( rng: np.random.Generator, @@ -93,17 +113,6 @@ def add_mdata_global_columns(md: MuData, rng: np.random.Generator): return md -@pytest.fixture -def mdata_legacy(rng: np.random.Generator, modalities: Mapping[str, AnnData], axis: Axis): - mdata = MuData(modalities, axis=axis) - - batches = rng.choice(["a", "b", "c"], size=mdata.n_obs, replace=True) - mdata.obs["batch"] = batches - mdata.var["genesets"] = rng.choice(["a", "b", "c"], size=mdata.n_vars, replace=True) - - return mdata - - @pytest.fixture def mdata(rng: np.random.Generator, modalities: Mapping[str, AnnData], axis: Axis): md = MuData(modalities, axis=axis) @@ -111,402 +120,252 @@ def mdata(rng: np.random.Generator, modalities: Mapping[str, AnnData], axis: Axi return add_mdata_global_columns(md, rng) -@pytest.mark.parametrize("axis", [0, 1]) -@pytest.mark.parametrize("mod", ["unique", "duplicated", "extreme_duplicated"]) -@pytest.mark.parametrize("across", ["intersecting"]) -@pytest.mark.parametrize("n", ["joint", "disjoint"]) -class TestMuData: - @pytest.fixture(autouse=True) - def new_update(self): - set_options(pull_on_update=False) - yield - set_options(pull_on_update=None) - - @staticmethod - def get_attrm_values(mdata: MuData, attr: str, key: str, names: Sequence[str]): - attrm = getattr(mdata, f"{attr}m") - index = getattr(mdata, f"{attr}_names") - return np.concatenate([np.atleast_1d(attrm[key][np.nonzero(index == name)[0]]) for name in names]) - - @staticmethod - def assert_dtypes(df: pd.DataFrame): - assert pd.api.types.is_integer_dtype(df["dtype-int"]) - assert pd.api.types.is_float_dtype(df["dtype-float"]) - assert pd.api.types.is_bool_dtype(df["dtype-bool"]) - assert pd.api.types.is_categorical_dtype(df["dtype-categorical"]) - assert pd.api.types.is_string_dtype(df["batch"]) or df["batch"].dtype == object - - def test_update_simple(self, mdata: MuData, axis: Axis): - """ - Update should work when - - obs_names are the same across modalities, - - var_names are unique to each modality - """ - attr = "obs" if axis == 0 else "var" - oattr = "var" if axis == 0 else "obs" - - for mod in mdata.mod.keys(): - assert mdata.obsmap[mod].dtype.kind == "u" - assert mdata.varmap[mod].dtype.kind == "u" - - # names along non-axis are concatenated - assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values()) - assert ( - getattr(mdata, f"{oattr}_names") - == reduce(lambda x, y: x.append(y), (getattr(mod, f"{oattr}_names") for mod in mdata.mod.values())) - ).all() - - # names along axis are unioned - axisnames = reduce( - lambda x, y: x.union(y, sort=False), (getattr(mod, f"{attr}_names") for mod in mdata.mod.values()) - ) - assert mdata.shape[axis] == axisnames.shape[0] - assert (getattr(mdata, f"{attr}_names").sort_values() == axisnames.sort_values()).all() - - # guards against Pandas scrambling the order. This was the case for pandas < 1.4.0 when using pd.concat with an outer join on a MultiIndex. - # reprex: - # - # import numpy as np - # import pandas as pd - # df1 = pd.DataFrame({"a": np.repeat(np.arange(5), 2), "b": np.tile(np.asarray([0, 1]), 5), "c": np.arange(10)}).set_index("a").set_index("b", append=True) - # df2 = pd.DataFrame({"a": np.repeat(np.arange(10), 2), "b": np.tile(np.asarray([0, 1]), 10), "d": np.arange(20)}).set_index("a").set_index("b", append=True) - # df1 = df1.iloc[::-1, :] - # df = pd.concat((kdf1, df2), axis=1, join="outer", sort=False) - assert ( - getattr(mdata, f"{attr}_names")[: mdata["mod1"].shape[axis]] == getattr(mdata["mod1"], f"{attr}_names") - ).all() - - def test_update_add_modality(self, rng: np.random.Generator, modalities: Mapping[str, AnnData], axis: Axis): - modnames = list(modalities.keys()) - mdata = add_mdata_global_columns( - MuData({modname: modalities[modname] for modname in modnames[:-2]}, axis=axis), rng - ) - - attr = "obs" if axis == 0 else "var" - oattr = "var" if axis == 0 else "obs" - - for i in (-2, -1): - old_attrnames = getattr(mdata, f"{attr}_names") - old_oattrnames = getattr(mdata, f"{oattr}_names") - - some_obs_names = mdata.obs_names[:2] - mdata.obsm["test"] = rng.normal(size=(mdata.n_obs, 1)) - true_obsm_values = self.get_attrm_values(mdata, "obs", "test", some_obs_names) - - mdata.mod[modnames[i]] = modalities[modnames[i]] - mdata.update() - - for mod in mdata.mod.keys(): - assert mdata.obsmap[mod].dtype.kind == "u" - assert mdata.varmap[mod].dtype.kind == "u" - - self.assert_dtypes(mdata.obs) - self.assert_dtypes(mdata.var) - - test_obsm_values = self.get_attrm_values(mdata, "obs", "test", some_obs_names) - if axis == 1: - assert np.isnan(mdata.obsm["test"]).sum() == modalities[modnames[i]].n_obs - assert np.all(np.isnan(mdata.obsm["test"][-modalities[modnames[i]].n_obs :])) - assert np.all(~np.isnan(mdata.obsm["test"][: -modalities[modnames[i]].n_obs])) - assert (test_obsm_values[~np.isnan(test_obsm_values)].reshape(-1) == true_obsm_values.reshape(-1)).all() - else: - assert (test_obsm_values == true_obsm_values).all() - - attrnames = getattr(mdata, f"{attr}_names") - oattrnames = getattr(mdata, f"{oattr}_names") - assert (attrnames[: old_attrnames.size] == old_attrnames).all() - assert (oattrnames[: old_oattrnames.size] == old_oattrnames).all() - - assert ( - attrnames == old_attrnames.union(getattr(modalities[modnames[i]], f"{attr}_names"), sort=False) - ).all() - assert (oattrnames == old_oattrnames.append(getattr(modalities[modnames[i]], f"{oattr}_names"))).all() - - def test_update_delete_modality(self, mdata: MuData, axis: Axis): - modnames = list(mdata.mod.keys()) - attr = "obs" if axis == 0 else "var" - oattr = "var" if axis == 0 else "obs" - attrm = f"{attr}m" - oattrm = f"{oattr}m" - - fullbatch = getattr(mdata, attr)["batch"] - fullobatch = getattr(mdata, oattr)["batch"] - fulltestm = getattr(mdata, attrm)["test"] - fullotestm = getattr(mdata, oattrm)["test"] - keptmask = (getattr(mdata, f"{attr}map")[modnames[1]].reshape(-1) > 0) | ( - getattr(mdata, f"{attr}map")[modnames[2]].reshape(-1) > 0 - ) - keptomask = (getattr(mdata, f"{oattr}map")[modnames[1]].reshape(-1) > 0) | ( - getattr(mdata, f"{oattr}map")[modnames[2]].reshape(-1) > 0 - ) - - del mdata.mod[modnames[0]] +def get_attrm_values(mdata: MuData, attr: str, key: str, names: Sequence[str]): + attrm = getattr(mdata, f"{attr}m") + index = getattr(mdata, f"{attr}_names") + return np.concatenate([np.atleast_1d(attrm[key][np.nonzero(index == name)[0]]) for name in names]) + + +def assert_dtypes(df: pd.DataFrame): + assert pd.api.types.is_integer_dtype(df["dtype-int"]) + assert pd.api.types.is_float_dtype(df["dtype-float"]) + assert pd.api.types.is_bool_dtype(df["dtype-bool"]) + assert pd.api.types.is_categorical_dtype(df["dtype-categorical"]) + assert pd.api.types.is_string_dtype(df["batch"]) or df["batch"].dtype == object + + +def test_update_simple(mdata: MuData, axis: Axis): + """ + Update should work when + - obs_names are the same across modalities, + - var_names are unique to each modality + """ + attr = "obs" if axis == 0 else "var" + oattr = "var" if axis == 0 else "obs" + + for mod in mdata.mod.keys(): + assert mdata.obsmap[mod].dtype.kind == "u" + assert mdata.varmap[mod].dtype.kind == "u" + + # names along non-axis are concatenated + assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values()) + assert ( + getattr(mdata, f"{oattr}_names") + == reduce(lambda x, y: x.append(y), (getattr(mod, f"{oattr}_names") for mod in mdata.mod.values())) + ).all() + + # names along axis are unioned + axisnames = reduce( + lambda x, y: x.union(y, sort=False), (getattr(mod, f"{attr}_names") for mod in mdata.mod.values()) + ) + assert mdata.shape[axis] == axisnames.shape[0] + assert (getattr(mdata, f"{attr}_names").sort_values() == axisnames.sort_values()).all() + + # guards against Pandas scrambling the order. This was the case for pandas < 1.4.0 when using pd.concat with an outer join on a MultiIndex. + # reprex: + # + # import numpy as np + # import pandas as pd + # df1 = pd.DataFrame({"a": np.repeat(np.arange(5), 2), "b": np.tile(np.asarray([0, 1]), 5), "c": np.arange(10)}).set_index("a").set_index("b", append=True) + # df2 = pd.DataFrame({"a": np.repeat(np.arange(10), 2), "b": np.tile(np.asarray([0, 1]), 10), "d": np.arange(20)}).set_index("a").set_index("b", append=True) + # df1 = df1.iloc[::-1, :] + # df = pd.concat((kdf1, df2), axis=1, join="outer", sort=False) + assert ( + getattr(mdata, f"{attr}_names")[: mdata["mod1"].shape[axis]] == getattr(mdata["mod1"], f"{attr}_names") + ).all() + + +def test_update_add_modality(rng: np.random.Generator, modalities: Mapping[str, AnnData], axis: Axis): + modnames = list(modalities.keys()) + mdata = add_mdata_global_columns( + MuData({modname: modalities[modname] for modname in modnames[:-2]}, axis=axis), rng + ) + + attr = "obs" if axis == 0 else "var" + oattr = "var" if axis == 0 else "obs" + + for i in (-2, -1): + old_attrnames = getattr(mdata, f"{attr}_names") + old_oattrnames = getattr(mdata, f"{oattr}_names") + + some_obs_names = mdata.obs_names[:2] + mdata.obsm["test"] = rng.normal(size=(mdata.n_obs, 1)) + true_obsm_values = get_attrm_values(mdata, "obs", "test", some_obs_names) + + mdata.mod[modnames[i]] = modalities[modnames[i]] mdata.update() for mod in mdata.mod.keys(): assert mdata.obsmap[mod].dtype.kind == "u" assert mdata.varmap[mod].dtype.kind == "u" - self.assert_dtypes(mdata.obs) - self.assert_dtypes(mdata.var) + assert_dtypes(mdata.obs) + assert_dtypes(mdata.var) + + test_obsm_values = get_attrm_values(mdata, "obs", "test", some_obs_names) + if axis == 1: + assert np.isnan(mdata.obsm["test"]).sum() == modalities[modnames[i]].n_obs + assert np.all(np.isnan(mdata.obsm["test"][-modalities[modnames[i]].n_obs :])) + assert np.all(~np.isnan(mdata.obsm["test"][: -modalities[modnames[i]].n_obs])) + assert (test_obsm_values[~np.isnan(test_obsm_values)].reshape(-1) == true_obsm_values.reshape(-1)).all() + else: + assert (test_obsm_values == true_obsm_values).all() + + attrnames = getattr(mdata, f"{attr}_names") + oattrnames = getattr(mdata, f"{oattr}_names") + assert (attrnames[: old_attrnames.size] == old_attrnames).all() + assert (oattrnames[: old_oattrnames.size] == old_oattrnames).all() + + assert (attrnames == old_attrnames.union(getattr(modalities[modnames[i]], f"{attr}_names"), sort=False)).all() + assert (oattrnames == old_oattrnames.append(getattr(modalities[modnames[i]], f"{oattr}_names"))).all() + + +def test_update_delete_modality(mdata: MuData, axis: Axis): + modnames = list(mdata.mod.keys()) + attr = "obs" if axis == 0 else "var" + oattr = "var" if axis == 0 else "obs" + attrm = f"{attr}m" + oattrm = f"{oattr}m" + + fullbatch = getattr(mdata, attr)["batch"] + fullobatch = getattr(mdata, oattr)["batch"] + fulltestm = getattr(mdata, attrm)["test"] + fullotestm = getattr(mdata, oattrm)["test"] + keptmask = (getattr(mdata, f"{attr}map")[modnames[1]].reshape(-1) > 0) | ( + getattr(mdata, f"{attr}map")[modnames[2]].reshape(-1) > 0 + ) + keptomask = (getattr(mdata, f"{oattr}map")[modnames[1]].reshape(-1) > 0) | ( + getattr(mdata, f"{oattr}map")[modnames[2]].reshape(-1) > 0 + ) + + del mdata.mod[modnames[0]] + mdata.update() + + for mod in mdata.mod.keys(): + assert mdata.obsmap[mod].dtype.kind == "u" + assert mdata.varmap[mod].dtype.kind == "u" + + assert_dtypes(mdata.obs) + assert_dtypes(mdata.var) + + assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values()) + assert (getattr(mdata, attr)["batch"] == fullbatch[keptmask]).all() + assert (getattr(mdata, oattr)["batch"] == fullobatch[keptomask]).all() + assert (getattr(mdata, attrm)["test"] == fulltestm[keptmask, :]).all() + assert (getattr(mdata, oattrm)["test"] == fullotestm[keptomask, :]).all() + + fullbatch = getattr(mdata, attr)["batch"] + fullobatch = getattr(mdata, oattr)["batch"] + fulltestm = getattr(mdata, attrm)["test"] + fullotestm = getattr(mdata, oattrm)["test"] + keptmask = getattr(mdata, f"{attr}map")[modnames[1]].reshape(-1) > 0 + keptomask = getattr(mdata, f"{oattr}map")[modnames[1]].reshape(-1) > 0 + + del mdata.mod[modnames[2]] + mdata.update() + + assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values()) + assert (getattr(mdata, oattr)["batch"] == fullobatch[keptomask]).all() + assert (getattr(mdata, attr)["batch"] == fullbatch[keptmask]).all() + assert (getattr(mdata, attrm)["test"] == fulltestm[keptmask, :]).all() + assert (getattr(mdata, oattrm)["test"] == fullotestm[keptomask, :]).all() + + +def test_update_intersecting(rng: np.random.Generator, modalities: Mapping[str, AnnData], axis: Axis): + """ + Update should work when + - obs_names are the same across modalities, + - there are intersecting var_names, + which are unique in each modality + """ + attr = "obs" if axis == 0 else "var" + oattr = "var" if axis == 0 else "obs" + for m, mod in modalities.items(): + setattr( + mod, f"{oattr}_names", [f"{m}_{oattr}{j}" if j != 0 else f"{oattr}_{j}" for j in range(mod.shape[1 - axis])] + ) - assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values()) - assert (getattr(mdata, attr)["batch"] == fullbatch[keptmask]).all() - assert (getattr(mdata, oattr)["batch"] == fullobatch[keptomask]).all() - assert (getattr(mdata, attrm)["test"] == fulltestm[keptmask, :]).all() - assert (getattr(mdata, oattrm)["test"] == fullotestm[keptomask, :]).all() + mdata = add_mdata_global_columns(MuData(modalities, axis=axis), rng) - fullbatch = getattr(mdata, attr)["batch"] - fullobatch = getattr(mdata, oattr)["batch"] - fulltestm = getattr(mdata, attrm)["test"] - fullotestm = getattr(mdata, oattrm)["test"] - keptmask = getattr(mdata, f"{attr}map")[modnames[1]].reshape(-1) > 0 - keptomask = getattr(mdata, f"{oattr}map")[modnames[1]].reshape(-1) > 0 + for mod in mdata.mod.keys(): + assert mdata.obsmap[mod].dtype.kind == "u" + assert mdata.varmap[mod].dtype.kind == "u" - del mdata.mod[modnames[2]] - mdata.update() + assert_dtypes(mdata.obs) + assert_dtypes(mdata.var) - assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in mdata.mod.values()) - assert (getattr(mdata, oattr)["batch"] == fullobatch[keptomask]).all() - assert (getattr(mdata, attr)["batch"] == fullbatch[keptmask]).all() - assert (getattr(mdata, attrm)["test"] == fulltestm[keptmask, :]).all() - assert (getattr(mdata, oattrm)["test"] == fullotestm[keptomask, :]).all() - - def test_update_intersecting(self, rng: np.random.Generator, modalities: Mapping[str, AnnData], axis: Axis): - """ - Update should work when - - obs_names are the same across modalities, - - there are intersecting var_names, - which are unique in each modality - """ - attr = "obs" if axis == 0 else "var" - oattr = "var" if axis == 0 else "obs" - for m, mod in modalities.items(): - setattr( - mod, - f"{oattr}_names", - [f"{m}_{oattr}{j}" if j != 0 else f"{oattr}_{j}" for j in range(mod.shape[1 - axis])], - ) - - mdata = add_mdata_global_columns(MuData(modalities, axis=axis), rng) + # names along non-axis are concatenated + assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in modalities.values()) + assert ( + getattr(mdata, f"{oattr}_names") + == reduce(lambda x, y: x.append(y), (getattr(mod, f"{oattr}_names") for mod in modalities.values())) + ).all() - for mod in mdata.mod.keys(): - assert mdata.obsmap[mod].dtype.kind == "u" - assert mdata.varmap[mod].dtype.kind == "u" - - self.assert_dtypes(mdata.obs) - self.assert_dtypes(mdata.var) + # names along axis are unioned + axisnames = reduce( + lambda x, y: x.union(y, sort=False), (getattr(mod, f"{attr}_names") for mod in modalities.values()) + ) + assert mdata.shape[axis] == axisnames.shape[0] + assert (getattr(mdata, f"{attr}_names") == axisnames).all() - # names along non-axis are concatenated - assert mdata.shape[1 - axis] == sum(mod.shape[1 - axis] for mod in modalities.values()) - assert ( - getattr(mdata, f"{oattr}_names") - == reduce(lambda x, y: x.append(y), (getattr(mod, f"{oattr}_names") for mod in modalities.values())) - ).all() - - # names along axis are unioned - axisnames = reduce( - lambda x, y: x.union(y, sort=False), (getattr(mod, f"{attr}_names") for mod in modalities.values()) - ) - assert mdata.shape[axis] == axisnames.shape[0] - assert (getattr(mdata, f"{attr}_names") == axisnames).all() - def test_update_after_filter_obs_adata(self, mdata: MuData, axis: Axis): - """ - Check for https://github.com/scverse/muon/issues/44 - """ - # Replicate in-place filtering in muon: - # mu.pp.filter_obs(mdata['mod1'], 'min_count', lambda x: (x < -2)) +def test_update_after_filter_obs_adata(mdata: MuData, axis: Axis): + """ + Check for https://github.com/scverse/muon/issues/44 + """ + # Replicate in-place filtering in muon: + # mu.pp.filter_obs(mdata['mod1'], 'min_count', lambda x: (x < -2)) - old_obsnames = mdata.obs_names - old_varnames = mdata.var_names + old_obsnames = mdata.obs_names + old_varnames = mdata.var_names - filtermask = mdata["mod3"].obs["min_count"] < -2 - fullfiltermask = mdata.obsmap["mod3"].copy() > 0 - fullfiltermask[fullfiltermask] = filtermask - keptmask = (mdata.obsmap["mod1"] > 0) | (mdata.obsmap["mod2"] > 0) | fullfiltermask + filtermask = mdata["mod3"].obs["min_count"] < -2 + fullfiltermask = mdata.obsmap["mod3"].copy() > 0 + fullfiltermask[fullfiltermask] = filtermask + keptmask = (mdata.obsmap["mod1"] > 0) | (mdata.obsmap["mod2"] > 0) | fullfiltermask - some_obs_names = mdata[keptmask, :].obs_names.values[:2] - true_obsm_values = self.get_attrm_values(mdata[keptmask], "obs", "test", some_obs_names) + some_obs_names = mdata[keptmask, :].obs_names.values[:2] + true_obsm_values = get_attrm_values(mdata[keptmask], "obs", "test", some_obs_names) - mdata.mod["mod3"] = mdata["mod3"][mdata["mod3"].obs["min_count"] < -2].copy() - mdata.update() + mdata.mod["mod3"] = mdata["mod3"][mdata["mod3"].obs["min_count"] < -2].copy() + mdata.update() - for mod in mdata.mod.keys(): - assert mdata.obsmap[mod].dtype.kind == "u" - assert mdata.varmap[mod].dtype.kind == "u" + for mod in mdata.mod.keys(): + assert mdata.obsmap[mod].dtype.kind == "u" + assert mdata.varmap[mod].dtype.kind == "u" - self.assert_dtypes(mdata.obs) - self.assert_dtypes(mdata.var) + assert_dtypes(mdata.obs) + assert_dtypes(mdata.var) - assert mdata.obs["batch"].isna().sum() == 0 + assert mdata.obs["batch"].isna().sum() == 0 - assert (mdata.var_names == old_varnames).all() - if axis == 0: - # check if the order is preserved - assert (mdata.obs_names == old_obsnames[old_obsnames.isin(mdata.obs_names)]).all() + assert (mdata.var_names == old_varnames).all() + if axis == 0: + # check if the order is preserved + assert (mdata.obs_names == old_obsnames[old_obsnames.isin(mdata.obs_names)]).all() - test_obsm_values = self.get_attrm_values(mdata, "obs", "test", some_obs_names) - assert (true_obsm_values == test_obsm_values).all() + test_obsm_values = get_attrm_values(mdata, "obs", "test", some_obs_names) + assert (true_obsm_values == test_obsm_values).all() - def test_update_after_obs_reordered(self, mdata: MuData): - """ - Update should work if obs are reordered. - """ - some_obs_names = mdata.obs_names.values[:2] - true_obsm_values = self.get_attrm_values(mdata, "obs", "test", some_obs_names) +def test_update_after_obs_reordered(mdata: MuData): + """ + Update should work if obs are reordered. + """ + some_obs_names = mdata.obs_names.values[:2] - mdata.mod["mod1"] = mdata["mod1"][::-1].copy() - mdata.update() + true_obsm_values = get_attrm_values(mdata, "obs", "test", some_obs_names) - for mod in mdata.mod.keys(): - assert mdata.obsmap[mod].dtype.kind == "u" - assert mdata.varmap[mod].dtype.kind == "u" + mdata.mod["mod1"] = mdata["mod1"][::-1].copy() + mdata.update() - self.assert_dtypes(mdata.obs) - self.assert_dtypes(mdata.var) - - test_obsm_values = self.get_attrm_values(mdata, "obs", "test", some_obs_names) - - assert (true_obsm_values == test_obsm_values).all() - - -@pytest.mark.parametrize("axis", [0, 1]) -class TestMuDataLegacy: - @pytest.mark.parametrize("mod", ["unique"]) - @pytest.mark.parametrize("across", ["intersecting"]) - @pytest.mark.parametrize("n", ["joint", "disjoint"]) - def test_update_simple(self, modalities: Mapping[str, AnnData], axis: Axis): - """ - Update should work when - - obs_names are the same across modalities, - - var_names are unique to each modality - """ - attr = "obs" if axis == 0 else "var" - oattr = "var" if axis == 0 else "obs" - for m, mod in modalities.items(): - mod.obs["assert-bool"] = True - mod.obs[f"assert-boolean-{m}"] = False - mod.var["assert-bool"] = True - mod.var[f"assert-boolean-{m}"] = False - setattr(mod, f"{oattr}_names", [f"{m}_{oattr}{j}" for j in range(mod.shape[1 - axis])]) - mdata = MuData(modalities, axis=axis) - mdata.update() + for mod in mdata.mod.keys(): + assert mdata.obsmap[mod].dtype.kind == "u" + assert mdata.varmap[mod].dtype.kind == "u" - # Variables are different across modalities - assert "mod" in getattr(mdata, oattr).columns - assert getattr(mdata, oattr)["assert-bool"].dtype == bool - for m, mod in modalities.items(): - assert getattr(mdata, oattr)[f"{m}:assert-boolean-{m}"].dtype == "boolean" - # Observations are the same across modalities - # hence /mod/mod1/obs/mod -> /obs/mod1:mod - assert f"{m}:mod" in getattr(mdata, attr).columns - # Columns are intact in individual modalities - assert "mod" in mod.obs.columns - assert all(mod.obs["mod"] == m) - assert "mod" in mod.var.columns - assert all(mod.var["mod"] == m) - - @pytest.mark.parametrize("mod", ["unique", "extreme_duplicated"]) - @pytest.mark.parametrize("across", ["intersecting"]) - @pytest.mark.parametrize("n", ["joint", "disjoint"]) - def test_update_duplicates(self, modalities: Mapping[str, AnnData], axis: Axis): - """ - Update should work when - - obs_names are the same across modalities, - - there are duplicated var_names, which are not intersecting - between modalities - """ - attr = "obs" if axis == 0 else "var" - oattr = "var" if axis == 0 else "obs" - for m, mod in modalities.items(): - setattr(mod, f"{oattr}_names", [f"{m}_{oattr}{j // 2}" for j in range(mod.shape[1 - axis])]) - mdata = MuData(modalities, axis=axis) - mdata.update() + assert_dtypes(mdata.obs) + assert_dtypes(mdata.var) - # Variables are different across modalities - assert "mod" in getattr(mdata, oattr).columns - for m, mod in modalities.items(): - # Observations are the same across modalities - # hence /mod/mod1/obs/mod -> /obs/mod1:mod - assert f"{m}:mod" in getattr(mdata, attr).columns - # Columns are intact in individual modalities - assert "mod" in mod.obs.columns - assert all(mod.obs["mod"] == m) - assert "mod" in mod.var.columns - assert all(mod.var["mod"] == m) - - @pytest.mark.parametrize("mod", ["unique", "extreme_duplicated"]) - @pytest.mark.parametrize("across", ["intersecting"]) - @pytest.mark.parametrize("n", ["joint", "disjoint"]) - def test_update_intersecting(self, modalities: Mapping[str, AnnData], axis: Axis): - """ - Update should work when - - obs_names are the same across modalities, - - there are intersecting var_names, - which are unique in each modality - """ - attr = "obs" if axis == 0 else "var" - for m, mod in modalities.items(): - # [mod1] var0, mod1_var1, mod1_var2, ...; [mod2] var0, mod2_var1, mod2_var2, ... - setattr( - mod, f"{attr}_names", [f"{m}_{attr}{j}" if j != 0 else f"{attr}_{j}" for j in range(mod.shape[axis])] - ) - mdata = MuData(modalities, axis=axis) - mdata.update() + test_obsm_values = get_attrm_values(mdata, "obs", "test", some_obs_names) - for m, mod in modalities.items(): - # Observations are the same across modalities - # hence /mod/mod1/obs/mod -> /obs/mod1:mod - assert f"{m}:mod" in mdata.obs.columns - # Variables are intersecting - # so they won't be merged - assert f"{m}:mod" in mdata.var.columns - # Columns are intact in individual modalities - assert "mod" in mod.obs.columns - assert all(mod.obs["mod"] == m) - assert "mod" in mod.var.columns - assert all(mod.var["mod"] == m) - - @pytest.mark.parametrize("mod", ["unique"]) - @pytest.mark.parametrize("across", ["intersecting"]) - @pytest.mark.parametrize("n", ["joint", "disjoint"]) - def test_update_after_filter_obs_adata(self, mdata_legacy: MuData): - """ - Check for https://github.com/scverse/muon/issues/44 - """ - # Replicate in-place filtering in muon: - # mu.pp.filter_obs(mdata['mod1'], 'min_count', lambda x: (x < -2)) - mdata_legacy.mod["mod1"] = mdata_legacy["mod1"][mdata_legacy["mod1"].obs["min_count"] < -2].copy() - mdata_legacy.update() - assert mdata_legacy.obs["batch"].isna().sum() == 0 - - @pytest.mark.parametrize("mod", ["unique", "extreme_duplicated"]) - @pytest.mark.parametrize("across", ["intersecting"]) - @pytest.mark.parametrize("n", ["joint", "disjoint"]) - def test_update_after_obs_reordered(self, rng, mdata_legacy: MuData): - """ - Update should work if obs are reordered. - """ - attr = "obs" if mdata_legacy.axis == 0 else "var" - getattr(mdata_legacy, f"{attr}m")["test"] = rng.normal(size=(mdata_legacy.shape[mdata_legacy.axis], 2)) - - some_names = getattr(mdata_legacy, f"{attr}_names").values[:2] - - true_values = [ - getattr(mdata_legacy, f"{attr}m")["test"][ - np.where(getattr(mdata_legacy, f"{attr}_names").values == name)[0][0] - ] - for name in some_names - ] - - mdata_legacy.mod["mod1"] = mdata_legacy["mod1"][::-1].copy() - mdata_legacy.update() - - test_values = [ - getattr(mdata_legacy, f"{attr}m")["test"][np.where(getattr(mdata_legacy, f"{attr}_names") == name)[0][0]] - for name in some_names - ] - - assert all(all(true_values[i] == test_values[i]) for i in range(len(true_values))) + assert (true_obsm_values == test_obsm_values).all()