diff --git a/.github/workflows/setup-data.yml b/.github/workflows/setup-data.yml index 050ad3d..2265a39 100644 --- a/.github/workflows/setup-data.yml +++ b/.github/workflows/setup-data.yml @@ -22,7 +22,7 @@ jobs: path: data/ # The folder you want to cache # The key determines if we have a match. # Change 'v1' to 'v2' manually to force a re-download in the future. - key: test-data-v7 + key: test-data-v9 # 2. DOWNLOAD ONLY IF CACHE MISS diff --git a/src/mritk/cli.py b/src/mritk/cli.py index 10fe43a..f657c50 100644 --- a/src/mritk/cli.py +++ b/src/mritk/cli.py @@ -9,7 +9,7 @@ from rich.logging import RichHandler from rich_argparse import RichHelpFormatter -from . import concentration, datasets, hybrid, info, looklocker, masks, mixed, napari, r1, show, statistics +from . import concentration, datasets, hybrid, info, looklocker, masks, mixed, napari, r1, segmentation, show, statistics def version_info(): @@ -75,6 +75,9 @@ def setup_parser(): napari_parser = subparsers.add_parser("napari", help="Show MRI data using napari", formatter_class=parser.formatter_class) napari.add_arguments(napari_parser) + segmentation_parser = subparsers.add_parser("seg", help="Perform segmentation tasks", formatter_class=parser.formatter_class) + segmentation.add_arguments(segmentation_parser, extra_args_cb=add_extra_arguments) + looklocker_parser = subparsers.add_parser( "looklocker", help="Process Look-Locker data", formatter_class=parser.formatter_class ) @@ -142,6 +145,8 @@ def dispatch(parser: argparse.ArgumentParser, argv: Optional[Sequence[str]] = No show.dispatch(args) elif command == "napari": napari.dispatch(args) + elif command == "seg": + segmentation.dispatch(args) elif command == "looklocker": looklocker.dispatch(args) elif command == "mask": diff --git a/src/mritk/masks.py b/src/mritk/masks.py index a50b4c6..18799e5 100644 --- a/src/mritk/masks.py +++ b/src/mritk/masks.py @@ -12,6 +12,7 @@ import skimage from .data import MRIData +from .segmentation import CSFSegmentation, Segmentation from .testing import assert_same_space @@ -81,12 +82,12 @@ def compute_csf_mask_array( return binary -def csf_mask(input: Path, connectivity: int | None = 2, use_li: bool = False) -> MRIData: +def csf_mask(input: MRIData, connectivity: int | None = 2, use_li: bool = False) -> MRIData: """ I/O wrapper for generating and saving a CSF mask from a NIfTI file. Args: - input (Path): Path to the input NIfTI image. + input (MRIData): An MRIData object containing the input volume (typically T2-weighted or Spin-Echo). connectivity (Optional[int], optional): Connectivity distance. Defaults to 2. use_li (bool, optional): If True, uses Li thresholding. Defaults to False. output (Optional[Path], optional): Path to save the resulting mask. Defaults to None. @@ -97,12 +98,10 @@ def csf_mask(input: Path, connectivity: int | None = 2, use_li: bool = False) -> Raises: AssertionError: If the resulting mask contains no voxels. """ - input_vol = MRIData.from_file(input, dtype=np.single) - mask = compute_csf_mask_array(input_vol.data, connectivity, use_li) - + mask = compute_csf_mask_array(input.data, connectivity, use_li) assert np.max(mask) > 0, "Masking failed, no voxels in mask" - mri_data = MRIData(data=mask, affine=input_vol.affine) + mri_data = MRIData(data=mask, affine=input.affine) return mri_data @@ -134,7 +133,7 @@ def compute_intracranial_mask_array(csf_mask_array: np.ndarray, segmentation_arr return ~opened_background -def intracranial_mask(csf_segmentation_path: Path, segmentation_path: Path) -> MRIData: +def intracranial_mask(segmentation: Segmentation, csf_mask: MRIData) -> MRIData: """ I/O wrapper for generating and saving an intracranial mask from NIfTI files. @@ -142,21 +141,20 @@ def intracranial_mask(csf_segmentation_path: Path, segmentation_path: Path) -> M delegates the array computation. Args: - csf_segmentation_path (Path): Path to the CSF segmentation NIfTI file. - segmentation_path (Path): Path to the brain segmentation NIfTI file. - output (Optional[Path], optional): Path to save the resulting mask. Defaults to None. + segmentation (MRIData): The refined segmentation (MRIData), \ + generated by the segmentation refinement module. + csf_mask (MRIData): The CSF mask (MRIData), generated by the csf mask module. Returns: MRIData: An MRIData object containing the intracranial mask. """ - input_csf_mask = MRIData.from_file(csf_segmentation_path, dtype=bool) - segmentation_data = MRIData.from_file(segmentation_path, dtype=bool) + csf_seg = CSFSegmentation(segmentation, csf_mask).to_csf_segmentation() # Validate spatial alignment before array operations - assert_same_space(input_csf_mask, segmentation_data) + assert_same_space(csf_seg, segmentation.mri) - mask_data = compute_intracranial_mask_array(input_csf_mask.data, segmentation_data.data) - mri_data = MRIData(data=mask_data, affine=segmentation_data.affine) + mask_data = compute_intracranial_mask_array(csf_seg.data, segmentation.mri.data) + mri_data = MRIData(data=mask_data, affine=segmentation.mri.affine) return mri_data @@ -183,8 +181,18 @@ def add_arguments( intracranial_mask_parser = subparser.add_parser( "intracranial", help="Compute intracranial mask", formatter_class=parser.formatter_class ) - intracranial_mask_parser.add_argument("--csf-segmentation-path", type=Path, help="Path to the CSF segmentation NIfTI file") - intracranial_mask_parser.add_argument("--segmentation-path", type=Path, help="Path to the brain segmentation NIfTI file") + intracranial_mask_parser.add_argument( + "--segmentation-path", + type=Path, + help="Path to refined segmentation file, generated by \ + the segmentation refinement module, i.e. mritk seg refine", + ) + intracranial_mask_parser.add_argument( + "--csf-mask-path", + type=Path, + help="Path to the CSF mask NIfTI file, generated by \ + the csf mask module, i.e. mritk mask csf", + ) intracranial_mask_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the resulting mask") if extra_args_cb is not None: @@ -195,11 +203,14 @@ def add_arguments( def dispatch(args): command = args.pop("mask-command") if command == "csf": - csf_mask_data = csf_mask(input=args.pop("input"), connectivity=args.pop("connectivity"), use_li=args.pop("use_li")) + csf_mask_data = csf_mask( + input=MRIData.from_file(args.pop("input")), connectivity=args.pop("connectivity"), use_li=args.pop("use_li") + ) csf_mask_data.save(args.pop("output"), dtype=np.uint8) elif command == "intracranial": intracranial_mask_data = intracranial_mask( - csf_segmentation_path=args.pop("csf_segmentation_path"), segmentation_path=args.pop("segmentation_path") + segmentation=MRIData.from_file(args.pop("segmentation_path"), dtype=np.single), + csf_mask=MRIData.from_file(args.pop("csf_mask_path"), dtype=np.single), ) intracranial_mask_data.save(args.pop("output"), dtype=np.uint8) else: diff --git a/src/mritk/segmentation.py b/src/mritk/segmentation.py index 0f97fe5..db00e95 100644 --- a/src/mritk/segmentation.py +++ b/src/mritk/segmentation.py @@ -4,17 +4,23 @@ # Copyright (C) 2026 Cécile Daversin-Catty (cecile@simula.no) # Copyright (C) 2026 Simula Research Laboratory +import argparse +import itertools import logging import os import re +from collections.abc import Callable +from dataclasses import dataclass from pathlib import Path from urllib.request import urlretrieve import numpy as np import numpy.typing as npt import pandas as pd +import scipy -from .data import MRIData, load_mri_data +from .data import MRIData, apply_affine +from .testing import assert_same_space logger = logging.getLogger(__name__) @@ -79,7 +85,8 @@ } -class Segmentation(MRIData): +@dataclass +class Segmentation: """ Base class for MRI segmentations, linking spatial data with anatomical lookup tables. @@ -88,7 +95,7 @@ class Segmentation(MRIData): labels to a descriptive Lookup Table (LUT). """ - def __init__(self, data: np.ndarray, affine: np.ndarray, lut: pd.DataFrame | None = None): + def __init__(self, mri: MRIData, lut: pd.DataFrame | None = None): """ Initializes the Segmentation object. @@ -98,11 +105,10 @@ def __init__(self, data: np.ndarray, affine: np.ndarray, lut: pd.DataFrame | Non lut (Optional[pd.DataFrame], optional): A pandas DataFrame mapping numerical labels to their descriptions. If None, a default numerical mapping is generated. Defaults to None. """ - super().__init__(data, affine) - self.data = self.data.astype(int) + self.mri = mri # Extract all unique active regions (ignoring 0/background) - self.rois = np.unique(self.data[self.data > 0]) + self.rois = np.unique(self.mri.data[self.mri.data > 0]) if lut is not None: self.lut = lut @@ -110,7 +116,66 @@ def __init__(self, data: np.ndarray, affine: np.ndarray, lut: pd.DataFrame | Non self.lut = pd.DataFrame({"Label": self.rois}, index=self.rois) # Identify the primary label column dynamically - self._label_name = "Label" if "Label" in self.lut.columns else self.lut.columns[0] + self.label_name = "Label" if "Label" in self.lut.columns else self.lut.columns[0] + + @classmethod + def from_file( + cls, seg_path: Path, dtype: npt.DTypeLike | None = None, orient: bool = True, lut_path: Path | None = None + ) -> "Segmentation": + """Loads a Segmentation from a NIfTI file. + + Args: + seg_path (Path): The file path to the segmentation NIfTI file. + dtype (npt.DTypeLike, optional): The data type for the segmentation data. Defaults to None. + orient (bool, optional): Whether to orient the data. Defaults to True. + lut_path (Path, optional): The file path to the lookup table. Defaults to None. + Returns: + Segmentation: An instance of the Segmentation class containing the loaded + segmentation data and affine transformation. + """ + logger.info(f"Loading segmentation from {seg_path}.") + mri = MRIData.from_file(seg_path, dtype=dtype, orient=orient) + + if lut_path is None and seg_path.with_suffix(".json").exists(): + lut_path = seg_path.with_suffix(".json") + + if lut_path is not None: + logger.info(f"Loading LUT from {lut_path}.") + lut = pd.read_json(lut_path) + else: + rois = np.unique(mri.data[mri.data > 0]) + lut = pd.DataFrame({"Label": rois}, index=rois) + + return cls(mri=mri, lut=lut) + + def save(self, output_path: Path, dtype: npt.DTypeLike | None = None, intent_code: int = 1006, lut_path: Path | None = None): + """Saves the Segmentation to a NIfTI file. + + Args: + output_path (Path): The file path where the segmentation will be saved. + dtype (npt.DTypeLike, optional): The data type for the saved segmentation data. Defaults to None. + intent_code (int, optional): The NIfTI intent code to set in the header. Defaults to 1006 (NIFTI_INTENT_LABEL). + """ + self.mri.save(output_path, dtype=dtype, intent_code=intent_code) + if lut_path is not None: + self.lut.to_json(lut_path, orient="index") + else: + self.lut.to_json(output_path.with_suffix(".json"), orient="index") + + def set_lut(self, lut: pd.DataFrame, label_column: str = "Label"): + """Sets the Lookup Table (LUT) for the segmentation, ensuring it matches the present ROIs. + + Args: + lut (pd.DataFrame): A pandas DataFrame mapping numerical labels + to their descriptions. If None, a default numerical mapping is generated. Defaults to None. + label_column (str, optional): The name of the column in the LUT that contains the label + descriptions. Defaults to "Label". + """ + + self.lut = lut + self.label_name = label_column + if self.label_name not in self.lut.columns: + raise ValueError(f"Specified label column '{self.label_name}' not found in LUT.") @property def num_rois(self) -> int: @@ -145,7 +210,76 @@ def get_roi_labels(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFr if not np.isin(rois, self.rois).all(): raise ValueError("Some of the provided ROIs are not present in the segmentation.") - return self.lut.loc[self.lut.index.isin(rois), [self._label_name]].rename_axis("ROI").reset_index() + return self.lut.loc[self.lut.index.isin(rois), [self.label_name]].rename_axis("ROI").reset_index() + + def resample_to_reference(self, reference_mri: MRIData) -> "Segmentation": + """ + Resamples the segmentation to match the spatial dimensions and resolution of a reference MRI. + + Args: + reference_mri (MRIData): The MRI to which the segmentation will be resampled, + for example a T1-weighted anatomical scan. + Returns: + Segmentation: A new Segmentation object containing the resampled data. + """ + + shape_in = self.mri.shape + shape_out = reference_mri.shape + + # Generate a grid of voxel indices for the output space + upsampled_indices = np.fromiter( + itertools.product(*(np.arange(ni) for ni in shape_out)), + dtype=np.dtype((int, 3)), + ) + # Get voxel indices in the input segmentation space corresponding to the output grid + seg_indices = apply_affine( + np.linalg.inv(self.mri.affine), + apply_affine(reference_mri.affine, upsampled_indices), + ) + seg_indices = np.rint(seg_indices).astype(int) + + # The two images does not necessarily share field of view. + # Remove voxels which are not located within the segmentation fov. + valid_index_mask = (seg_indices > 0).all(axis=1) * (seg_indices < shape_in).all(axis=1) + upsampled_indices = upsampled_indices[valid_index_mask] + seg_indices = seg_indices[valid_index_mask] + + seg_upsampled = np.zeros(shape_out, dtype=self.mri.data.dtype) + I_in, J_in, K_in = seg_indices.T + I_out, J_out, K_out = upsampled_indices.T + seg_upsampled[I_out, J_out, K_out] = self.mri.data[I_in, J_in, K_in] + + # return Segmentation(data=seg_upsampled, affine=reference_mri.affine, lut=self.lut) + mri = MRIData(data=seg_upsampled, affine=reference_mri.affine) + return Segmentation(mri=mri, lut=self.lut) + + def smooth(self, sigma: float, cutoff_score: float = 0.5, **kwargs) -> "Segmentation": + """ + Applies Gaussian smoothing to the segmentation labels to create a soft probabilistic map. + + Args: + sigma (float): The standard deviation for the Gaussian kernel. + cutoff_score (float, optional): A threshold to remove low-confidence voxels. Defaults to 0.5. + **kwargs: Additional keyword arguments passed to scipy.ndimage.gaussian_filter. + + Returns: + dict[str, np.ndarray]: A dictionary containing 'labels' (the smoothed segmentation) + and 'scores' (the confidence scores for each voxel). + """ + smoothed_rois = np.zeros_like(self.mri.data) + high_scores = np.zeros(self.mri.data.shape) + + for roi in self.rois: + scores = scipy.ndimage.gaussian_filter((self.mri.data == roi).astype(float), sigma=sigma, **kwargs) + is_new_high_score = scores > high_scores + smoothed_rois[is_new_high_score] = roi + high_scores[is_new_high_score] = scores[is_new_high_score] + + delete_scores = (high_scores < cutoff_score) * (self.mri.data == 0) + smoothed_rois[delete_scores] = 0 + + mri = MRIData(data=smoothed_rois, affine=self.mri.affine) + return Segmentation(mri=mri, lut=self.lut) class FreeSurferSegmentation(Segmentation): @@ -179,8 +313,8 @@ def from_file( # FreeSurfer LUTs index by the "label" column lut = lut.set_index("label") if "label" in lut.columns else lut - data, affine = load_mri_data(filepath, dtype=dtype, orient=orient) - return cls(data=data, affine=affine, lut=lut) + mri = MRIData.from_file(filepath, dtype=dtype, orient=orient) + return cls(mri=mri, lut=lut) class ExtendedFreeSurferSegmentation(FreeSurferSegmentation): @@ -218,7 +352,7 @@ def get_roi_labels(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFr left_on="FreeSurfer_ROI", right_on="FreeSurfer_ROI", how="outer", - ).drop(columns=["FreeSurfer_ROI"])[["ROI", self._label_name, "tissue_type"]] + ).drop(columns=["FreeSurfer_ROI"])[["ROI", self.label_name, "tissue_type"]] def get_tissue_type(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFrame: """ @@ -249,6 +383,35 @@ def get_tissue_type(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataF return ret +@dataclass +class CSFSegmentation: + segmentation: Segmentation + csf_mask: MRIData + + def __init__(self, segmentation: Segmentation, csf_mask: MRIData): + assert_same_space(segmentation.mri, csf_mask) + self.segmentation = segmentation + self.csf_mask = csf_mask + + @classmethod + def from_file(cls, segmentation_path: Path, csf_mask_path: Path) -> "CSFSegmentation": + segmentation = Segmentation.from_file(segmentation_path, dtype=np.int16) + csf_mask = MRIData.from_file(csf_mask_path, dtype=bool) + assert_same_space(segmentation.mri, csf_mask) + return cls(segmentation=segmentation, csf_mask=csf_mask) + + def to_csf_segmentation(self) -> MRIData: + # Get interpolation operator + I, J, K = np.where(self.segmentation.mri.data != 0) + interp = scipy.interpolate.NearestNDInterpolator(np.array([I, J, K]).T, self.segmentation.mri.data[I, J, K]) + # Interpolate segmentation values at CSF mask locations + i, j, k = np.where(self.csf_mask.data != 0) + csf_seg = np.zeros_like(self.segmentation.mri.data, dtype=np.int16) + csf_seg[i, j, k] = interp(i, j, k) + + return MRIData(data=csf_seg.astype(np.int16), affine=self.csf_mask.affine) + + def default_segmentation_groups() -> dict[str, list[int]]: """ Returns the default grouping of FreeSurfer labels into brain regions. @@ -407,3 +570,80 @@ def write_lut(filename: Path, table: pd.DataFrame): # Save as tab-separated values without headers or indices newtable.to_csv(filename, sep="\t", index=False, header=False) + + +def add_arguments( + parser: argparse.ArgumentParser, + extra_args_cb: Callable[[argparse.ArgumentParser], None] | None = None, +) -> None: + subparser = parser.add_subparsers(dest="seg-command", help="Commands for segmentation processing") + + resample_parser = subparser.add_parser( + "resample", help="Resample a segmentation to match the space of a reference MRI", formatter_class=parser.formatter_class + ) + resample_parser.add_argument("-i", "--input", type=Path, help="Path to the input segmentation NIfTI file") + resample_parser.add_argument( + "-r", + "--reference", + type=Path, + help="Path to the reference MRI \ + - usually a registered T1 weighted anatomical scan", + ) + resample_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the resampled segmentation") + + smooth_parser = subparser.add_parser( + "smooth", + help="Apply Gaussian smoothing to a segmentation to create a soft probabilistic map", + formatter_class=parser.formatter_class, + ) + smooth_parser.add_argument("-i", "--input", type=Path, help="Path to the input (refined) segmentation NIfTI file") + smooth_parser.add_argument("-s", "--sigma", type=float, help="Standard deviation for the Gaussian kernel used in smoothing") + smooth_parser.add_argument( + "-c", "--cutoff", type=float, default=0.5, help="Cutoff score to remove low-confidence voxels (default: 0.5)" + ) + smooth_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the smoothed segmentation") + + refine_parser = subparser.add_parser( + "refine", + help="Refine a segmentation by applying Gaussian smoothing to the labels", + formatter_class=parser.formatter_class, + ) + refine_parser.add_argument("-i", "--input", type=Path, help="Path to the input segmentation NIfTI file") + refine_parser.add_argument( + "-r", + "--reference", + type=Path, + help="Path to the reference MRI \ + - usually a registered T1 weighted anatomical scan", + ) + refine_parser.add_argument("-s", "--smooth", type=float, help="Standard deviation for the Gaussian kernel used in smoothing") + refine_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the refined segmentation") + + if extra_args_cb is not None: + extra_args_cb(resample_parser) + extra_args_cb(smooth_parser) + extra_args_cb(refine_parser) + + +def dispatch(args): + command = args.pop("seg-command") + if command == "resample": + print("Resampling segmentation...") + input_seg = Segmentation.from_file(args.pop("input")) + reference_mri = MRIData.from_file(args.pop("reference")) + resampled_seg = input_seg.resample_to_reference(reference_mri) + resampled_seg.save(args.pop("output"), dtype=np.int32) + + elif command == "smooth": + smoothed = Segmentation.from_file(args.pop("input")).smooth(sigma=args.pop("sigma"), cutoff_score=args.pop("cutoff")) + smoothed.save(args.pop("output"), dtype=np.int32) + + elif command == "refine": + seg = Segmentation.from_file(args.pop("input")) + refined = seg.resample_to_reference(MRIData.from_file(args.pop("reference"))) + smoothed = refined.smooth(sigma=args.pop("smooth")) + refined.mri.data = np.where(smoothed.mri.data > 0, smoothed.mri.data, refined.mri.data) + refined.save(args.pop("output"), dtype=np.int32) + + else: + raise ValueError(f"Unknown segmentation command: {command}") diff --git a/src/mritk/statistics/compute_stats.py b/src/mritk/statistics/compute_stats.py index 9fb4ca5..b5ccc10 100644 --- a/src/mritk/statistics/compute_stats.py +++ b/src/mritk/statistics/compute_stats.py @@ -219,7 +219,7 @@ def generate_stats_dataframe_rois( metadata: Optional[dict] = None, ) -> pd.DataFrame: # Verify that segmentation and MRI are in the same space - assert_same_space(seg, mri) + assert_same_space(seg.mri, mri) qoi_records = [] # Collects records related to qois roi_records = [] # Collects records related to ROIs, @@ -228,7 +228,7 @@ def generate_stats_dataframe_rois( finite_mask = np.isfinite(mri.data) for roi in tqdm.rich.tqdm(seg.roi_labels, total=len(seg.roi_labels)): # Identify rois in segmentation - region_mask = (seg.data == roi) * finite_mask + region_mask = (seg.mri.data == roi) * finite_mask # print(region_mask.shape) region_data = mri.data[region_mask] nb_nans = np.isnan(region_data).sum() @@ -239,7 +239,7 @@ def generate_stats_dataframe_rois( { "ROI": roi, "voxel_count": voxelcount, - "volume_ml": seg.voxel_ml_volume * voxelcount, + "volume_ml": seg.mri.voxel_ml_volume * voxelcount, "num_nan_values": nb_nans, } ) diff --git a/tests/conftest.py b/tests/conftest.py index be56825..599b4c2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,7 +21,7 @@ def example_segmentation() -> Segmentation: base = np.array([0, 1, 2, 3], dtype=float) seg = np.tile(base, (100, 1)) - return Segmentation(seg, affine=np.eye(4)) + return Segmentation(MRIData(data=seg, affine=np.eye(4))) @pytest.fixture diff --git a/tests/create_test_data.py b/tests/create_test_data.py index 811a335..68571b0 100644 --- a/tests/create_test_data.py +++ b/tests/create_test_data.py @@ -21,11 +21,18 @@ def main(): "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-intracranial_binary.nii.gz", "mri-processed/mri_dataset/derivatives/sub-01/ses-01/sub-01_ses-01_acq-mixed_T1map.nii.gz", "mri-processed/mri_dataset/derivatives/sub-01/ses-01/sub-01_ses-01_acq-looklocker_T1map.nii.gz", - "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-aparc+aseg_refined.nii.gz", "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_acq-looklocker_T1map_registered.nii.gz", "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-aseg.nii.gz", + "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-aparc+aseg.nii.gz", + "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-wmparc.nii.gz", + "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-aseg_refined.nii.gz", + "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-aparc+aseg_refined.nii.gz", "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-wmparc_refined.nii.gz", "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T2w_registered.nii.gz", + "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T1w_registered.nii.gz", + "freesurfer/mri_processed_data/freesurfer/sub-01/mri/aparc+aseg.mgz", + "freesurfer/mri_processed_data/freesurfer/sub-01/mri/aseg.mgz", + "freesurfer/mri_processed_data/freesurfer/sub-01/mri/wmparc.mgz", ] for file in files: diff --git a/tests/test_masks.py b/tests/test_masks.py index 2494e43..d2bfa01 100644 --- a/tests/test_masks.py +++ b/tests/test_masks.py @@ -12,7 +12,15 @@ import numpy as np import mritk.cli -from mritk.masks import compute_csf_mask_array, compute_intracranial_mask_array, csf_mask, intracranial_mask, largest_island +from mritk.data import MRIData +from mritk.masks import ( + compute_csf_mask_array, + compute_intracranial_mask_array, + csf_mask, + intracranial_mask, + largest_island, +) +from mritk.segmentation import Segmentation from mritk.testing import compare_nifti_images @@ -118,7 +126,8 @@ def test_csf_mask_io(tmp_path): nii = nib.Nifti1Image(data, np.eye(4)) nib.save(nii, in_path) - result = csf_mask(input=in_path, use_li=True) + input_data = mritk.data.MRIData.from_file(in_path, dtype=np.single) + result = csf_mask(input=input_data, use_li=True) result.save(out_path, dtype=np.uint8) # Verify the file was physically saved to the filesystem @@ -146,7 +155,7 @@ def test_intracranial_mask_io(tmp_path): seg_data[4:6, 4:6, 4:6] = 1.0 nib.save(nib.Nifti1Image(seg_data, affine), seg_path) - result = intracranial_mask(csf_segmentation_path=csf_path, segmentation_path=seg_path) + result = intracranial_mask(segmentation=Segmentation(mri=MRIData(seg_data, affine)), csf_mask=MRIData(csf_data, affine)) result.save(out_path, dtype=np.uint8) # Verify the file was physically saved to the filesystem @@ -156,32 +165,35 @@ def test_intracranial_mask_io(tmp_path): @patch("mritk.masks.csf_mask") -def test_dispatch_csf_mask(mock_csf_mask): +@patch("mritk.data.MRIData.from_file") +def test_dispatch_csf_mask(mock_from_file, mock_csf_mask): """Test the CLI dispatch for the CSF mask command.""" - mritk.cli.main(["mask", "csf", "-i", "input.nii.gz", "--output", "mock_out.nii.gz", "--use-li", "--connectivity", "2"]) + mritk.cli.main(["mask", "csf", "-i", "input.nii.gz", "-o", "mock_out.nii.gz", "--use-li", "--connectivity", "2"]) - mock_csf_mask.assert_called_once_with(input=Path("input.nii.gz"), connectivity=2, use_li=True) + input_data = mock_from_file(Path("input.nii.gz"), dtype=np.single) + mock_csf_mask.assert_called_once_with(input=input_data, connectivity=2, use_li=True) @patch("mritk.masks.intracranial_mask") -def test_dispatch_intracranial_mask(mock_intracranial_mask): +@patch("mritk.data.MRIData.from_file") +def test_dispatch_intracranial_mask(mock_from_file, mock_intracranial_mask): """Test the CLI dispatch for the intracranial mask command.""" mritk.cli.main( [ "mask", "intracranial", - "--csf-segmentation-path", - "csf_segmentation.nii.gz", "--segmentation-path", "segmentation.nii.gz", + "--csf-mask-path", + "csf_mask.nii.gz", "-o", "ic_mask.nii.gz", ] ) - mock_intracranial_mask.assert_called_once_with( - csf_segmentation_path=Path("csf_segmentation.nii.gz"), segmentation_path=Path("segmentation.nii.gz") - ) + seg_data = mock_from_file(Path("segmentation.nii.gz"), dtype=np.single) + csf_data = mock_from_file(Path("csf_mask.nii.gz"), dtype=np.single) + mock_intracranial_mask.assert_called_once_with(segmentation=seg_data, csf_mask=csf_data) def test_csf_mask(tmp_path, mri_data_dir: Path): @@ -191,18 +203,22 @@ def test_csf_mask(tmp_path, mri_data_dir: Path): ref_output = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf_binary.nii.gz" test_output = tmp_path / "output_seg-csf_binary.nii.gz" - result = csf_mask(input=input_T2w_path, use_li=use_li) + input_T2w = mritk.data.MRIData.from_file(input_T2w_path, dtype=np.single) + result = csf_mask(input=input_T2w, use_li=use_li) result.save(test_output, dtype=np.uint8) compare_nifti_images(test_output, ref_output, data_tolerance=1e-12) def test_intracranial_mask(tmp_path, mri_data_dir: Path): - csf_segmentation_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-aseg.nii.gz" + csf_mask_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf_binary.nii.gz" segmentation_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-wmparc_refined.nii.gz" ref_output = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-intracranial_binary.nii.gz" test_output = tmp_path / "output_seg-intracranial_binary.nii.gz" - result = intracranial_mask(csf_segmentation_path=csf_segmentation_path, segmentation_path=segmentation_path) + input_segmentation = mritk.segmentation.Segmentation.from_file(segmentation_path, dtype=np.single) + input_csf_mask = mritk.data.MRIData.from_file(csf_mask_path, dtype=np.single) + + result = intracranial_mask(segmentation=input_segmentation, csf_mask=input_csf_mask) result.save(test_output, dtype=np.uint8) compare_nifti_images(test_output, ref_output, data_tolerance=1e-12) diff --git a/tests/test_mri_io.py b/tests/test_mri_io.py index 57aa442..c98f9a3 100644 --- a/tests/test_mri_io.py +++ b/tests/test_mri_io.py @@ -55,8 +55,9 @@ def test_load_mri_data_invalid_suffix(mri_data_dir): @pytest.mark.parametrize("orient", (True, False)) def test_load_Segmentation(tmp_path, mri_data_dir, orient: bool): input_file = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-aparc+aseg_refined.nii.gz" - seg = Segmentation.from_file(input_file) - assert seg.data.dtype == int + seg = Segmentation.from_file(input_file, dtype=np.int32) + + assert seg.mri.data.dtype == np.int32 mri = MRIData.from_file(input_file, dtype=np.single, orient=orient) output_file = tmp_path.with_suffix(".nii.gz") mri.save(output_file, dtype=np.single) diff --git a/tests/test_segmentation.py b/tests/test_segmentation.py index a368fde..899cfd4 100644 --- a/tests/test_segmentation.py +++ b/tests/test_segmentation.py @@ -5,9 +5,12 @@ import pandas as pd import pytest +import mritk.cli +from mritk.data import MRIData from mritk.segmentation import ( LUT_REGEX, VENTRICLES, + CSFSegmentation, ExtendedFreeSurferSegmentation, Segmentation, default_segmentation_groups, @@ -20,8 +23,8 @@ def test_segmentation_initialization(example_segmentation: Segmentation): - assert example_segmentation.data.shape == (100, 4) - assert example_segmentation.affine.shape == (4, 4) + assert example_segmentation.mri.data.shape == (100, 4) + assert example_segmentation.mri.affine.shape == (4, 4) assert example_segmentation.num_rois == 3 assert set(example_segmentation.roi_labels) == {1, 2, 3} assert example_segmentation.lut.shape == (3, 1) @@ -44,11 +47,11 @@ def test_freesurfer_segmentation_labels(mri_data_dir: Path): def test_extended_freesurfer_segmentation_labels(example_segmentation: Segmentation, mri_data_dir: Path): - data = example_segmentation.data + data = example_segmentation.mri.data data[0:2, 0:2] = 10001 # csf data[3:5, 3:5] = 20001 # dura - ext_fs_seg = ExtendedFreeSurferSegmentation(data, affine=np.eye(4)) + ext_fs_seg = ExtendedFreeSurferSegmentation(MRIData(data=data, affine=np.eye(4))) labels = ext_fs_seg.get_roi_labels() assert set(labels["ROI"]) == set(ext_fs_seg.roi_labels) @@ -180,3 +183,133 @@ def test_write_lut_file_io(tmp_path): # Verify the denormalization restored the original 0-255 integers assert content[0] == "4\tLeft-Lateral-Ventricle\t120\t18\t134\t0" assert content[1] == "5\tLeft-Inf-Lat-Vent\t198\t51\t122\t0" + + +# Note : Refinement is actually testing both resampling and smoothing +# @pytest.mark.xfail( +# reason=("Call to resample_to_reference fails due to shape issue when using gonzo_roi. Needs to be investigated further.") +# ) +@pytest.mark.parametrize("seg_type", ["aparc+aseg", "aseg", "wmparc"]) +def test_segmentation_refinement(tmp_path, mri_data_dir: Path, gonzo_roi, seg_type: str): + # Get gonzo_roi from FS_segmentation + FS_seg_path = mri_data_dir / f"freesurfer/mri_processed_data/freesurfer/sub-01/mri/{seg_type}.mgz" + fs_seg = Segmentation.from_file(FS_seg_path) # MRIData type + vi = gonzo_roi.voxel_indices(affine=fs_seg.mri.affine) + v = fs_seg.mri.data[tuple(vi.T)].reshape(gonzo_roi.shape) + piece_fs_seg_data = mritk.data.MRIData(data=v, affine=gonzo_roi.affine) + + # Get gonzo_roi from reference MRI to use as reference for resampling + ref_mri_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T1w_registered.nii.gz" + ref_mri = MRIData.from_file(ref_mri_path, dtype=np.single) + vi = gonzo_roi.voxel_indices(affine=ref_mri.affine) + v = ref_mri.data[tuple(vi.T)].reshape(gonzo_roi.shape) + piece_ref_mri_data = mritk.data.MRIData(data=v, affine=gonzo_roi.affine) + + # Output: Refine segmentation from gonzoi_roi segmentation and ref MRI + test_output = tmp_path / "output_refined.nii.gz" + + smoothing = 1 + piece_fs_seg = Segmentation(mri=piece_fs_seg_data) + result = piece_fs_seg.resample_to_reference(piece_ref_mri_data) + smoothed = result.smooth(sigma=smoothing) + result.mri.data = smoothed.mri.data + result.save(test_output, dtype=np.int32) + + ref_output_path = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-{seg_type}_refined.nii.gz" + ref_output = mritk.data.MRIData.from_file(ref_output_path, dtype=np.single) + vi = gonzo_roi.voxel_indices(affine=ref_output.affine) + v_ref = ref_output.data[tuple(vi.T)].reshape(gonzo_roi.shape) + + mritk.testing.compare_nifti_arrays(result.mri.data, v_ref, data_tolerance=1e-12) + + +@pytest.mark.parametrize("seg_type", ["aparc+aseg", "aseg", "wmparc"]) +def test_csf_segmentation(tmp_path, mri_data_dir: Path, gonzo_roi, seg_type): + """Test the CSF segmentation logic by comparing against a known reference.""" + input_seg_path = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-{seg_type}_refined.nii.gz" + input_csf_mask_path = mri_data_dir / "mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf_binary.nii.gz" + + ref_output_path = mri_data_dir / f"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-{seg_type}.nii.gz" + + input_seg = MRIData.from_file(input_seg_path, dtype=np.single) + vi = gonzo_roi.voxel_indices(affine=input_seg.affine) + v = input_seg.data[tuple(vi.T)].reshape(gonzo_roi.shape) + piece_seg_data = mritk.data.MRIData(data=v, affine=gonzo_roi.affine) + piece_seg = Segmentation(mri=piece_seg_data) + + input_csf_mask = MRIData.from_file(input_csf_mask_path, dtype=np.single) + vi = gonzo_roi.voxel_indices(affine=input_csf_mask.affine) + v = input_csf_mask.data[tuple(vi.T)].reshape(gonzo_roi.shape) + piece_csf_mask_data = mritk.data.MRIData(data=v, affine=gonzo_roi.affine) + + result = CSFSegmentation(segmentation=piece_seg, csf_mask=piece_csf_mask_data).to_csf_segmentation() + + ref_output = MRIData.from_file(ref_output_path, dtype=np.single) + vi = gonzo_roi.voxel_indices(affine=ref_output.affine) + v_ref = ref_output.data[tuple(vi.T)].reshape(gonzo_roi.shape) + + mritk.testing.compare_nifti_arrays(result.data, v_ref, data_tolerance=1e-12) + + +@patch("mritk.segmentation.MRIData") +@patch("mritk.segmentation.Segmentation") +def test_dispatch_resample(mock_seg, mock_mri_data): + """Test that dispatch correctly routes to segmentation resample.""" + + mritk.cli.main(["seg", "resample", "-i", "mock_in.nii.gz", "-r", "mock_ref.nii.gz", "-o", "mock_out.nii.gz"]) + + mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz")) + mock_mri_data.from_file.assert_called_once_with(Path("mock_ref.nii.gz")) + + inst = mock_seg.from_file.return_value # Segmentation type instance returned by from_file + inst.resample_to_reference.assert_called_once_with(mock_mri_data.from_file.return_value) + + +@patch("mritk.segmentation.Segmentation") +def test_dispatch_smoothing(mock_seg): + """Test that dispatch correctly routes to segmentation smoothing.""" + + mritk.cli.main(["seg", "smooth", "-i", "mock_in.nii.gz", "-o", "mock_out.nii.gz", "-s", "1"]) + + mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz")) + inst = mock_seg.from_file.return_value # Segmentation type instance returned by from_file + inst.smooth.assert_called_once_with(sigma=1.0, cutoff_score=0.5) + + +@patch("mritk.segmentation.MRIData") +@patch("mritk.segmentation.Segmentation") +def test_dispatch_refine(mock_seg, mock_mri_data): + """Test that dispatch correctly routes to segmentation refinement.""" + + # Mock the underlying data arrays to avoid TypeError in np.where + inst = mock_seg.from_file.return_value + refined_inst = inst.resample_to_reference.return_value + smoothed_inst = refined_inst.smooth.return_value + + # Setup mock numpy arrays for the attributes used in np.where + smoothed_inst.data = np.array([1]) # In case the source code bug isn't fixed yet + refined_inst.data = np.array([0]) # In case the source code bug isn't fixed yet + refined_inst.mri.data = np.array([0]) # Correct fixed access + smoothed_inst.mri.data = np.array([1]) # Correct fixed access + + mritk.cli.main( + [ + "seg", + "refine", + "-i", + "mock_in.nii.gz", + "-r", + "mock_ref.nii.gz", + "-o", + "mock_out.nii.gz", + "-s", + "1", + ] + ) + + mock_seg.from_file.assert_called_once_with(Path("mock_in.nii.gz")) + mock_mri_data.from_file.assert_called_once_with(Path("mock_ref.nii.gz")) + + inst.resample_to_reference.assert_called_once_with(mock_mri_data.from_file.return_value) + refined_inst.smooth.assert_called_once_with(sigma=1.0) + refined_inst.save.assert_called_once_with(Path("mock_out.nii.gz"), dtype=np.int32)