diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 9ca16aaf..be4a0fe4 100755 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -7,6 +7,8 @@ anndata<0.12 spatialdata>=0.3.0,<0.6 pyarrow<22.0.0 py-lmd>=1.3.1 +affine +rasterio spatialdata-plot>=0.2.14 matplotlib diff --git a/src/scportrait/io/h5sc.py b/src/scportrait/io/h5sc.py index 26564765..b891b1e1 100644 --- a/src/scportrait/io/h5sc.py +++ b/src/scportrait/io/h5sc.py @@ -12,6 +12,7 @@ from scportrait.pipeline._utils.constants import ( DEFAULT_CELL_ID_NAME, + DEFAULT_IDENTIFIER_FILENAME, DEFAULT_NAME_SINGLE_CELL_IMAGES, DEFAULT_SEGMENTATION_DTYPE, DEFAULT_SINGLE_CELL_IMAGE_DTYPE, @@ -55,6 +56,9 @@ def read_h5sc(filename: str | Path) -> AnnData: _clean_uns(adata) adata.obsm[DEFAULT_NAME_SINGLE_CELL_IMAGES] = f.get(IMAGE_DATACONTAINER_NAME) + adata.uns[DEFAULT_IDENTIFIER_FILENAME] = filename + adata.uns["_h5sc_file_handle"] = f + return adata diff --git a/src/scportrait/pipeline/_utils/constants.py b/src/scportrait/pipeline/_utils/constants.py index 01cc7548..ea6330dc 100644 --- a/src/scportrait/pipeline/_utils/constants.py +++ b/src/scportrait/pipeline/_utils/constants.py @@ -35,6 +35,7 @@ IMAGE_DATACONTAINER_NAME = f"obsm/{DEFAULT_NAME_SINGLE_CELL_IMAGES}" DEFAULT_CELL_ID_NAME = "scportrait_cell_id" INDEX_DATACONTAINER_NAME = f"obs/{DEFAULT_CELL_ID_NAME}" +DEFAULT_IDENTIFIER_FILENAME = "h5sc_source_path" DEFAULT_IMAGE_DTYPE: np.dtype = np.uint16 DEFAULT_SEGMENTATION_DTYPE: np.dtype = np.uint64 diff --git a/src/scportrait/pipeline/_utils/spatialdata_helper.py b/src/scportrait/pipeline/_utils/spatialdata_helper.py index 1bfc50db..1ef88a32 100644 --- a/src/scportrait/pipeline/_utils/spatialdata_helper.py +++ b/src/scportrait/pipeline/_utils/spatialdata_helper.py @@ -182,6 +182,7 @@ def calculate_centroids(mask: xarray.DataArray, coordinate_system: str = "global transform = get_transformation(mask, coordinate_system) if check_memory(mask): + print("Array fits in memory, using in-memory calculation.") centers, _, _ids = numba_mask_centroid(mask.values) return make_centers_object(centers, _ids, transform, coordinate_system) diff --git a/src/scportrait/pipeline/mask_filtering/region_masking.py b/src/scportrait/pipeline/mask_filtering/region_masking.py new file mode 100644 index 00000000..41e0779a --- /dev/null +++ b/src/scportrait/pipeline/mask_filtering/region_masking.py @@ -0,0 +1,102 @@ +import dask +import numpy as np +import rasterio +import spatialdata as sd +from napari_spatialdata import Interactive +from rasterio.features import geometry_mask +from scipy import ndimage +from shapely import unary_union +from shapely.geometry import Polygon, mapping +from skimage.draw import disk +from skimage.measure import find_contours +from skimage.segmentation import watershed +from spatialdata.models import Image2DModel + + +def mask_image(sdata, image, mask, invert, automatic_masking, threshold, overwrite, masked_image_name): + """ + Given an image and mask, either masks or crops the image. + + Parameters + ---------- + sdata : sd.SpatialData + spatialdata object containing the image and mask. + image : str + Name of the image in sdata.images to mask. + mask : str | shapely.geometry.Polygon + Mask, either str of the name of the shape in sdata.shapes or a shapely polygon. + invert : bool + If True, inverts the mask, such that only pixels within mask remain, while the rest gets cropped. + automatic_masking : bool + If True, uses threshold + watershed to automatically create a mask based on shapes. Threshold needs to be adjusted manually. + threshold : float + Threshold for pixel intensity values at which to segment image into foreground and background. + overwrite : bool + Whether to overwrite the image in sdata.images. + masked_image_name : None | str + Name of the masked image in sdata.images if overwrite==True. Defaults to f"{image}_masked". + Returns + ------- + sd.SpatialData + spatialdata object with masked image + """ + channels, height, width = sdata.images[image].data.shape + + if automatic_masking: + polygon = _draw_polygons(sdata.images[image].data, threshold) + elif isinstance(mask, str): + polygon = sdata.shapes[mask].iloc[0].geometry + else: + polygon = mask + + polygon_geom = [mapping(polygon)] + + transform = rasterio.transform.Affine(1, 0, 0, 0, 1, 0) # identity transform + + image_mask = geometry_mask(polygon_geom, invert=invert, out_shape=(height, width), transform=transform) + + if channels > 1: + image_mask = dask.array.broadcast_to(image_mask, (channels, height, width)) + + masked_image = sdata.images[image].data * image_mask + images = {} + images["masked_image"] = Image2DModel.parse(masked_image) + + if overwrite: + sdata.images[image] = images["masked_image"] + else: + if masked_image_name is None: + masked_image_name = f"{image}_masked" + sdata.images[masked_image_name] = images["masked_image"] + + +def _draw_polygons(image, threshold): + """ + Given an image, detect regions to turn into polygon shapes, which are then used as a mask. + + Parameters + ---------- + image : np.ndarray + Image to find regions in. + threshold : float + Threshold for pixel intensity values at which to segment image into foreground and background. + Returns + ------- + shapely.geometry.Polygon + Polygon containing the detected regions. + """ + if image.shape[0] == 1: + image = image[0] + binary_image = image > np.percentile(image.flatten(), threshold) + + distance = ndimage.distance_transform_edt(binary_image) + markers, _ = ndimage.label(distance) + + segmented = watershed(-distance, markers, mask=binary_image) + + contours = find_contours(segmented, level=0.5) + + polygons = [Polygon(contour) for contour in contours if len(contour) > 2] + polygon = unary_union(polygons) + + return polygon diff --git a/src/scportrait/pipeline/project.py b/src/scportrait/pipeline/project.py index 288db668..faddee0d 100644 --- a/src/scportrait/pipeline/project.py +++ b/src/scportrait/pipeline/project.py @@ -123,6 +123,9 @@ class Project(Logable): DEFAULT_SINGLE_CELL_IMAGE_DTYPE = DEFAULT_SINGLE_CELL_IMAGE_DTYPE DEFAULT_CELL_ID_NAME = DEFAULT_CELL_ID_NAME + _h5sc_handle = None + _h5sc_adata = None + PALETTE = [ "blue", "green", @@ -232,6 +235,15 @@ def __exit__(self): def __del__(self): self._clear_temp_dir() + def __getstate__(self): + state = self.__dict__.copy() + state["_h5sc_handle"] = None # ensure closed before pickling + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._h5sc_handle = None # will be reopened lazily + @property def sdata_path(self) -> str: return self._get_sdata_path() @@ -243,14 +255,22 @@ def sdata(self) -> SpatialData: @property def h5sc(self) -> AnnData: - if self.extraction_f is None: - raise ValueError("No extraction method has been set.") - else: - if self.extraction_f.output_path is None: - path = self.extraction_f.extraction_file - else: - path = self.extraction_f.output_path - return read_h5sc(path) + # Always safely close previous handle if it exists + if hasattr(self, "_h5sc_handle") and self._h5sc_handle is not None: + try: + self._h5sc_handle.close() + except (ValueError, RuntimeError): + # handle was already closed or in invalid state + pass + self._h5sc_handle = None + + # Load a fresh AnnData + adata = read_h5sc(self.extraction_f.output_path or self.extraction_f.extraction_file) + + # Track only the handle, not the adata + self._h5sc_handle = adata.uns["_h5sc_file_handle"] + + return adata ##### Setup Functions ##### diff --git a/src/scportrait/plotting/h5sc.py b/src/scportrait/plotting/h5sc.py index 6ae51194..0f66c44d 100644 --- a/src/scportrait/plotting/h5sc.py +++ b/src/scportrait/plotting/h5sc.py @@ -3,8 +3,10 @@ import warnings from collections.abc import Iterable +import h5py import matplotlib.pyplot as plt import numpy as np +import pandas as pd from anndata import AnnData from matplotlib.axes import Axes from matplotlib.figure import Figure @@ -272,7 +274,20 @@ def cell_grid_single_channel( fig = ax.get_figure() spacing = spacing * single_cell_size - images = get_image_with_cellid(adata, _cell_ids, channel_id) + + # Collect images in a list + if isinstance(adata.obsm["single_cell_images"], h5py.Dataset): + # if working on a memory-backed array + images = get_image_with_cellid(adata, _cell_ids, channel_id) + + else: + # non backed h5sc adata objects can be accessed directly + # these are created by slicing original h5sc objects + col = "scportrait_cell_id" + mapping = pd.Series(data=np.arange(len(adata.obs), dtype=int), index=adata.obs[col].values) + idx = mapping.loc[_cell_ids].to_numpy() + images = adata.obsm["single_cell_images"][idx, channel_id, :, :] + _plot_image_grid( ax, images, @@ -368,7 +383,18 @@ def cell_grid_multi_channel( channel_names = adata.uns["single_cell_images"]["channel_names"] # Collect images in a list - images = get_image_with_cellid(adata, _cell_ids) + if isinstance(adata.obsm["single_cell_images"], h5py.Dataset): + # if working on a memory-backed array + images = get_image_with_cellid(adata, _cell_ids) + + else: + # non backed h5sc adata objects can be accessed directly + # these are created by slicing original h5sc objects + col = "scportrait_cell_id" + mapping = pd.Series(data=np.arange(len(adata.obs), dtype=int), index=adata.obs[col].values) + idx = mapping.loc[_cell_ids].to_numpy() + images = adata.obsm["single_cell_images"][idx] + if select_channels is not None: if not isinstance(select_channels, Iterable): select_channels = [select_channels] diff --git a/src/scportrait/plotting/sdata.py b/src/scportrait/plotting/sdata.py index 2858d847..1100f9fc 100644 --- a/src/scportrait/plotting/sdata.py +++ b/src/scportrait/plotting/sdata.py @@ -4,8 +4,10 @@ import matplotlib as mpl import matplotlib.pyplot as plt +import numpy as np import spatialdata import xarray +from geopandas.geodataframe import GeoDataFrame from matplotlib.axes import Axes PALETTE = [ @@ -48,6 +50,11 @@ def _get_shape_element(sdata, element_name) -> tuple[int, int]: """ if isinstance(sdata[element_name], xarray.DataTree): shape = sdata[element_name].scale0.image.shape + elif isinstance(sdata[element_name], GeoDataFrame): + bounds = sdata[element_name].geometry.bounds + x = int(np.ceil(bounds["maxx"] - bounds["minx"])) + y = int(np.ceil(bounds["maxy"] - bounds["miny"])) + shape = (x, y) else: shape = sdata[element_name].data.shape diff --git a/src/scportrait/tools/h5sc/__init__.py b/src/scportrait/tools/h5sc/__init__.py index 54830847..de8e4323 100644 --- a/src/scportrait/tools/h5sc/__init__.py +++ b/src/scportrait/tools/h5sc/__init__.py @@ -5,6 +5,20 @@ Functions to work with scPortrait's standardized single-cell data format. """ -from .operations import get_image_index, get_image_with_cellid +from .operations import ( + add_spatial_coordinates, + get_cell_id_index, + get_image_with_cellid, + subset_cells_region, + subset_h5sc, + update_obs_on_disk, +) -__all__ = ["get_image_with_cellid", "get_image_index"] +__all__ = [ + "update_obs_on_disk", + "get_image_with_cellid", + "get_cell_id_index", + "add_spatial_coordinates", + "subset_h5sc", + "subset_cells_region", +] diff --git a/src/scportrait/tools/h5sc/operations.py b/src/scportrait/tools/h5sc/operations.py index e672d050..09a1e1d8 100644 --- a/src/scportrait/tools/h5sc/operations.py +++ b/src/scportrait/tools/h5sc/operations.py @@ -5,14 +5,72 @@ Functions to work with scPortrait's standardized single-cell data format. """ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dask.dataframe.core import DataFrame as da_DataFrame + from spatialdata import SpatialData + +import os +import shutil +from pathlib import Path +from warnings import warn + +import dask.array as da +import geopandas as gpd +import h5py import numpy as np +from anndata import AnnData +from shapely.geometry import Point + +from scportrait.io.h5sc import read_h5sc +from scportrait.pipeline._utils.constants import ( + DEFAULT_CELL_ID_NAME, + DEFAULT_IDENTIFIER_FILENAME, + DEFAULT_NAME_SINGLE_CELL_IMAGES, + IMAGE_DATACONTAINER_NAME, +) + + +def update_obs_on_disk(adata: AnnData) -> None: + """ + Temporarily close the HDF5 handle from a read-only AnnData, + overwrite .obs on disk, then reopen it and restore the image dataset. -from scportrait.pipeline._utils.constants import DEFAULT_CELL_ID_NAME, DEFAULT_NAME_SINGLE_CELL_IMAGES + Args: + adata: AnnData object whose .obs will replace the existing one. + """ + + # 1. Get the open HDF5 file handle + file_handle = adata.uns.get("_h5sc_file_handle", None) + + # 2. Close file to release read-only lock + if file_handle: + file_handle.close() + adata.uns["_h5sc_file_handle"] = None + # 3. Write updated obs + obs_df = adata.obs.copy() + obs_df.index = obs_df.index.astype(str) -def get_image_index(adata, cell_id: int | list[int]) -> int | list[int]: + with h5py.File(adata.uns[DEFAULT_IDENTIFIER_FILENAME], "r+") as f: + if "obs" in f: + del f["obs"] + grp = f.create_group("obs") + for col in obs_df.columns: + grp.create_dataset(col, data=obs_df[col].to_numpy()) + + # 4. Reopen file handle and restore image dataset + f = h5py.File(adata.uns[DEFAULT_IDENTIFIER_FILENAME], "r") + adata.obsm[DEFAULT_NAME_SINGLE_CELL_IMAGES] = f.get(IMAGE_DATACONTAINER_NAME) + adata.uns["_h5sc_file_handle"] = f + + +def get_cell_id_index(adata: AnnData, cell_id: int | list[int]) -> int | list[int]: """ - Retrieve the image index (row index) of a specific cell id in a H5SC object. + Retrieve the index (row index) of a specific cell id in a H5SC object. Args: adata: An AnnData object with obsm["single_cell_images"] containing a memory-backed array of the single-cell images. @@ -23,14 +81,18 @@ def get_image_index(adata, cell_id: int | list[int]) -> int | list[int]: """ lookup = dict(zip(adata.obs[DEFAULT_CELL_ID_NAME], adata.obs.index.astype(int), strict=True)) - if isinstance(cell_id, int): + assert cell_id in lookup, f"CellID {cell_id} not present in the AnnData object." return lookup[cell_id] + missing = [x for x in cell_id if x not in lookup] + assert not missing, f"CellIDs not present in the AnnData object: {missing}" return [lookup[_id] for _id in cell_id] -def get_image_with_cellid(adata, cell_id: list[int] | int, select_channel: int | list[int] | None = None) -> np.ndarray: +def get_image_with_cellid( + adata: AnnData, cell_id: list[int] | int, select_channel: int | list[int] | None = None +) -> np.ndarray: """Get single cell images from the cells with the provided cell IDs. Images are returned in the order of the cell IDs. Args: @@ -41,18 +103,15 @@ def get_image_with_cellid(adata, cell_id: list[int] | int, select_channel: int | Returns: The image(s) of the cell with the passed Cell IDs. """ - lookup = dict(zip(adata.obs[DEFAULT_CELL_ID_NAME], adata.obs.index.astype(int), strict=True)) - image_container = adata.obsm[DEFAULT_NAME_SINGLE_CELL_IMAGES] - - if isinstance(cell_id, int): - cell_id = [cell_id] + idxs = get_cell_id_index(adata, cell_id) + if isinstance(idxs, int): + idxs = [idxs] # Ensure idxs is always a list - for x in cell_id: - assert x in lookup.keys(), f"CellID {x} is not present in the AnnData object." + # get the image container from the AnnData object + image_container = adata.obsm[DEFAULT_NAME_SINGLE_CELL_IMAGES] images = [] - for _id in cell_id: - idx = lookup[_id] + for idx in idxs: if select_channel is None: image = image_container[idx][:] else: @@ -64,3 +123,151 @@ def get_image_with_cellid(adata, cell_id: list[int] | int, select_channel: int | return array.squeeze(axis=0) # Remove the first dimension else: return array + + +def add_spatial_coordinates( + adata: AnnData, + centers_object: da_DataFrame, + cell_id_identifier: str = "scportrait_cell_id", + update_on_disk: bool = False, +) -> None: + """Add spatial coordinates to the AnnData object from scPortrait's standardized centers object. + Args: + adata: AnnData object to add spatial coordinates to. + centers_object: Dask DataFrame containing the spatial coordinates with columns "x" and "y" and the scportrait cell id as index. + cell_id_identifier: The column name in `adata.obs` that contains the cell IDs. + update_on_disk: boolean value indicating if the updated obs containing the spatial coordinates should be written to disk. This will overwrite the existing obs. + + Returns: + Updates the obs object of the passed h5sc object. + """ + + assert cell_id_identifier in adata.obs.columns, f"{cell_id_identifier} must be a column in h5sc.obs" + assert ["x", "y"] == list(centers_object.columns), ( + "centers_object must be scportrait's standardized centers object containing columns 'x' and 'y' and the scportrait cell id as index, but detected columns are {centers_object.columns}" + ) + + if ("x" in adata.obs.columns) or ("y" in adata.obs.columns): + adata.obs.drop(columns=["x", "y"], inplace=True, errors="ignore") + warn( + "Removed existing 'x' and 'y' columns from adata.obs. If this is not intended, please check the input data.", + stacklevel=2, + ) + + adata.obs = adata.obs.merge(centers_object.compute(), left_on=cell_id_identifier, right_index=True) + + if update_on_disk: + update_obs_on_disk(adata) + + +def subset_cells_region( + adata: AnnData, + sdata: SpatialData, + region_name: str, + outpath: str | Path = None, + within_region: bool = True, + to_disk: bool = True, + return_anndata: bool = True, +) -> AnnData | None: + """ + Subset cells in the specified region. + + Args: + adata: AnnData object containing the cell data. + sdata: SpatialData object containing the region geometry. + region_name: Name of the region to subset cells from. + outpath: Path to save the subsetted AnnData object. If None, the subsetted file is saved in the same directory as the original h5sc file with a prefix "subset_{select_region}". + within_region: If True, select cells within the region. If False, select cells outside the region. + to_disk: If True, save the subsetted AnnData object to disk. If False, return the subsetted AnnData object in memory. + return_anndata: If True, return a memory mapped version of the subsetted AnnData object. + + Returns: + If `to_disk` is False, returns the subsetted AnnData object. If `to_disk` is True, saves the subsetted AnnData object to disk and returns None. + """ + if outpath is not None: + if not isinstance(outpath, (str | Path)): + raise ValueError("outpath must be a string or Path object.") + assert to_disk, "outpath is only used if to_disk is True." + + if region_name not in sdata: + raise ValueError(f"Region '{region_name}' not found in spatialdata object.") + + xs, ys = adata.obs.get(["x", "y"]).values.T + points = gpd.GeoSeries([Point(xi, yi) for xi, yi in zip(xs, ys, strict=True)]) + is_inside = points.apply(lambda p: sdata[region_name].geometry.contains(p).any()).values + + if not within_region: + selection = ~is_inside + key = "outside" + else: + selection = is_inside + key = "within" + + if not to_disk: + return adata[selection] + else: + cell_ids = adata.obs.loc[selection, DEFAULT_CELL_ID_NAME].values + + if outpath is None: + outpath = adata.uns["h5sc_source_path"].replace("single_cells.h5sc", f"subset_{key}_{region_name}.h5sc") + subset_h5sc(adata, cell_ids, outpath=outpath) + + if return_anndata: + return read_h5sc(outpath) + else: + return None + + +def subset_h5sc(adata: AnnData, cell_id: int | list[int], outpath: str | Path) -> None: + """ + Write a subset of the AnnData object to disk based on the provided cell IDs. + + Args: + adata: AnnData object containing the single-cell data. + cell_id: A single cell ID or a list of cell IDs to subset the AnnData + outpath: Path to save the subsetted AnnData object. + + Returns: + None. The AnnData object is written to disk at the specified outpath. + """ + idx = get_cell_id_index(adata, cell_id) + + if isinstance(idx, int): + idx = [idx] # Ensure idx is always a list + + obs = adata.obs.iloc[idx, :].copy() + obs.reset_index(drop=True, inplace=True) + obs.index = obs.index.astype(str) # Ensure index is string type for consistency + var = adata.var.copy() + uns = {DEFAULT_NAME_SINGLE_CELL_IMAGES: adata.uns[DEFAULT_NAME_SINGLE_CELL_IMAGES]} + + adata_subset = AnnData(obs=obs, var=var, uns=uns) + + if os.path.exists(outpath): + shutil.rmtree(outpath, ignore_errors=True) + adata_subset.write_h5ad(outpath) + + # initialize the obsm with the single cell images + orig = adata.obsm[DEFAULT_NAME_SINGLE_CELL_IMAGES] + single_cell_data_shape = (len(idx),) + orig.shape[1:] + chunks = orig.chunks if hasattr(orig, "chunks") else None + compression = orig.compression if hasattr(orig, "compression") else None + with h5py.File(outpath, "a") as hf: + hf.create_dataset( + IMAGE_DATACONTAINER_NAME, + shape=single_cell_data_shape, + chunks=chunks, + compression=compression, + dtype=orig.dtype, + ) + if hasattr(orig, "attrs"): + for key, value in orig.attrs.items(): + hf[IMAGE_DATACONTAINER_NAME].attrs[key] = value + + # transfer the images + for i, ix in enumerate(idx): + hf[IMAGE_DATACONTAINER_NAME][i] = orig[ix] + hf.close() + + print(f"Subsetted AnnData object saved to {outpath}.") + return None diff --git a/src/scportrait/tools/sdata/processing/__init__.py b/src/scportrait/tools/sdata/processing/__init__.py index 342d45d1..e7d2212b 100644 --- a/src/scportrait/tools/sdata/processing/__init__.py +++ b/src/scportrait/tools/sdata/processing/__init__.py @@ -1,4 +1,4 @@ from ._image_processing import percentile_normalize_image -from ._subset import get_bounding_box_sdata +from ._subset import get_bounding_box_sdata, mask_region -__all__ = ["percentile_normalize_image", "get_bounding_box_sdata"] +__all__ = ["percentile_normalize_image", "get_bounding_box_sdata", "mask_region"] diff --git a/src/scportrait/tools/sdata/processing/_subset.py b/src/scportrait/tools/sdata/processing/_subset.py index f489baa7..ec8399d9 100644 --- a/src/scportrait/tools/sdata/processing/_subset.py +++ b/src/scportrait/tools/sdata/processing/_subset.py @@ -1,22 +1,27 @@ -import copy import warnings -import spatialdata +import dask.array as da +import numpy as np +import xarray as xr +from affine import Affine +from rasterio.features import rasterize +from shapely.geometry import mapping +from spatialdata import SpatialData def get_bounding_box_sdata( - sdata: spatialdata, max_width: int, center_x: int, center_y: int, drop_points: bool = True -) -> spatialdata: + sdata: SpatialData, max_width: int, center_x: int, center_y: int, drop_points: bool = True +) -> SpatialData: """apply bounding box to sdata object Args: - sdata: spatialdata object + sdata: SpatialData object max_width: maximum width of the bounding box center_x: x coordinate of the center of the bounding box center_y: y coordinate of the center of the bounding box Returns: - spatialdata: spatialdata object with bounding box applied + SpatialData object with bounding box applied """ _sdata = sdata # remove points object to improve subsetting @@ -54,9 +59,96 @@ def get_bounding_box_sdata( if drop_points: # re-add points object - __sdata = spatialdata.SpatialData.read(sdata.path, selection=["points"]) + __sdata = SpatialData.read(sdata.path, selection=["points"]) for x in points_keys: sdata[x] = __sdata[x] del __sdata return _sdata + + +def mask_region( + sdata: SpatialData, + image_name: str = "input_image", + shape_name: str = "select_region", + mask: bool = True, + crop: bool = False, +) -> xr.DataArray: + """Mask and/or crop the input image to the selected region. + + Args: + sdata: SpatialData object containing the image and shape. + image_name: Name of the image to be masked/cropped. + shape_name: Name of the shape to mask/crop the image with. + mask: Whether to apply the mask to the image. Default is True. + crop: Whether to crop the image to the outer bounding box of the shape. Default is False. + Returns: + masked/cropped image as a DataArray. If crop is False, the image has the same dimensions as the input image, otherwise it has the dimensions of the outer bounding box of the shape. + """ + assert mask or crop, "Either mask or crop must be True" + + # get image and check for proper scaling + if image_name not in sdata: + raise ValueError(f"Image {image_name} not found in sdata") + image = sdata[image_name] + + if isinstance(image, xr.DataTree): + image = image.get("scale0").image + + elif isinstance(image, xr.DataArray): + image = image + + print(image.dtype) + + # get shape and check for single-shape selection + shape = sdata[shape_name].geometry + if len(shape) == 1: + shape = shape[0] + elif len(shape) > 1: + raise ValueError("Expected a single shape, but found multiple shapes. Please select only one region.") + else: + raise ValueError("No shapes found in the specified region.") + + # initialize empty array + H, W = image.sizes["y"], image.sizes["x"] + chunks_yx = (image.data.chunks[image.get_axis_num("y")], image.data.chunks[image.get_axis_num("x")]) + template = da.zeros((H, W), chunks=chunks_yx, dtype=np.uint16) + + def _mask_block(block, block_info=None): + info = block_info[None] + (y0, y1), (x0, x1) = info["array-location"][:2] + h, w = (y1 - y0), (x1 - x0) + + # shift transform to this block’s window + window_transform = Affine.translation(x0, y0) + + m = rasterize( + [(geom, 1)], + out_shape=(h, w), + transform=window_transform, + fill=0, + dtype=image.dtype, + all_touched=True, # set True if you want any touched pixel included + ) + return m.astype(bool) + + geom = mapping(shape) + mask_dask = da.map_blocks(_mask_block, template, dtype=bool) + mask_da = xr.DataArray(mask_dask, dims=("y", "x"), coords={"y": image.coords["y"], "x": image.coords["x"]}) + + other = np.array(0, dtype=image.dtype) + if mask: + if "c" in image.dims: + m = mask_da.broadcast_like(image.isel(c=0)) + masked = image.where(m, other=other) + else: + masked = image.where(mask_da, other=other) + else: + masked = image + + if crop: + minx, miny, maxx, maxy = shape.bounds + minx, miny, maxx, maxy = int(np.floor(minx)), int(np.floor(miny)), int(np.ceil(maxx)), int(np.ceil(maxy)) + return masked.isel(x=slice(minx, maxx), y=slice(miny, maxy)) + else: + return masked diff --git a/src/scportrait/tools/sdata/write/_write.py b/src/scportrait/tools/sdata/write/_write.py index 5d5f21f7..76049461 100644 --- a/src/scportrait/tools/sdata/write/_write.py +++ b/src/scportrait/tools/sdata/write/_write.py @@ -79,7 +79,7 @@ def image( f"Number of channel names ({len(channel_names)}) does not match the number of channels in the image ({image.shape[0]})." ) channel_names_old = image.coords["c"].values.tolist() - if channel_names_old != channel_names: + if any(channel_names_old != channel_names): warnings.warn( f"Channel names in the DataArray ({channel_names_old}) do not match the provided channel names ({channel_names}). The DataArray will be updated with the provided channel names.", stacklevel=2, diff --git a/tests/conftest.py b/tests/conftest.py index 29ea041c..bcce93a7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import shutil +import geopandas as gpd import matplotlib import matplotlib.pyplot as plt import numpy as np @@ -7,8 +8,10 @@ import pytest from anndata import AnnData from matplotlib.figure import Figure +from shapely.geometry import box from spatialdata import SpatialData from spatialdata.datasets import blobs +from spatialdata.models import Image2DModel, ShapesModel from scportrait.tools.sdata.write._helper import _normalize_anndata_strings @@ -77,3 +80,31 @@ def sdata_with_labels() -> SpatialData: sdata["table"].obs["labelling_categorical"] = sdata["table"].obs["instance_id"].astype("category") sdata["table"].obs["labelling_continous"] = (sdata["table"].obs["instance_id"] > 10).astype(float) return sdata + + +@pytest.fixture +def sdata_with_selected_region(): + image = np.ones((1, 10, 10), dtype=np.uint16) + shape = box(2, 3, 7, 8) + image_model = Image2DModel.parse(image, dims=("c", "y", "x")) + shapes_gdf = gpd.GeoDataFrame({"geometry": [shape]}) + shapes_model = ShapesModel.parse(shapes_gdf) + sdata = SpatialData(images={"input_image": image_model}, shapes={"select_region": shapes_model}) + return sdata + + +@pytest.fixture +def sdata_builder(): + def _build( + image, + shapes, + image_name="input_image", + shape_name="select_region", + dims=("c", "y", "x"), + ): + image_model = Image2DModel.parse(image, dims=dims) + shapes_gdf = gpd.GeoDataFrame({"geometry": shapes}) + shapes_model = ShapesModel.parse(shapes_gdf) + return SpatialData(images={image_name: image_model}, shapes={shape_name: shapes_model}) + + return _build diff --git a/tests/unit_tests/plotting/test_h5sc.py b/tests/unit_tests/plotting/test_h5sc.py index 55a7f9fa..f6933997 100644 --- a/tests/unit_tests/plotting/test_h5sc.py +++ b/tests/unit_tests/plotting/test_h5sc.py @@ -1,5 +1,3 @@ -from unittest.mock import MagicMock, patch - import matplotlib.pyplot as plt import numpy as np import pytest @@ -17,9 +15,8 @@ cell_grid_single_channel, ) -# ---------- _reshape_image_array ---------- - +# ---------- _reshape_image_array ---------- @pytest.mark.parametrize( "input_shape, expected_shape", [ @@ -35,8 +32,6 @@ def test_reshape_image_array(input_shape, expected_shape): # ---------- _plot_image_grid ---------- - - @pytest.mark.parametrize( "input_shape, nrows, ncols, col_labels, col_labels_rotation", [ @@ -94,41 +89,22 @@ def test_plot_contour_grid_draws_on_existing_grid(): # ---------- cell_grid_single_channel ---------- - - -@patch("scportrait.plotting.h5sc.get_image_with_cellid") -def test_cell_grid_single_channel_returns_figure_with_title_rotation(mock_get_img): - mock_adata = MagicMock() - mock_adata.uns = {"single_cell_images": {"channel_names": np.array(["ch0", "ch1"])}} - mock_adata.obs = MagicMock() - mock_adata.obs.__getitem__.return_value.sample.return_value.values = [101, 102] - mock_get_img.return_value = rng.random((2, 10, 10)) - +def test_cell_grid_single_channel_returns_figure_with_title_rotation(h5sc_object): fig = cell_grid_single_channel( - adata=mock_adata, - select_channel=0, + adata=h5sc_object, + select_channel="ch0", n_cells=2, return_fig=True, show_fig=False, title_rotation=30, # new parameter ) assert isinstance(fig, Figure) - # You could further check that fig.axes[0].get_title() is set, but rotation may not be directly visible # ---------- cell_grid_multi_channel ---------- - - -@patch("scportrait.plotting.h5sc.get_image_with_cellid") -def test_cell_grid_multi_channel_returns_figure_with_channel_label_rotation(mock_get_img): - mock_adata = MagicMock() - mock_adata.uns = {"single_cell_images": {"channel_names": ["ch0", "ch1"], "n_channels": 2}} - mock_adata.obs = MagicMock() - mock_adata.obs.__getitem__.return_value.sample.return_value.values = [101, 102] - mock_get_img.return_value = rng.random((2, 2, 10, 10)) - +def test_cell_grid_multi_channel_returns_figure_with_channel_label_rotation(h5sc_object): fig = cell_grid_multi_channel( - adata=mock_adata, + adata=h5sc_object, n_cells=2, return_fig=True, show_fig=False, @@ -139,17 +115,9 @@ def test_cell_grid_multi_channel_returns_figure_with_channel_label_rotation(mock # ---------- cell_grid ---------- +def test_cell_grid_dispatches_to_single_channel(h5sc_object): + cell_grid(adata=h5sc_object, select_channel="ch1", n_cells=1, show_fig=False) -@patch("scportrait.plotting.h5sc.cell_grid_single_channel") -def test_cell_grid_dispatches_to_single_channel(mock_single): - mock_adata = MagicMock() - cell_grid(adata=mock_adata, select_channel=1, n_cells=1, show_fig=False) - assert mock_single.called - - -@patch("scportrait.plotting.h5sc.cell_grid_multi_channel") -def test_cell_grid_dispatches_to_multi_channel(mock_multi): - mock_adata = MagicMock() - cell_grid(adata=mock_adata, select_channel=[0, 1], n_cells=1, show_fig=False) - assert mock_multi.called +def test_cell_grid_dispatches_to_multi_channel(h5sc_object): + cell_grid(adata=h5sc_object, select_channel=["ch0", "ch1"], n_cells=1, show_fig=False) diff --git a/tests/unit_tests/tools/h5sc/test_operations.py b/tests/unit_tests/tools/h5sc/test_operations.py new file mode 100644 index 00000000..af628f78 --- /dev/null +++ b/tests/unit_tests/tools/h5sc/test_operations.py @@ -0,0 +1,56 @@ +# tests/test_operations.py + +from pathlib import Path + +import anndata as ad +import h5py +import numpy as np +import pandas as pd +import pytest +from scportrait.tl.h5sc import ( + add_spatial_coordinates, + get_cell_id_index, + subset_cells_region, + subset_h5sc, + update_obs_on_disk, +) + +from scportrait.io import read_h5sc + +rng = np.random.default_rng() + + +def test_update_obs_on_disk(h5sc_object, tmp_path): + # Write h5ad + p = tmp_path / "test.h5ad" + h5sc_object.write(p) + + h5sc_object.uns["h5sc_source_path"] = str(p) + size = h5sc_object.obs.shape[0] + + # Modify obs + random_values = rng.integers(1, 10, size=size) + h5sc_object.obs["new_col"] = random_values + update_obs_on_disk(h5sc_object) + + # Reload and confirm updated + reloaded = read_h5sc(p) + assert "new_col" in reloaded.obs.columns + assert np.all(reloaded.obs["new_col"] == random_values) + + +def test_get_cell_id_index_single(h5sc_object): + idx = get_cell_id_index(h5sc_object, 107) + assert idx == 2 + + +def test_get_cell_id_index_list(h5sc_object): + idx = get_cell_id_index(h5sc_object, [101, 109]) + assert idx == [0, 3] + + +def test_subset_h5sc(h5sc_object, tmp_path): + out = tmp_path / "subset.h5sc" + subset_h5sc(h5sc_object, [101, 102], out) + + assert out.exists() diff --git a/tests/unit_tests/tools/sdata/test_subset.py b/tests/unit_tests/tools/sdata/test_subset.py new file mode 100644 index 00000000..0343b208 --- /dev/null +++ b/tests/unit_tests/tools/sdata/test_subset.py @@ -0,0 +1,140 @@ +import numpy as np +import pytest +from shapely.geometry import box + +from scportrait.tools.sdata.processing._subset import get_bounding_box_sdata, mask_region +from scportrait.tools.sdata.write._helper import _get_image + + +def _as_numpy(image): + data = image.data + if hasattr(data, "compute"): + data = data.compute() + return np.asarray(data) + + +def _expected_mask(masked, shape): + x_coords = masked.coords["x"].values + y_coords = masked.coords["y"].values + x_mask = (x_coords >= shape.bounds[0]) & (x_coords <= shape.bounds[2]) + y_mask = (y_coords >= shape.bounds[1]) & (y_coords <= shape.bounds[3]) + return np.outer(y_mask, x_mask).astype(bool) + + +@pytest.mark.parametrize( + "mask,crop", + [ + (True, False), + (False, True), + (True, True), + ], +) +def test_mask_region_mask_crop_combinations(sdata_with_selected_region, mask, crop): + image = _as_numpy(sdata_with_selected_region.images["input_image"]) + shape = sdata_with_selected_region["select_region"].geometry[0] + + result = mask_region( + sdata_with_selected_region, image_name="input_image", shape_name="select_region", mask=mask, crop=crop + ) + result_np = _as_numpy(result) + + minx, miny, maxx, maxy = shape.bounds + x0, y0 = int(np.floor(minx)), int(np.floor(miny)) + x1, y1 = int(np.ceil(maxx)), int(np.ceil(maxy)) + + if crop: + expected = image[:, y0:y1, x0:x1] + assert result_np.shape == expected.shape + if mask: + assert result_np.min() >= 0 + assert result_np.max() == expected.max() + else: + np.testing.assert_array_equal(result_np, expected) + else: + expected_mask = _expected_mask(result, shape) + values = result_np[0] + assert values[expected_mask].min() == 1 + assert values[~expected_mask].max() == 0 + + +def test_mask_region_requires_mask_or_crop(sdata_with_selected_region): + with pytest.raises(AssertionError): + mask_region( + sdata_with_selected_region, image_name="input_image", shape_name="select_region", mask=False, crop=False + ) + + +def test_mask_region_missing_image_raises(sdata_with_selected_region): + with pytest.raises(ValueError): + mask_region(sdata_with_selected_region, image_name="missing", shape_name="select_region", mask=True, crop=False) + + +def test_mask_region_missing_shape_raises(sdata_with_selected_region): + with pytest.raises(KeyError): + mask_region(sdata_with_selected_region, image_name="input_image", shape_name="missing", mask=True, crop=False) + + +def test_mask_region_multiple_shapes_raises(sdata_builder): + image = np.ones((1, 6, 6), dtype=np.uint16) + shapes = [box(1, 1, 3, 3), box(2, 2, 4, 4)] + sdata = sdata_builder(image, shapes) + + with pytest.raises(ValueError): + mask_region(sdata, image_name="input_image", shape_name="select_region", mask=True, crop=False) + + +@pytest.mark.parametrize( + "max_width,center_x,center_y", + [ + (4, 5, 5), + (6, 2, 8), + ], +) +def test_get_bounding_box_sdata_reduces_extent(sdata_builder, max_width, center_x, center_y): + image = np.ones((1, 10, 10), dtype=np.uint16) + shapes = [box(1, 1, 3, 3)] + sdata = sdata_builder(image, shapes) + + subset = get_bounding_box_sdata( + sdata, + max_width=max_width, + center_x=center_x, + center_y=center_y, + drop_points=False, + ) + + image_subset = _get_image(subset.images["input_image"]) + assert image_subset.sizes["x"] <= 10 + assert image_subset.sizes["y"] <= 10 + assert image_subset.sizes["x"] > 0 + assert image_subset.sizes["y"] > 0 + assert image_subset.sizes["x"] <= max_width + assert image_subset.sizes["y"] <= max_width + + +@pytest.mark.parametrize( + "center_x,center_y", + [ + (0, 0), + (-5, -5), + ], +) +def test_get_bounding_box_sdata_clamps_edges(sdata_builder, center_x, center_y): + image = np.ones((1, 10, 10), dtype=np.uint16) + shapes = [box(1, 1, 3, 3)] + sdata = sdata_builder(image, shapes) + + subset = get_bounding_box_sdata( + sdata, + max_width=4, + center_x=center_x, + center_y=center_y, + drop_points=False, + ) + + image_subset = _get_image(subset.images["input_image"]) + x_coords = image_subset.coords["x"].values + y_coords = image_subset.coords["y"].values + + assert x_coords.min() >= sdata.images["input_image"].coords["x"].values.min() + assert y_coords.min() >= sdata.images["input_image"].coords["y"].values.min()