From 4ba2ff8043856dd204c68e16a20a9fd00fa389e8 Mon Sep 17 00:00:00 2001 From: Laurens Lehner Date: Tue, 12 Nov 2024 00:07:45 +0100 Subject: [PATCH 01/23] Add draft function to mask tissue regions --- .../pipeline/mask_filtering/region_masking.py | 106 ++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 src/scportrait/pipeline/mask_filtering/region_masking.py diff --git a/src/scportrait/pipeline/mask_filtering/region_masking.py b/src/scportrait/pipeline/mask_filtering/region_masking.py new file mode 100644 index 00000000..fb2a9c69 --- /dev/null +++ b/src/scportrait/pipeline/mask_filtering/region_masking.py @@ -0,0 +1,106 @@ +import spatialdata as sd +from napari_spatialdata import Interactive +from shapely.geometry import mapping +from rasterio.features import geometry_mask +import rasterio +import dask +from spatialdata.models import Image2DModel +import numpy as np +from scipy import ndimage +from skimage.measure import find_contours +from shapely.geometry import Polygon +from shapely import unary_union +from skimage.segmentation import watershed +from skimage.draw import disk + +def mask_image(sdata, image, mask, invert, automatic_masking, threshold, overwrite, masked_image_name) + """ + Given an image and mask, either masks or crops the image. + + Parameters + ---------- + sdata : sd.SpatialData + spatialdata object containing the image and mask. + image : str + Name of the image in sdata.images to mask. + mask : str | shapely.geometry.Polygon + Mask, either str of the name of the shape in sdata.shapes or a shapely polygon. + invert : bool + If True, inverts the mask, such that only pixels within mask remain, while the rest gets cropped. + automatic_masking : bool + If True, uses threshold + watershed to automatically create a mask based on shapes. Threshold needs to be adjusted manually. + threshold : float + Threshold for pixel intensity values at which to segment image into foreground and background. + overwrite : bool + Whether to overwrite the image in sdata.images. + masked_image_name : None | str + Name of the masked image in sdata.images if overwrite==True. Defaults to f"{image}_masked". + Returns + ------- + sd.SpatialData + spatialdata object with masked image + """ + channels, height, width = sdata.images[image].data.shape + + if automatic_masking: + polygon = _draw_polygons(sdata.images[image].data, threshold) + elif isinstance(mask, str): + polygon = sdata.shapes[mask].iloc[0].geometry + else: + polygon = mask + + polygon_geom = [mapping(polygon)] + + transform = rasterio.transform.Affine(1, 0, 0, 0, 1, 0) # identity transform + + image_mask = geometry_mask( + polygon_geom, + invert=invert, + out_shape=(height, width), + transform=transform + ) + + if channels > 1: + image_mask = dask.array.broadcast_to(image_mask, (channels, height, width)) + + masked_image = sdata.images[image].data * image_mask + images = {} + images["masked_image"] = Image2DModel.parse(masked_image) + + if overwrite: + sdata.images[image] = images["masked_image"] + else: + if masked_image_name is None: + masked_image_name = f"{image}_masked" + sdata.images[masked_image_name] = images["masked_image"] + +def _draw_polygons(image, threshold): + """ + Given an image, detect regions to turn into polygon shapes, which are then used as a mask. + + Parameters + ---------- + image : np.ndarray + Image to find regions in. + threshold : float + Threshold for pixel intensity values at which to segment image into foreground and background. + Returns + ------- + shapely.geometry.Polygon + Polygon containing the detected regions. + """ + if image.shape[0] == 1: + image = image[0] + binary_image = image > np.percentile(image.flatten(), threshold) + + distance = ndimage.distance_transform_edt(binary_image) + markers, _ = ndimage.label(distance) + + segmented = watershed(-distance, markers, mask=binary_image) + + contours = find_contours(segmented, level=0.5) + + polygons = [Polygon(contour) for contour in contours if len(contour) > 2] + polygon = unary_union(polygons) + + return polygon \ No newline at end of file From 8b224eba62219570d2d7512a1f2ea859ee3b7655 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Mon, 3 Nov 2025 20:48:13 +0100 Subject: [PATCH 02/23] [FEATURE] store filepath when reading h5sc files Keeps track of filepath in uns from which the h5sc object was loaded. This allows you to update values (e.g. obs) on disk. --- src/scportrait/io/h5sc.py | 8 +++++++- src/scportrait/pipeline/_utils/constants.py | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/scportrait/io/h5sc.py b/src/scportrait/io/h5sc.py index 8639ee0e..427cb450 100644 --- a/src/scportrait/io/h5sc.py +++ b/src/scportrait/io/h5sc.py @@ -5,7 +5,11 @@ from anndata import AnnData from anndata._io.h5ad import _clean_uns, _read_raw, read_dataframe, read_elem -from scportrait.pipeline._utils.constants import DEFAULT_NAME_SINGLE_CELL_IMAGES, IMAGE_DATACONTAINER_NAME +from scportrait.pipeline._utils.constants import ( + DEFAULT_IDENTIFIER_FILENAME, + DEFAULT_NAME_SINGLE_CELL_IMAGES, + IMAGE_DATACONTAINER_NAME, +) def read_h5sc(filename: str | Path) -> AnnData: @@ -44,4 +48,6 @@ def read_h5sc(filename: str | Path) -> AnnData: _clean_uns(adata) adata.obsm[DEFAULT_NAME_SINGLE_CELL_IMAGES] = f.get(IMAGE_DATACONTAINER_NAME) + adata.uns[DEFAULT_IDENTIFIER_FILENAME] = filename + return adata diff --git a/src/scportrait/pipeline/_utils/constants.py b/src/scportrait/pipeline/_utils/constants.py index 01cc7548..ea6330dc 100644 --- a/src/scportrait/pipeline/_utils/constants.py +++ b/src/scportrait/pipeline/_utils/constants.py @@ -35,6 +35,7 @@ IMAGE_DATACONTAINER_NAME = f"obsm/{DEFAULT_NAME_SINGLE_CELL_IMAGES}" DEFAULT_CELL_ID_NAME = "scportrait_cell_id" INDEX_DATACONTAINER_NAME = f"obs/{DEFAULT_CELL_ID_NAME}" +DEFAULT_IDENTIFIER_FILENAME = "h5sc_source_path" DEFAULT_IMAGE_DTYPE: np.dtype = np.uint16 DEFAULT_SEGMENTATION_DTYPE: np.dtype = np.uint64 From 33ff5b398732b8e759ec6956e43d9f2ff0752c79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Mon, 3 Nov 2025 20:48:43 +0100 Subject: [PATCH 03/23] [FIX] if no axes object is provided to plot_shapes correctly read dimensions from shape file --- src/scportrait/plotting/sdata.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/scportrait/plotting/sdata.py b/src/scportrait/plotting/sdata.py index f00377c1..3fde611a 100644 --- a/src/scportrait/plotting/sdata.py +++ b/src/scportrait/plotting/sdata.py @@ -4,8 +4,10 @@ import matplotlib as mpl import matplotlib.pyplot as plt +import numpy as np import spatialdata import xarray +from geopandas.geodataframe import GeoDataFrame from matplotlib.axes import Axes PALETTE = [ @@ -48,6 +50,11 @@ def _get_shape_element(sdata, element_name) -> tuple[int, int]: """ if isinstance(sdata[element_name], xarray.DataTree): shape = sdata[element_name].scale0.image.shape + elif isinstance(sdata[element_name], GeoDataFrame): + bounds = sdata[element_name].geometry.bounds + x = int(np.ceil(bounds["maxx"] - bounds["minx"])) + y = int(np.ceil(bounds["maxy"] - bounds["miny"])) + shape = (x, y) else: shape = sdata[element_name].data.shape From 3fc534bebce7b01fa04d8cbcf79d45bf03ff8128 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Thu, 6 Nov 2025 02:31:43 +0100 Subject: [PATCH 04/23] add functionality to update obs to disk --- src/scportrait/io/h5sc.py | 1 + src/scportrait/pipeline/project.py | 36 +++++++--- src/scportrait/tools/h5sc/__init__.py | 4 +- src/scportrait/tools/h5sc/operations.py | 95 ++++++++++++++++++++++++- 4 files changed, 125 insertions(+), 11 deletions(-) diff --git a/src/scportrait/io/h5sc.py b/src/scportrait/io/h5sc.py index 427cb450..4a749833 100644 --- a/src/scportrait/io/h5sc.py +++ b/src/scportrait/io/h5sc.py @@ -49,5 +49,6 @@ def read_h5sc(filename: str | Path) -> AnnData: adata.obsm[DEFAULT_NAME_SINGLE_CELL_IMAGES] = f.get(IMAGE_DATACONTAINER_NAME) adata.uns[DEFAULT_IDENTIFIER_FILENAME] = filename + adata.uns["_h5sc_file_handle"] = f return adata diff --git a/src/scportrait/pipeline/project.py b/src/scportrait/pipeline/project.py index 6da737a2..e17b1c56 100644 --- a/src/scportrait/pipeline/project.py +++ b/src/scportrait/pipeline/project.py @@ -122,6 +122,9 @@ class Project(Logable): DEFAULT_SINGLE_CELL_IMAGE_DTYPE = DEFAULT_SINGLE_CELL_IMAGE_DTYPE DEFAULT_CELL_ID_NAME = DEFAULT_CELL_ID_NAME + _h5sc_handle = None + _h5sc_adata = None + PALETTE = [ "blue", "green", @@ -231,6 +234,15 @@ def __exit__(self): def __del__(self): self._clear_temp_dir() + def __getstate__(self): + state = self.__dict__.copy() + state["_h5sc_handle"] = None # ensure closed before pickling + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._h5sc_handle = None # will be reopened lazily + @property def sdata_path(self) -> str: return self._get_sdata_path() @@ -242,14 +254,22 @@ def sdata(self) -> SpatialData: @property def h5sc(self) -> AnnData: - if self.extraction_f is None: - raise ValueError("No extraction method has been set.") - else: - if self.extraction_f.output_path is None: - path = self.extraction_f.extraction_file - else: - path = self.extraction_f.output_path - return read_h5sc(path) + # Always safely close previous handle if it exists + if hasattr(self, "_h5sc_handle") and self._h5sc_handle is not None: + try: + self._h5sc_handle.close() + except (ValueError, RuntimeError): + # handle was already closed or in invalid state + pass + self._h5sc_handle = None + + # Load a fresh AnnData + adata = read_h5sc(self.extraction_f.output_path or self.extraction_f.extraction_file) + + # Track only the handle, not the adata + self._h5sc_handle = adata.uns["_h5sc_file_handle"] + + return adata ##### Setup Functions ##### diff --git a/src/scportrait/tools/h5sc/__init__.py b/src/scportrait/tools/h5sc/__init__.py index 54830847..e9ed862b 100644 --- a/src/scportrait/tools/h5sc/__init__.py +++ b/src/scportrait/tools/h5sc/__init__.py @@ -5,6 +5,6 @@ Functions to work with scPortrait's standardized single-cell data format. """ -from .operations import get_image_index, get_image_with_cellid +from .operations import add_spatial_coordinates, get_image_index, get_image_with_cellid, update_obs_on_disk -__all__ = ["get_image_with_cellid", "get_image_index"] +__all__ = ["update_obs_on_disk", "get_image_with_cellid", "get_image_index", "add_spatial_coordinates"] diff --git a/src/scportrait/tools/h5sc/operations.py b/src/scportrait/tools/h5sc/operations.py index e672d050..197915b1 100644 --- a/src/scportrait/tools/h5sc/operations.py +++ b/src/scportrait/tools/h5sc/operations.py @@ -5,9 +5,67 @@ Functions to work with scPortrait's standardized single-cell data format. """ +from __future__ import annotations + +from typing import TYPE_CHECKING +from warnings import warn + +import dask.array as da +import h5py import numpy as np -from scportrait.pipeline._utils.constants import DEFAULT_CELL_ID_NAME, DEFAULT_NAME_SINGLE_CELL_IMAGES +from scportrait.io.h5sc import read_h5sc + +if TYPE_CHECKING: + from anndata import AnnData + from dask.dataframe.core import DataFrame as da_DataFrame + +from scportrait.pipeline._utils.constants import ( + DEFAULT_CELL_ID_NAME, + DEFAULT_IDENTIFIER_FILENAME, + DEFAULT_NAME_SINGLE_CELL_IMAGES, + IMAGE_DATACONTAINER_NAME, +) + + +def _update_obs_on_disk(adata: AnnData) -> None: + """ + Temporarily close the HDF5 handle from a read-only AnnData, + overwrite .obs on disk, then reopen it and restore the image dataset. + """ + # 1. Get the open HDF5 file handle + file_handle = adata.uns.get("_h5sc_file_handle", None) + + # 2. Close file to release read-only lock + if file_handle: + file_handle.close() + adata.uns["_h5sc_file_handle"] = None + + # 3. Write updated obs + obs_df = adata.obs.copy() + obs_df.index = obs_df.index.astype(str) + + with h5py.File(adata.uns[DEFAULT_IDENTIFIER_FILENAME], "r+") as f: + if "obs" in f: + del f["obs"] + grp = f.create_group("obs") + for col in obs_df.columns: + grp.create_dataset(col, data=obs_df[col].to_numpy()) + + # 4. Reopen file handle and restore image dataset + f = h5py.File(adata.uns[DEFAULT_IDENTIFIER_FILENAME], "r") + adata.obsm[DEFAULT_NAME_SINGLE_CELL_IMAGES] = f.get(IMAGE_DATACONTAINER_NAME) + adata.uns["_h5sc_file_handle"] = f + + +def update_obs_on_disk(adata: AnnData) -> None: + """ + Overwrite the .obs table in an existing .h5sc file on disk. + + Args: + adata: AnnData object whose .obs will replace the existing one. + """ + _update_obs_on_disk(adata) def get_image_index(adata, cell_id: int | list[int]) -> int | list[int]: @@ -64,3 +122,38 @@ def get_image_with_cellid(adata, cell_id: list[int] | int, select_channel: int | return array.squeeze(axis=0) # Remove the first dimension else: return array + + +def add_spatial_coordinates( + adata: AnnData, + centers_object: da_DataFrame, + cell_id_identifier: str = "scportrait_cell_id", + update_on_disk: bool = False, +) -> None: + """Add spatial coordinates to the AnnData object from scPortrait's standardized centers object. + Args: + adata: AnnData object to add spatial coordinates to. + centers_object: Dask DataFrame containing the spatial coordinates with columns "x" and "y" and the scportrait cell id as index. + cell_id_identifier: The column name in `adata.obs` that contains the cell IDs. + update_on_disk: boolean value indicating if the updated obs containing the spatial coordinates should be written to disk. This will overwrite the existing obs. + + Returns: + Updates the obs object of the passed h5sc object. + """ + + assert cell_id_identifier in adata.obs.columns, f"{cell_id_identifier} must be a column in h5sc.obs" + assert ( + ["x", "y"] == list(centers_object.columns) + ), "centers_object must be scportrait's standardized centers object containing columns 'x' and 'y' and the scportrait cell id as index, but detected columns are {centers_object.columns}" + + if ("x" in adata.obs.columns) or ("y" in adata.obs.columns): + adata.obs.drop(columns=["x", "y"], inplace=True, errors="ignore") + warn( + "Removed existing 'x' and 'y' columns from adata.obs. If this is not intended, please check the input data.", + stacklevel=2, + ) + + adata.obs = adata.obs.merge(centers_object.compute(), left_on=cell_id_identifier, right_index=True) + + if update_on_disk: + update_obs_on_disk(adata) From d20b5acaf82ad89426f0336ce0d86015085372af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Thu, 6 Nov 2025 02:41:48 +0100 Subject: [PATCH 05/23] add plotting support for non memory backed h5sc files --- src/scportrait/plotting/h5sc.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/src/scportrait/plotting/h5sc.py b/src/scportrait/plotting/h5sc.py index a29b9777..7f04beb0 100644 --- a/src/scportrait/plotting/h5sc.py +++ b/src/scportrait/plotting/h5sc.py @@ -3,8 +3,10 @@ import warnings from collections.abc import Iterable +import h5py import matplotlib.pyplot as plt import numpy as np +import pandas as pd from anndata import AnnData from matplotlib.axes import Axes from matplotlib.figure import Figure @@ -221,7 +223,20 @@ def cell_grid_single_channel( fig = ax.get_figure() spacing = spacing * single_cell_size - images = get_image_with_cellid(adata, _cell_ids, channel_id) + + # Collect images in a list + if isinstance(adata.obsm["single_cell_images"], h5py.Dataset): + # if working on a memory-backed array + images = get_image_with_cellid(adata, _cell_ids, channel_id) + + else: + # non backed h5sc adata objects can be accessed directly + # these are created by slicing original h5sc objects + col = "scportrait_cell_id" + mapping = pd.Series(data=np.arange(len(adata.obs), dtype=int), index=adata.obs[col].values) + idx = mapping.loc[_cell_ids].to_numpy() + images = adata.obsm["single_cell_images"][idx, channel_id, :, :] + _plot_image_grid( ax, images, @@ -309,7 +324,18 @@ def cell_grid_multi_channel( channel_names = adata.uns["single_cell_images"]["channel_names"] # Collect images in a list - images = get_image_with_cellid(adata, _cell_ids) + if isinstance(adata.obsm["single_cell_images"], h5py.Dataset): + # if working on a memory-backed array + images = get_image_with_cellid(adata, _cell_ids) + + else: + # non backed h5sc adata objects can be accessed directly + # these are created by slicing original h5sc objects + col = "scportrait_cell_id" + mapping = pd.Series(data=np.arange(len(adata.obs), dtype=int), index=adata.obs[col].values) + idx = mapping.loc[_cell_ids].to_numpy() + images = adata.obsm["single_cell_images"][idx] + if select_channels is not None: if not isinstance(select_channels, Iterable): select_channels = [select_channels] From 14b0bc373b2e1a3fa06e7fae5f5fa31bd53c18ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Thu, 6 Nov 2025 23:34:49 +0100 Subject: [PATCH 06/23] rename function --- src/scportrait/tools/h5sc/__init__.py | 4 ++-- src/scportrait/tools/h5sc/operations.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/scportrait/tools/h5sc/__init__.py b/src/scportrait/tools/h5sc/__init__.py index e9ed862b..76490ca6 100644 --- a/src/scportrait/tools/h5sc/__init__.py +++ b/src/scportrait/tools/h5sc/__init__.py @@ -5,6 +5,6 @@ Functions to work with scPortrait's standardized single-cell data format. """ -from .operations import add_spatial_coordinates, get_image_index, get_image_with_cellid, update_obs_on_disk +from .operations import add_spatial_coordinates, get_cell_id_index, get_image_with_cellid, update_obs_on_disk -__all__ = ["update_obs_on_disk", "get_image_with_cellid", "get_image_index", "add_spatial_coordinates"] +__all__ = ["update_obs_on_disk", "get_image_with_cellid", "get_cell_id_index", "add_spatial_coordinates"] diff --git a/src/scportrait/tools/h5sc/operations.py b/src/scportrait/tools/h5sc/operations.py index 197915b1..2594d7a4 100644 --- a/src/scportrait/tools/h5sc/operations.py +++ b/src/scportrait/tools/h5sc/operations.py @@ -68,9 +68,9 @@ def update_obs_on_disk(adata: AnnData) -> None: _update_obs_on_disk(adata) -def get_image_index(adata, cell_id: int | list[int]) -> int | list[int]: +def get_cell_id_index(adata, cell_id: int | list[int]) -> int | list[int]: """ - Retrieve the image index (row index) of a specific cell id in a H5SC object. + Retrieve the index (row index) of a specific cell id in a H5SC object. Args: adata: An AnnData object with obsm["single_cell_images"] containing a memory-backed array of the single-cell images. From ec810f816a32b8fbc8103a629e760aba5b90a23a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Thu, 6 Nov 2025 23:36:59 +0100 Subject: [PATCH 07/23] parametrize function --- src/scportrait/tools/h5sc/operations.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/scportrait/tools/h5sc/operations.py b/src/scportrait/tools/h5sc/operations.py index 2594d7a4..e11c8be9 100644 --- a/src/scportrait/tools/h5sc/operations.py +++ b/src/scportrait/tools/h5sc/operations.py @@ -81,10 +81,12 @@ def get_cell_id_index(adata, cell_id: int | list[int]) -> int | list[int]: """ lookup = dict(zip(adata.obs[DEFAULT_CELL_ID_NAME], adata.obs.index.astype(int), strict=True)) - if isinstance(cell_id, int): + assert cell_id in lookup, f"CellID {cell_id} not present in the AnnData object." return lookup[cell_id] + missing = [x for x in cell_id if x not in lookup] + assert not missing, f"CellIDs not present in the AnnData object: {missing}" return [lookup[_id] for _id in cell_id] @@ -99,18 +101,15 @@ def get_image_with_cellid(adata, cell_id: list[int] | int, select_channel: int | Returns: The image(s) of the cell with the passed Cell IDs. """ - lookup = dict(zip(adata.obs[DEFAULT_CELL_ID_NAME], adata.obs.index.astype(int), strict=True)) - image_container = adata.obsm[DEFAULT_NAME_SINGLE_CELL_IMAGES] + idxs = get_cell_id_index(adata, cell_id) + if isinstance(idxs, int): + idxs = [idxs] # Ensure idxs is always a list - if isinstance(cell_id, int): - cell_id = [cell_id] - - for x in cell_id: - assert x in lookup.keys(), f"CellID {x} is not present in the AnnData object." + # get the image container from the AnnData object + image_container = adata.obsm[DEFAULT_NAME_SINGLE_CELL_IMAGES] images = [] - for _id in cell_id: - idx = lookup[_id] + for idx in idxs: if select_channel is None: image = image_container[idx][:] else: From c67cadebc94b0be0cf46dac3e673bc98210063a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Thu, 6 Nov 2025 23:37:11 +0100 Subject: [PATCH 08/23] add print statement --- src/scportrait/pipeline/_utils/spatialdata_helper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/scportrait/pipeline/_utils/spatialdata_helper.py b/src/scportrait/pipeline/_utils/spatialdata_helper.py index 1bfc50db..1ef88a32 100644 --- a/src/scportrait/pipeline/_utils/spatialdata_helper.py +++ b/src/scportrait/pipeline/_utils/spatialdata_helper.py @@ -182,6 +182,7 @@ def calculate_centroids(mask: xarray.DataArray, coordinate_system: str = "global transform = get_transformation(mask, coordinate_system) if check_memory(mask): + print("Array fits in memory, using in-memory calculation.") centers, _, _ids = numba_mask_centroid(mask.values) return make_centers_object(centers, _ids, transform, coordinate_system) From 60b1f526d214e00a285697f9f868d4333874a30e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Sat, 8 Nov 2025 18:02:27 +0100 Subject: [PATCH 09/23] simplify code structure and remove duplicate function --- src/scportrait/tools/h5sc/operations.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/src/scportrait/tools/h5sc/operations.py b/src/scportrait/tools/h5sc/operations.py index e11c8be9..b46d4d7f 100644 --- a/src/scportrait/tools/h5sc/operations.py +++ b/src/scportrait/tools/h5sc/operations.py @@ -28,11 +28,15 @@ ) -def _update_obs_on_disk(adata: AnnData) -> None: +def update_obs_on_disk(adata: AnnData) -> None: """ Temporarily close the HDF5 handle from a read-only AnnData, overwrite .obs on disk, then reopen it and restore the image dataset. + + Args: + adata: AnnData object whose .obs will replace the existing one. """ + # 1. Get the open HDF5 file handle file_handle = adata.uns.get("_h5sc_file_handle", None) @@ -58,16 +62,6 @@ def _update_obs_on_disk(adata: AnnData) -> None: adata.uns["_h5sc_file_handle"] = f -def update_obs_on_disk(adata: AnnData) -> None: - """ - Overwrite the .obs table in an existing .h5sc file on disk. - - Args: - adata: AnnData object whose .obs will replace the existing one. - """ - _update_obs_on_disk(adata) - - def get_cell_id_index(adata, cell_id: int | list[int]) -> int | list[int]: """ Retrieve the index (row index) of a specific cell id in a H5SC object. From f6d636a4dfbc4bad646419d19621390ded279b1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Sat, 8 Nov 2025 18:07:02 +0100 Subject: [PATCH 10/23] [FEATURE] add functions to subset h5sc objects --- src/scportrait/tools/h5sc/__init__.py | 18 +++- src/scportrait/tools/h5sc/operations.py | 132 ++++++++++++++++++++++-- 2 files changed, 141 insertions(+), 9 deletions(-) diff --git a/src/scportrait/tools/h5sc/__init__.py b/src/scportrait/tools/h5sc/__init__.py index 76490ca6..de8e4323 100644 --- a/src/scportrait/tools/h5sc/__init__.py +++ b/src/scportrait/tools/h5sc/__init__.py @@ -5,6 +5,20 @@ Functions to work with scPortrait's standardized single-cell data format. """ -from .operations import add_spatial_coordinates, get_cell_id_index, get_image_with_cellid, update_obs_on_disk +from .operations import ( + add_spatial_coordinates, + get_cell_id_index, + get_image_with_cellid, + subset_cells_region, + subset_h5sc, + update_obs_on_disk, +) -__all__ = ["update_obs_on_disk", "get_image_with_cellid", "get_cell_id_index", "add_spatial_coordinates"] +__all__ = [ + "update_obs_on_disk", + "get_image_with_cellid", + "get_cell_id_index", + "add_spatial_coordinates", + "subset_h5sc", + "subset_cells_region", +] diff --git a/src/scportrait/tools/h5sc/operations.py b/src/scportrait/tools/h5sc/operations.py index b46d4d7f..c5bd973d 100644 --- a/src/scportrait/tools/h5sc/operations.py +++ b/src/scportrait/tools/h5sc/operations.py @@ -8,18 +8,24 @@ from __future__ import annotations from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dask.dataframe.core import DataFrame as da_DataFrame + from spatialdata import SpatialData + +import os +import shutil +from pathlib import Path from warnings import warn import dask.array as da +import geopandas as gpd import h5py import numpy as np +from anndata import AnnData +from shapely.geometry import Point from scportrait.io.h5sc import read_h5sc - -if TYPE_CHECKING: - from anndata import AnnData - from dask.dataframe.core import DataFrame as da_DataFrame - from scportrait.pipeline._utils.constants import ( DEFAULT_CELL_ID_NAME, DEFAULT_IDENTIFIER_FILENAME, @@ -62,7 +68,7 @@ def update_obs_on_disk(adata: AnnData) -> None: adata.uns["_h5sc_file_handle"] = f -def get_cell_id_index(adata, cell_id: int | list[int]) -> int | list[int]: +def get_cell_id_index(adata: AnnData, cell_id: int | list[int]) -> int | list[int]: """ Retrieve the index (row index) of a specific cell id in a H5SC object. @@ -84,7 +90,9 @@ def get_cell_id_index(adata, cell_id: int | list[int]) -> int | list[int]: return [lookup[_id] for _id in cell_id] -def get_image_with_cellid(adata, cell_id: list[int] | int, select_channel: int | list[int] | None = None) -> np.ndarray: +def get_image_with_cellid( + adata: AnnData, cell_id: list[int] | int, select_channel: int | list[int] | None = None +) -> np.ndarray: """Get single cell images from the cells with the provided cell IDs. Images are returned in the order of the cell IDs. Args: @@ -150,3 +158,113 @@ def add_spatial_coordinates( if update_on_disk: update_obs_on_disk(adata) + + +def subset_cells_region( + adata: AnnData, + sdata: SpatialData, + region_name: str, + outpath: str | Path = None, + within_region: bool = True, + to_disk: bool = True, + return_anndata: bool = True, +) -> AnnData | None: + """ + Subset cells in the specified region. + + Args: + adata: AnnData object containing the cell data. + sdata: SpatialData object containing the region geometry. + region_name: Name of the region to subset cells from. + outpath: Path to save the subsetted AnnData object. If None, the subsetted file is saved in the same directory as the original h5sc file with a prefix "subset_{select_region}". + within_region: If True, select cells within the region. If False, select cells outside the region. + to_disk: If True, save the subsetted AnnData object to disk. If False, return the subsetted AnnData object in memory. + return_anndata: If True, return a memory mapped version of the subsetted AnnData object. + + Returns: + If `to_disk` is False, returns the subsetted AnnData object. If `to_disk` is True, saves the subsetted AnnData object to disk and returns None. + """ + if outpath is not None: + if not isinstance(outpath, (str | Path)): + raise ValueError("outpath must be a string or Path object.") + assert to_disk, "outpath is only used if to_disk is True." + + if region_name not in sdata: + raise ValueError(f"Region '{region_name}' not found in spatialdata object.") + + xs, ys = adata.obs.get(["x", "y"]).values.T + points = gpd.GeoSeries([Point(xi, yi) for xi, yi in zip(xs, ys, strict=True)]) + is_inside = points.apply(lambda p: sdata[region_name].geometry.contains(p).any()).values + + if not within_region: + selection = ~is_inside + key = "outside" + else: + selection = is_inside + key = "within" + + if not to_disk: + return adata[selection] + else: + cell_ids = adata.obs.loc[selection, DEFAULT_CELL_ID_NAME].values + + if outpath is None: + outpath = adata.uns["h5sc_source_path"].replace("single_cells.h5sc", f"subset_{key}_{region_name}.h5sc") + subset_h5sc(adata, cell_ids, outpath=outpath) + + if return_anndata: + return read_h5sc(outpath) + else: + return None + + +def subset_h5sc(adata: AnnData, cell_id: int | list[int], outpath: str | Path) -> None: + """ + Write a subset of the AnnData object to disk based on the provided cell IDs. + + Args: + adata: AnnData object containing the single-cell data. + cell_id: A single cell ID or a list of cell IDs to subset the AnnData + outpath: Path to save the subsetted AnnData object. + + Returns: + None. The AnnData object is written to disk at the specified outpath. + """ + idx = get_cell_id_index(adata, cell_id) + + if isinstance(idx, int): + idx = [idx] # Ensure idx is always a list + + obs = adata.obs.iloc[idx, :].copy() + obs.reset_index(drop=True, inplace=True) + obs.index = obs.index.astype(str) # Ensure index is string type for consistency + var = adata.var.copy() + uns = {DEFAULT_NAME_SINGLE_CELL_IMAGES: adata.uns[DEFAULT_NAME_SINGLE_CELL_IMAGES]} + + adata_subset = AnnData(obs=obs, var=var, uns=uns) + + if os.path.exists(outpath): + shutil.rmtree(outpath, ignore_errors=True) + adata_subset.write_h5ad(outpath) + + # initialize the obsm with the single cell images + orig = adata.obsm[DEFAULT_NAME_SINGLE_CELL_IMAGES] + single_cell_data_shape = (len(idx),) + orig.shape[1:] + with h5py.File(outpath, "a") as hf: + hf.create_dataset( + IMAGE_DATACONTAINER_NAME, + shape=single_cell_data_shape, + chunks=orig.chunks, + compression=orig.compression, + dtype=orig.dtype, + ) + for key, value in orig.attrs.items(): + hf[IMAGE_DATACONTAINER_NAME].attrs[key] = value + + # transfer the images + for i, ix in enumerate(idx): + hf[IMAGE_DATACONTAINER_NAME][i] = orig[ix] + hf.close() + + print(f"Subsetted AnnData object saved to {outpath}.") + return None From 7efb1ffb9444a91e9116c5ed2d80d860cedd9a0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Sat, 8 Nov 2025 18:11:21 +0100 Subject: [PATCH 11/23] ruff linting and fix pre-commit issues --- .../pipeline/mask_filtering/region_masking.py | 32 ++++++++----------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/src/scportrait/pipeline/mask_filtering/region_masking.py b/src/scportrait/pipeline/mask_filtering/region_masking.py index fb2a9c69..41e0779a 100644 --- a/src/scportrait/pipeline/mask_filtering/region_masking.py +++ b/src/scportrait/pipeline/mask_filtering/region_masking.py @@ -1,19 +1,19 @@ +import dask +import numpy as np +import rasterio import spatialdata as sd from napari_spatialdata import Interactive -from shapely.geometry import mapping from rasterio.features import geometry_mask -import rasterio -import dask -from spatialdata.models import Image2DModel -import numpy as np from scipy import ndimage -from skimage.measure import find_contours -from shapely.geometry import Polygon from shapely import unary_union -from skimage.segmentation import watershed +from shapely.geometry import Polygon, mapping from skimage.draw import disk +from skimage.measure import find_contours +from skimage.segmentation import watershed +from spatialdata.models import Image2DModel + -def mask_image(sdata, image, mask, invert, automatic_masking, threshold, overwrite, masked_image_name) +def mask_image(sdata, image, mask, invert, automatic_masking, threshold, overwrite, masked_image_name): """ Given an image and mask, either masks or crops the image. @@ -51,14 +51,9 @@ def mask_image(sdata, image, mask, invert, automatic_masking, threshold, overwri polygon_geom = [mapping(polygon)] - transform = rasterio.transform.Affine(1, 0, 0, 0, 1, 0) # identity transform + transform = rasterio.transform.Affine(1, 0, 0, 0, 1, 0) # identity transform - image_mask = geometry_mask( - polygon_geom, - invert=invert, - out_shape=(height, width), - transform=transform - ) + image_mask = geometry_mask(polygon_geom, invert=invert, out_shape=(height, width), transform=transform) if channels > 1: image_mask = dask.array.broadcast_to(image_mask, (channels, height, width)) @@ -74,6 +69,7 @@ def mask_image(sdata, image, mask, invert, automatic_masking, threshold, overwri masked_image_name = f"{image}_masked" sdata.images[masked_image_name] = images["masked_image"] + def _draw_polygons(image, threshold): """ Given an image, detect regions to turn into polygon shapes, which are then used as a mask. @@ -101,6 +97,6 @@ def _draw_polygons(image, threshold): contours = find_contours(segmented, level=0.5) polygons = [Polygon(contour) for contour in contours if len(contour) > 2] - polygon = unary_union(polygons) + polygon = unary_union(polygons) - return polygon \ No newline at end of file + return polygon From ba50312379e17c8f288d829348a696d6bf690214 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Sat, 8 Nov 2025 18:50:18 +0100 Subject: [PATCH 12/23] add general synthetic test objects that closely mimic real objects --- tests/conftest.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 tests/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..4379f509 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,39 @@ +import numpy as np +import pandas as pd +import pytest +from anndata import AnnData + +rng = np.random.default_rng() + + +@pytest.fixture +def h5sc_object(): + # Two cells, two channels, small images + n_cells = 2 + channel_names = np.array(["seg_all_nucleus", "ch0", "ch1"]) + channel_mapping = np.array(["mask", "image", "image"]) # or whatever mapping your code expects + n_channels = len(channel_names) + H, W = 10, 10 + + # --- obs --- + obs = pd.DataFrame({"scportrait_cell_id": [101, 102]}, index=[0, 1]) + + # --- var (channel metadata) --- + var = pd.DataFrame(index=np.arange(n_channels).astype("str")) + var["channels"] = channel_names + var["channel_mapping"] = channel_mapping + + adata = AnnData(obs=obs, var=var) + adata.obsm["single_cell_images"] = rng.random((n_cells, n_channels, H, W)) + adata.uns["single_cell_images"] = { + "channel_mapping": channel_mapping, + "channel_names": channel_names, + "compression": "lzf", + "image_size": np.int64(H), + "n_cells": np.int64(n_cells), + "n_channels": np.int64(n_channels), + "n_image_channels": np.int64(n_channels - 1), + "n_masks": np.int64(1), + } + + return adata From 471a8e64cb0749773f4ed808034cf21a6fc68a26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Sat, 8 Nov 2025 18:50:34 +0100 Subject: [PATCH 13/23] [TEST] improve h5sc plotting tests --- tests/unit_tests/plotting/test_h5sc.py | 52 +++++--------------------- 1 file changed, 10 insertions(+), 42 deletions(-) diff --git a/tests/unit_tests/plotting/test_h5sc.py b/tests/unit_tests/plotting/test_h5sc.py index bc92298a..12e73a1a 100644 --- a/tests/unit_tests/plotting/test_h5sc.py +++ b/tests/unit_tests/plotting/test_h5sc.py @@ -1,5 +1,3 @@ -from unittest.mock import MagicMock, patch - import matplotlib.pyplot as plt import numpy as np import pytest @@ -16,9 +14,8 @@ cell_grid_single_channel, ) -# ---------- _reshape_image_array ---------- - +# ---------- _reshape_image_array ---------- @pytest.mark.parametrize( "input_shape, expected_shape", [ @@ -34,8 +31,6 @@ def test_reshape_image_array(input_shape, expected_shape): # ---------- _plot_image_grid ---------- - - @pytest.mark.parametrize( "input_shape, nrows, ncols, col_labels, col_labels_rotation", [ @@ -71,41 +66,22 @@ def test_plot_image_grid(input_shape, nrows, ncols, col_labels, col_labels_rotat # ---------- cell_grid_single_channel ---------- - - -@patch("scportrait.plotting.h5sc.get_image_with_cellid") -def test_cell_grid_single_channel_returns_figure_with_title_rotation(mock_get_img): - mock_adata = MagicMock() - mock_adata.uns = {"single_cell_images": {"channel_names": np.array(["ch0", "ch1"])}} - mock_adata.obs = MagicMock() - mock_adata.obs.__getitem__.return_value.sample.return_value.values = [101, 102] - mock_get_img.return_value = rng.random((2, 10, 10)) - +def test_cell_grid_single_channel_returns_figure_with_title_rotation(h5sc_object): fig = cell_grid_single_channel( - adata=mock_adata, - select_channel=0, + adata=h5sc_object, + select_channel="ch0", n_cells=2, return_fig=True, show_fig=False, title_rotation=30, # new parameter ) assert isinstance(fig, Figure) - # You could further check that fig.axes[0].get_title() is set, but rotation may not be directly visible # ---------- cell_grid_multi_channel ---------- - - -@patch("scportrait.plotting.h5sc.get_image_with_cellid") -def test_cell_grid_multi_channel_returns_figure_with_channel_label_rotation(mock_get_img): - mock_adata = MagicMock() - mock_adata.uns = {"single_cell_images": {"channel_names": ["ch0", "ch1"], "n_channels": 2}} - mock_adata.obs = MagicMock() - mock_adata.obs.__getitem__.return_value.sample.return_value.values = [101, 102] - mock_get_img.return_value = rng.random((2, 2, 10, 10)) - +def test_cell_grid_multi_channel_returns_figure_with_channel_label_rotation(h5sc_object): fig = cell_grid_multi_channel( - adata=mock_adata, + adata=h5sc_object, n_cells=2, return_fig=True, show_fig=False, @@ -116,17 +92,9 @@ def test_cell_grid_multi_channel_returns_figure_with_channel_label_rotation(mock # ---------- cell_grid ---------- +def test_cell_grid_dispatches_to_single_channel(h5sc_object): + cell_grid(adata=h5sc_object, select_channel="ch1", n_cells=1, show_fig=False) -@patch("scportrait.plotting.h5sc.cell_grid_single_channel") -def test_cell_grid_dispatches_to_single_channel(mock_single): - mock_adata = MagicMock() - cell_grid(adata=mock_adata, select_channel=1, n_cells=1, show_fig=False) - assert mock_single.called - - -@patch("scportrait.plotting.h5sc.cell_grid_multi_channel") -def test_cell_grid_dispatches_to_multi_channel(mock_multi): - mock_adata = MagicMock() - cell_grid(adata=mock_adata, select_channel=[0, 1], n_cells=1, show_fig=False) - assert mock_multi.called +def test_cell_grid_dispatches_to_multi_channel(h5sc_object): + cell_grid(adata=h5sc_object, select_channel=["ch0", "ch1"], n_cells=1, show_fig=False) From 04ca7e8764ba1bcaba2f17918fafa502deb3ea02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Sat, 8 Nov 2025 19:36:27 +0100 Subject: [PATCH 14/23] [TESTS] add tests for h5sc operations doesn't yet cover all functions --- tests/conftest.py | 23 ++++++-- .../unit_tests/tools/h5sc/test_operations.py | 56 +++++++++++++++++++ 2 files changed, 75 insertions(+), 4 deletions(-) create mode 100644 tests/unit_tests/tools/h5sc/test_operations.py diff --git a/tests/conftest.py b/tests/conftest.py index 4379f509..fc2b1e80 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,22 +1,27 @@ +import shutil + import numpy as np import pandas as pd import pytest from anndata import AnnData +from spatialdata import SpatialData +from spatialdata.datasets import blobs rng = np.random.default_rng() @pytest.fixture -def h5sc_object(): +def h5sc_object() -> AnnData: # Two cells, two channels, small images - n_cells = 2 + cell_ids = [101, 102, 107, 109] + n_cells = 4 channel_names = np.array(["seg_all_nucleus", "ch0", "ch1"]) channel_mapping = np.array(["mask", "image", "image"]) # or whatever mapping your code expects n_channels = len(channel_names) H, W = 10, 10 # --- obs --- - obs = pd.DataFrame({"scportrait_cell_id": [101, 102]}, index=[0, 1]) + obs = pd.DataFrame({"scportrait_cell_id": cell_ids}, index=np.arange(n_cells)) # --- var (channel metadata) --- var = pd.DataFrame(index=np.arange(n_channels).astype("str")) @@ -36,4 +41,14 @@ def h5sc_object(): "n_masks": np.int64(1), } - return adata + yield adata + + +@pytest.fixture() +def sdata(tmp_path) -> SpatialData: + sdata = blobs() + # Write to temporary location + sdata_path = tmp_path / "sdata.zarr" + sdata.write(sdata_path) + yield sdata + shutil.rmtree(sdata_path) diff --git a/tests/unit_tests/tools/h5sc/test_operations.py b/tests/unit_tests/tools/h5sc/test_operations.py new file mode 100644 index 00000000..f0966307 --- /dev/null +++ b/tests/unit_tests/tools/h5sc/test_operations.py @@ -0,0 +1,56 @@ +# tests/test_operations.py + +from pathlib import Path + +import anndata as ad +import h5py +import numpy as np +import pandas as pd +import pytest + +from scportrait.io import read_h5sc +from scportrait.tl.h5sc import ( + add_spatial_coordinates, + get_cell_id_index, + subset_cells_region, + subset_h5sc, + update_obs_on_disk, +) + +rng = np.random.default_rng() + + +def test_update_obs_on_disk(h5sc_object, tmp_path): + # Write h5ad + p = tmp_path / "test.h5ad" + h5sc_object.write(p) + + h5sc_object.uns["h5sc_source_path"] = str(p) + size = h5sc_object.obs.shape[0] + + # Modify obs + random_values = rng.integers(1, 10, size=size) + h5sc_object.obs["new_col"] = random_values + update_obs_on_disk(h5sc_object) + + # Reload and confirm updated + reloaded = read_h5sc(p) + assert "new_col" in reloaded.obs.columns + assert np.all(reloaded.obs["new_col"] == random_values) + + +def test_get_cell_id_index_single(h5sc_object): + idx = get_cell_id_index(h5sc_object, 107) + assert idx == 2 + + +def test_get_cell_id_index_list(h5sc_object): + idx = get_cell_id_index(h5sc_object, [101, 109]) + assert idx == [0, 3] + + +def test_subset_h5sc(h5sc_object, tmp_path): + out = tmp_path / "subset.h5sc" + subset_h5sc(h5sc_object, [101, 102], out) + + assert out.exists() From 9823f94bef0f99df7105b9c7b6137b0738232bfe Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 4 Jan 2026 10:40:53 +0000 Subject: [PATCH 15/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scportrait/tools/h5sc/operations.py | 6 +++--- tests/unit_tests/tools/h5sc/test_operations.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/scportrait/tools/h5sc/operations.py b/src/scportrait/tools/h5sc/operations.py index c5bd973d..4d73dcc7 100644 --- a/src/scportrait/tools/h5sc/operations.py +++ b/src/scportrait/tools/h5sc/operations.py @@ -143,9 +143,9 @@ def add_spatial_coordinates( """ assert cell_id_identifier in adata.obs.columns, f"{cell_id_identifier} must be a column in h5sc.obs" - assert ( - ["x", "y"] == list(centers_object.columns) - ), "centers_object must be scportrait's standardized centers object containing columns 'x' and 'y' and the scportrait cell id as index, but detected columns are {centers_object.columns}" + assert ["x", "y"] == list(centers_object.columns), ( + "centers_object must be scportrait's standardized centers object containing columns 'x' and 'y' and the scportrait cell id as index, but detected columns are {centers_object.columns}" + ) if ("x" in adata.obs.columns) or ("y" in adata.obs.columns): adata.obs.drop(columns=["x", "y"], inplace=True, errors="ignore") diff --git a/tests/unit_tests/tools/h5sc/test_operations.py b/tests/unit_tests/tools/h5sc/test_operations.py index f0966307..af628f78 100644 --- a/tests/unit_tests/tools/h5sc/test_operations.py +++ b/tests/unit_tests/tools/h5sc/test_operations.py @@ -7,8 +7,6 @@ import numpy as np import pandas as pd import pytest - -from scportrait.io import read_h5sc from scportrait.tl.h5sc import ( add_spatial_coordinates, get_cell_id_index, @@ -17,6 +15,8 @@ update_obs_on_disk, ) +from scportrait.io import read_h5sc + rng = np.random.default_rng() From d3f488e27df3d62e6a3a50b3cbba0f5b63af5187 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Sun, 18 Jan 2026 17:31:12 +0100 Subject: [PATCH 16/23] [FIX] incorrect warning condition --- src/scportrait/tools/sdata/write/_write.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scportrait/tools/sdata/write/_write.py b/src/scportrait/tools/sdata/write/_write.py index 5d5f21f7..76049461 100644 --- a/src/scportrait/tools/sdata/write/_write.py +++ b/src/scportrait/tools/sdata/write/_write.py @@ -79,7 +79,7 @@ def image( f"Number of channel names ({len(channel_names)}) does not match the number of channels in the image ({image.shape[0]})." ) channel_names_old = image.coords["c"].values.tolist() - if channel_names_old != channel_names: + if any(channel_names_old != channel_names): warnings.warn( f"Channel names in the DataArray ({channel_names_old}) do not match the provided channel names ({channel_names}). The DataArray will be updated with the provided channel names.", stacklevel=2, From e746ca5d2e29e7ced4cd1dc467850b2fddfe75d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Sun, 18 Jan 2026 18:14:07 +0100 Subject: [PATCH 17/23] [FIX] docstrings and import statements --- src/scportrait/tools/sdata/processing/_subset.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/scportrait/tools/sdata/processing/_subset.py b/src/scportrait/tools/sdata/processing/_subset.py index f489baa7..2f68b304 100644 --- a/src/scportrait/tools/sdata/processing/_subset.py +++ b/src/scportrait/tools/sdata/processing/_subset.py @@ -1,12 +1,12 @@ import copy import warnings -import spatialdata +from spatialdata import SpatialData def get_bounding_box_sdata( - sdata: spatialdata, max_width: int, center_x: int, center_y: int, drop_points: bool = True -) -> spatialdata: + sdata: SpatialData, max_width: int, center_x: int, center_y: int, drop_points: bool = True +) -> SpatialData: """apply bounding box to sdata object Args: @@ -16,7 +16,7 @@ def get_bounding_box_sdata( center_y: y coordinate of the center of the bounding box Returns: - spatialdata: spatialdata object with bounding box applied + spatialdata object with bounding box applied """ _sdata = sdata # remove points object to improve subsetting @@ -54,7 +54,7 @@ def get_bounding_box_sdata( if drop_points: # re-add points object - __sdata = spatialdata.SpatialData.read(sdata.path, selection=["points"]) + __sdata = SpatialData.read(sdata.path, selection=["points"]) for x in points_keys: sdata[x] = __sdata[x] del __sdata From c3ab5b5570f2e3787f8da0f9c2e95deb5464274e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Sun, 18 Jan 2026 18:16:18 +0100 Subject: [PATCH 18/23] [FEATURE] first working version of the mask_region function --- .../tools/sdata/processing/__init__.py | 4 +- .../tools/sdata/processing/_subset.py | 94 ++++++++++++++++++- 2 files changed, 95 insertions(+), 3 deletions(-) diff --git a/src/scportrait/tools/sdata/processing/__init__.py b/src/scportrait/tools/sdata/processing/__init__.py index 342d45d1..e7d2212b 100644 --- a/src/scportrait/tools/sdata/processing/__init__.py +++ b/src/scportrait/tools/sdata/processing/__init__.py @@ -1,4 +1,4 @@ from ._image_processing import percentile_normalize_image -from ._subset import get_bounding_box_sdata +from ._subset import get_bounding_box_sdata, mask_region -__all__ = ["percentile_normalize_image", "get_bounding_box_sdata"] +__all__ = ["percentile_normalize_image", "get_bounding_box_sdata", "mask_region"] diff --git a/src/scportrait/tools/sdata/processing/_subset.py b/src/scportrait/tools/sdata/processing/_subset.py index 2f68b304..5ca98c54 100644 --- a/src/scportrait/tools/sdata/processing/_subset.py +++ b/src/scportrait/tools/sdata/processing/_subset.py @@ -1,6 +1,11 @@ -import copy import warnings +import dask.array as da +import numpy as np +import xarray as xr +from affine import Affine +from rasterio.features import rasterize +from shapely.geometry import mapping from spatialdata import SpatialData @@ -60,3 +65,90 @@ def get_bounding_box_sdata( del __sdata return _sdata + + +def mask_region( + sdata: SpatialData, + image_name: str = "input_image", + shape_name: str = "select_region", + mask: bool = True, + crop: bool = False, +) -> xr.DataArray: + """Mask and/or crop the input image to the selected region. + + Args: + sdata: SpatialData object containing the image and shape. + image_name: Name of the image to be masked/cropped. + shape_name: Name of the shape to mask/crop the image with. + mask: Whether to apply the mask to the image. Default is True. + crop: Whether to crop the image to the outer bounding box of the shape. Default is False. + Returns: + masked/cropped image as a DataArray. If crop is False, the image has the same dimensions as the input image, otherwise it has the dimensions of the outer bounding box of the shape. + """ + assert mask or crop, "Either mask or crop must be True" + + # get image and check for proper scaling + if image_name not in sdata: + raise ValueError(f"Image {image_name} not found in sdata") + image = sdata[image_name] + + if isinstance(image, xr.DataTree): + image = image.get("scale0").image + + elif isinstance(image, xr.DataArray): + image = image + + print(image.dtype) + + # get shape and check for single-shape selection + shape = sdata[shape_name].geometry + if len(shape) == 1: + shape = shape[0] + elif len(shape) > 1: + raise ValueError("Expected a single shape, but found multiple shapes. Please select only one region.") + else: + raise ValueError("No shapes found in the specified region.") + + # initialize empty array + H, W = image.sizes["y"], image.sizes["x"] + chunks_yx = (image.data.chunks[image.get_axis_num("y")], image.data.chunks[image.get_axis_num("x")]) + template = da.zeros((H, W), chunks=chunks_yx, dtype=np.uint16) + + def _mask_block(block, block_info=None): + info = block_info[None] + (y0, y1), (x0, x1) = info["array-location"][:2] + h, w = (y1 - y0), (x1 - x0) + + # shift transform to this block’s window + window_transform = Affine.translation(x0, y0) + + m = rasterize( + [(geom, 1)], + out_shape=(h, w), + transform=window_transform, + fill=0, + dtype=image.dtype, + all_touched=True, # set True if you want any touched pixel included + ) + return m.astype(bool) + + geom = mapping(shape) + mask_dask = da.map_blocks(_mask_block, template, dtype=bool) + mask_da = xr.DataArray(mask_dask, dims=("y", "x"), coords={"y": image.coords["y"], "x": image.coords["x"]}) + + other = np.array(0, dtype=image.dtype) + if mask: + if "c" in image.dims: + m = mask_da.broadcast_like(image.isel(c=0)) + masked = image.where(m, other=other) + else: + masked = image.where(mask_da, other=other) + else: + masked = image + + if crop: + minx, miny, maxx, maxy = shape.bounds + minx, miny, maxx, maxy = int(np.floor(minx)), int(np.floor(miny)), int(np.ceil(maxx)), int(np.ceil(maxy)) + return masked.isel(x=slice(minx, maxx), y=slice(miny, maxy)) + else: + return masked From 6c35a4234cc06b3eade97e9fb6822e703913c86d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Sun, 18 Jan 2026 18:36:49 +0100 Subject: [PATCH 19/23] [TEST] add tests for _subset module --- .../tools/sdata/processing/_subset.py | 4 +- tests/conftest.py | 31 ++++ tests/unit_tests/tools/sdata/test_subset.py | 140 ++++++++++++++++++ 3 files changed, 173 insertions(+), 2 deletions(-) create mode 100644 tests/unit_tests/tools/sdata/test_subset.py diff --git a/src/scportrait/tools/sdata/processing/_subset.py b/src/scportrait/tools/sdata/processing/_subset.py index 5ca98c54..ec8399d9 100644 --- a/src/scportrait/tools/sdata/processing/_subset.py +++ b/src/scportrait/tools/sdata/processing/_subset.py @@ -15,13 +15,13 @@ def get_bounding_box_sdata( """apply bounding box to sdata object Args: - sdata: spatialdata object + sdata: SpatialData object max_width: maximum width of the bounding box center_x: x coordinate of the center of the bounding box center_y: y coordinate of the center of the bounding box Returns: - spatialdata object with bounding box applied + SpatialData object with bounding box applied """ _sdata = sdata # remove points object to improve subsetting diff --git a/tests/conftest.py b/tests/conftest.py index 14a803d1..3184fd0c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,14 @@ import shutil +import geopandas as gpd import numpy as np import pandas as pd import pytest from anndata import AnnData +from shapely.geometry import box from spatialdata import SpatialData from spatialdata.datasets import blobs +from spatialdata.models import Image2DModel, ShapesModel rng = np.random.default_rng() @@ -60,3 +63,31 @@ def sdata_with_labels() -> SpatialData: sdata["table"].obs["labelling_categorical"] = sdata["table"].obs["instance_id"].astype("category") sdata["table"].obs["labelling_continous"] = (sdata["table"].obs["instance_id"] > 10).astype(float) return sdata + + +@pytest.fixture +def sdata_with_selected_region(): + image = np.ones((1, 10, 10), dtype=np.uint16) + shape = box(2, 3, 7, 8) + image_model = Image2DModel.parse(image, dims=("c", "y", "x")) + shapes_gdf = gpd.GeoDataFrame({"geometry": [shape]}) + shapes_model = ShapesModel.parse(shapes_gdf) + sdata = SpatialData(images={"input_image": image_model}, shapes={"select_region": shapes_model}) + return sdata + + +@pytest.fixture +def sdata_builder(): + def _build( + image, + shapes, + image_name="input_image", + shape_name="select_region", + dims=("c", "y", "x"), + ): + image_model = Image2DModel.parse(image, dims=dims) + shapes_gdf = gpd.GeoDataFrame({"geometry": shapes}) + shapes_model = ShapesModel.parse(shapes_gdf) + return SpatialData(images={image_name: image_model}, shapes={shape_name: shapes_model}) + + return _build diff --git a/tests/unit_tests/tools/sdata/test_subset.py b/tests/unit_tests/tools/sdata/test_subset.py new file mode 100644 index 00000000..0343b208 --- /dev/null +++ b/tests/unit_tests/tools/sdata/test_subset.py @@ -0,0 +1,140 @@ +import numpy as np +import pytest +from shapely.geometry import box + +from scportrait.tools.sdata.processing._subset import get_bounding_box_sdata, mask_region +from scportrait.tools.sdata.write._helper import _get_image + + +def _as_numpy(image): + data = image.data + if hasattr(data, "compute"): + data = data.compute() + return np.asarray(data) + + +def _expected_mask(masked, shape): + x_coords = masked.coords["x"].values + y_coords = masked.coords["y"].values + x_mask = (x_coords >= shape.bounds[0]) & (x_coords <= shape.bounds[2]) + y_mask = (y_coords >= shape.bounds[1]) & (y_coords <= shape.bounds[3]) + return np.outer(y_mask, x_mask).astype(bool) + + +@pytest.mark.parametrize( + "mask,crop", + [ + (True, False), + (False, True), + (True, True), + ], +) +def test_mask_region_mask_crop_combinations(sdata_with_selected_region, mask, crop): + image = _as_numpy(sdata_with_selected_region.images["input_image"]) + shape = sdata_with_selected_region["select_region"].geometry[0] + + result = mask_region( + sdata_with_selected_region, image_name="input_image", shape_name="select_region", mask=mask, crop=crop + ) + result_np = _as_numpy(result) + + minx, miny, maxx, maxy = shape.bounds + x0, y0 = int(np.floor(minx)), int(np.floor(miny)) + x1, y1 = int(np.ceil(maxx)), int(np.ceil(maxy)) + + if crop: + expected = image[:, y0:y1, x0:x1] + assert result_np.shape == expected.shape + if mask: + assert result_np.min() >= 0 + assert result_np.max() == expected.max() + else: + np.testing.assert_array_equal(result_np, expected) + else: + expected_mask = _expected_mask(result, shape) + values = result_np[0] + assert values[expected_mask].min() == 1 + assert values[~expected_mask].max() == 0 + + +def test_mask_region_requires_mask_or_crop(sdata_with_selected_region): + with pytest.raises(AssertionError): + mask_region( + sdata_with_selected_region, image_name="input_image", shape_name="select_region", mask=False, crop=False + ) + + +def test_mask_region_missing_image_raises(sdata_with_selected_region): + with pytest.raises(ValueError): + mask_region(sdata_with_selected_region, image_name="missing", shape_name="select_region", mask=True, crop=False) + + +def test_mask_region_missing_shape_raises(sdata_with_selected_region): + with pytest.raises(KeyError): + mask_region(sdata_with_selected_region, image_name="input_image", shape_name="missing", mask=True, crop=False) + + +def test_mask_region_multiple_shapes_raises(sdata_builder): + image = np.ones((1, 6, 6), dtype=np.uint16) + shapes = [box(1, 1, 3, 3), box(2, 2, 4, 4)] + sdata = sdata_builder(image, shapes) + + with pytest.raises(ValueError): + mask_region(sdata, image_name="input_image", shape_name="select_region", mask=True, crop=False) + + +@pytest.mark.parametrize( + "max_width,center_x,center_y", + [ + (4, 5, 5), + (6, 2, 8), + ], +) +def test_get_bounding_box_sdata_reduces_extent(sdata_builder, max_width, center_x, center_y): + image = np.ones((1, 10, 10), dtype=np.uint16) + shapes = [box(1, 1, 3, 3)] + sdata = sdata_builder(image, shapes) + + subset = get_bounding_box_sdata( + sdata, + max_width=max_width, + center_x=center_x, + center_y=center_y, + drop_points=False, + ) + + image_subset = _get_image(subset.images["input_image"]) + assert image_subset.sizes["x"] <= 10 + assert image_subset.sizes["y"] <= 10 + assert image_subset.sizes["x"] > 0 + assert image_subset.sizes["y"] > 0 + assert image_subset.sizes["x"] <= max_width + assert image_subset.sizes["y"] <= max_width + + +@pytest.mark.parametrize( + "center_x,center_y", + [ + (0, 0), + (-5, -5), + ], +) +def test_get_bounding_box_sdata_clamps_edges(sdata_builder, center_x, center_y): + image = np.ones((1, 10, 10), dtype=np.uint16) + shapes = [box(1, 1, 3, 3)] + sdata = sdata_builder(image, shapes) + + subset = get_bounding_box_sdata( + sdata, + max_width=4, + center_x=center_x, + center_y=center_y, + drop_points=False, + ) + + image_subset = _get_image(subset.images["input_image"]) + x_coords = image_subset.coords["x"].values + y_coords = image_subset.coords["y"].values + + assert x_coords.min() >= sdata.images["input_image"].coords["x"].values.min() + assert y_coords.min() >= sdata.images["input_image"].coords["y"].values.min() From f95f48fbd733ef6ec3d77d76b5f1421958c76391 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Sun, 18 Jan 2026 18:42:51 +0100 Subject: [PATCH 20/23] [FIX] update requirements to support new mask region function --- requirements/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 8155ff6c..0071cab3 100755 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -7,6 +7,7 @@ anndata<0.12 spatialdata>=0.3.0,<0.6 pyarrow<22.0.0 py-lmd>=1.3.1 +affine spatialdata-plot<=0.2.11 matplotlib From 15cc8ae061972c6100f5c5a8fcccb6d0c9e509f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Sun, 18 Jan 2026 18:54:28 +0100 Subject: [PATCH 21/23] [FIX] add missing requirement --- requirements/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 0071cab3..f200a5f3 100755 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -8,6 +8,7 @@ spatialdata>=0.3.0,<0.6 pyarrow<22.0.0 py-lmd>=1.3.1 affine +rasterio spatialdata-plot<=0.2.11 matplotlib From 3b85f155d923b5f98a90d1ebbb3cc8a1d2d5fa70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Sun, 18 Jan 2026 19:01:21 +0100 Subject: [PATCH 22/23] [FIX] also support numpy obsm values not just hdf5 backed --- src/scportrait/tools/h5sc/operations.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/scportrait/tools/h5sc/operations.py b/src/scportrait/tools/h5sc/operations.py index 4d73dcc7..09a1e1d8 100644 --- a/src/scportrait/tools/h5sc/operations.py +++ b/src/scportrait/tools/h5sc/operations.py @@ -250,16 +250,19 @@ def subset_h5sc(adata: AnnData, cell_id: int | list[int], outpath: str | Path) - # initialize the obsm with the single cell images orig = adata.obsm[DEFAULT_NAME_SINGLE_CELL_IMAGES] single_cell_data_shape = (len(idx),) + orig.shape[1:] + chunks = orig.chunks if hasattr(orig, "chunks") else None + compression = orig.compression if hasattr(orig, "compression") else None with h5py.File(outpath, "a") as hf: hf.create_dataset( IMAGE_DATACONTAINER_NAME, shape=single_cell_data_shape, - chunks=orig.chunks, - compression=orig.compression, + chunks=chunks, + compression=compression, dtype=orig.dtype, ) - for key, value in orig.attrs.items(): - hf[IMAGE_DATACONTAINER_NAME].attrs[key] = value + if hasattr(orig, "attrs"): + for key, value in orig.attrs.items(): + hf[IMAGE_DATACONTAINER_NAME].attrs[key] = value # transfer the images for i, ix in enumerate(idx): From 8e1b423e8116a1642e9334c17b2eb60be22b9d9e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Feb 2026 10:39:13 +0000 Subject: [PATCH 23/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scportrait/io/h5sc.py | 2 +- tests/conftest.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/scportrait/io/h5sc.py b/src/scportrait/io/h5sc.py index 5f3bb9f6..b891b1e1 100644 --- a/src/scportrait/io/h5sc.py +++ b/src/scportrait/io/h5sc.py @@ -11,8 +11,8 @@ from anndata._io.h5ad import _clean_uns, _read_raw, read_dataframe, read_elem from scportrait.pipeline._utils.constants import ( - DEFAULT_IDENTIFIER_FILENAME, DEFAULT_CELL_ID_NAME, + DEFAULT_IDENTIFIER_FILENAME, DEFAULT_NAME_SINGLE_CELL_IMAGES, DEFAULT_SEGMENTATION_DTYPE, DEFAULT_SINGLE_CELL_IMAGE_DTYPE, diff --git a/tests/conftest.py b/tests/conftest.py index 3fb2f7d4..bcce93a7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,8 +7,8 @@ import pandas as pd import pytest from anndata import AnnData -from shapely.geometry import box from matplotlib.figure import Figure +from shapely.geometry import box from spatialdata import SpatialData from spatialdata.datasets import blobs from spatialdata.models import Image2DModel, ShapesModel