Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
de00881
Implement segmentation refinement as in gMRI2FEM and csf segmentation…
cdaversin Apr 23, 2026
ab5ee60
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2026
1479097
Add missing files in test_data
cdaversin Apr 23, 2026
e41604b
Merge branch 'cecile/segmentation_refinement' of https://github.com/s…
cdaversin Apr 23, 2026
e0b062b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2026
10ef679
Minor - removed unused file in test
cdaversin Apr 23, 2026
53ef36e
Fix mypy - convert MRIData to Segmentation type
cdaversin Apr 23, 2026
d655ea0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2026
0811a95
Reset cache for data
finsberg Apr 23, 2026
3e09374
Update masks.py function input type and move csf_segmentation to segm…
cdaversin Apr 24, 2026
9b81380
Updates in segmentation: Segmentation class does not inherit from MRI…
cdaversin Apr 24, 2026
29a5247
fix conflict
cdaversin Apr 24, 2026
b55bdf3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 24, 2026
4351949
Various fixes
finsberg Apr 24, 2026
78ef8e1
Fix shape issue in refinement test
finsberg Apr 24, 2026
97e4acf
Use gonzo roi for csf segmentation test
finsberg Apr 24, 2026
e5e28fe
Fix dispatch tests in segmentation
finsberg Apr 24, 2026
202e909
Merge pull request #47 from scientificcomputing/finsberg/segmentation…
cdaversin Apr 25, 2026
9c0ab6e
Merge branch 'main' of https://github.com/scientificcomputing/mri-too…
cdaversin Apr 25, 2026
7edd899
fixes in segmentation classes
cdaversin Apr 25, 2026
fe21f1a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 25, 2026
d4e05e9
fixes in segmentation classes
cdaversin Apr 25, 2026
d6c70e0
Merge branch 'cecile/segmentation_refinement' of https://github.com/s…
cdaversin Apr 25, 2026
edca7a1
fixes in segmentation classes
cdaversin Apr 25, 2026
5b0e54b
Fixes after changes in CSFSegmentation
cdaversin Apr 25, 2026
7b56606
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 25, 2026
f833bc3
Fix tests
cdaversin Apr 25, 2026
44da573
Merge branch 'cecile/segmentation_refinement' of https://github.com/s…
cdaversin Apr 25, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/setup-data.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
path: data/ # The folder you want to cache
# The key determines if we have a match.
# Change 'v1' to 'v2' manually to force a re-download in the future.
key: test-data-v7
key: test-data-v8


# 2. DOWNLOAD ONLY IF CACHE MISS
Expand Down
7 changes: 6 additions & 1 deletion src/mritk/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from rich.logging import RichHandler
from rich_argparse import RichHelpFormatter

from . import concentration, datasets, hybrid, info, looklocker, masks, mixed, napari, r1, show, statistics
from . import concentration, datasets, hybrid, info, looklocker, masks, mixed, napari, r1, segmentation, show, statistics


def version_info():
Expand Down Expand Up @@ -75,6 +75,9 @@ def setup_parser():
napari_parser = subparsers.add_parser("napari", help="Show MRI data using napari", formatter_class=parser.formatter_class)
napari.add_arguments(napari_parser)

segmentation_parser = subparsers.add_parser("seg", help="Perform segmentation tasks", formatter_class=parser.formatter_class)
segmentation.add_arguments(segmentation_parser, extra_args_cb=add_extra_arguments)

looklocker_parser = subparsers.add_parser(
"looklocker", help="Process Look-Locker data", formatter_class=parser.formatter_class
)
Expand Down Expand Up @@ -142,6 +145,8 @@ def dispatch(parser: argparse.ArgumentParser, argv: Optional[Sequence[str]] = No
show.dispatch(args)
elif command == "napari":
napari.dispatch(args)
elif command == "seg":
segmentation.dispatch(args)
elif command == "looklocker":
looklocker.dispatch(args)
elif command == "mask":
Expand Down
71 changes: 61 additions & 10 deletions src/mritk/masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pathlib import Path

import numpy as np
import scipy.interpolate
import skimage

from .data import MRIData
Expand Down Expand Up @@ -107,6 +108,45 @@ def csf_mask(input: Path, connectivity: int | None = 2, use_li: bool = False) ->
return mri_data


def csf_segmentation(input_segmentation: Path | MRIData, csf_mask: Path | MRIData) -> MRIData:
Comment thread
finsberg marked this conversation as resolved.
Outdated
"""
Generates a CSF segmentation by applying a CSF mask to an anatomical segmentation.
Comment thread
finsberg marked this conversation as resolved.
Outdated

This function takes an anatomical segmentation (e.g., from FreeSurfer) and a CSF mask,
and produces a new segmentation where voxels identified as CSF in the mask are labeled
with their original segmentation values, while non-CSF voxels are set to zero.

Args:
input_segmentation (Path | MRIData): Path to the anatomical segmentation NIfTI file
or an MRIData object containing the resampled segmentation.
csf_mask (Path | MRIData): Either a path to a CSF mask NIfTI file or an MRIData object containing the mask.

Returns:
MRIData: An MRIData object containing the CSF segmentation.
"""
if isinstance(input_segmentation, Path):
seg_mri = MRIData.from_file(input_segmentation, dtype=np.int16)
else:
seg_mri = input_segmentation

if isinstance(csf_mask, Path):
csf_mask_mri = MRIData.from_file(csf_mask, dtype=bool)
else:
csf_mask_mri = csf_mask

assert_same_space(seg_mri, csf_mask_mri)

# Get interpolation operator
I, J, K = np.where(seg_mri.data != 0)
interp = scipy.interpolate.NearestNDInterpolator(np.array([I, J, K]).T, seg_mri.data[I, J, K])
# Interpolate segmentation values at CSF mask locations
i, j, k = np.where(csf_mask_mri.data != 0)
csf_seg = np.zeros_like(seg_mri.data, dtype=np.int16)
csf_seg[i, j, k] = interp(i, j, k)

return MRIData(data=csf_seg.astype(np.int16), affine=csf_mask_mri.affine)


def compute_intracranial_mask_array(csf_mask_array: np.ndarray, segmentation_array: np.ndarray) -> np.ndarray:
"""
Combines a CSF mask array and a brain segmentation mask array into a solid intracranial mask.
Expand Down Expand Up @@ -134,28 +174,29 @@ def compute_intracranial_mask_array(csf_mask_array: np.ndarray, segmentation_arr
return ~opened_background


def intracranial_mask(csf_segmentation_path: Path, segmentation_path: Path) -> MRIData:
def intracranial_mask(segmentation_path: Path, csf_mask_path: Path) -> MRIData:
Comment thread
finsberg marked this conversation as resolved.
Outdated
"""
I/O wrapper for generating and saving an intracranial mask from NIfTI files.

Loads the masks, verifies they share the same physical coordinate space, and
delegates the array computation.

Args:
csf_segmentation_path (Path): Path to the CSF segmentation NIfTI file.
segmentation_path (Path): Path to the brain segmentation NIfTI file.
output (Optional[Path], optional): Path to save the resulting mask. Defaults to None.
segmentation_path (Path): Path to the brain (refined) segmentation NIfTI file, \
generated by the segmentation refinement module.
csf_mask_path (Path): Path to the CSF mask, generated by the csf mask module.

Returns:
MRIData: An MRIData object containing the intracranial mask.
"""
input_csf_mask = MRIData.from_file(csf_segmentation_path, dtype=bool)
# Get segmentation data and csf segmentation
segmentation_data = MRIData.from_file(segmentation_path, dtype=bool)
csf_seg = csf_segmentation(input_segmentation=segmentation_data, csf_mask=csf_mask_path)

# Validate spatial alignment before array operations
assert_same_space(input_csf_mask, segmentation_data)
assert_same_space(csf_seg, segmentation_data)

mask_data = compute_intracranial_mask_array(input_csf_mask.data, segmentation_data.data)
mask_data = compute_intracranial_mask_array(csf_seg.data, segmentation_data.data)
mri_data = MRIData(data=mask_data, affine=segmentation_data.affine)

return mri_data
Expand Down Expand Up @@ -183,8 +224,18 @@ def add_arguments(
intracranial_mask_parser = subparser.add_parser(
"intracranial", help="Compute intracranial mask", formatter_class=parser.formatter_class
)
intracranial_mask_parser.add_argument("--csf-segmentation-path", type=Path, help="Path to the CSF segmentation NIfTI file")
intracranial_mask_parser.add_argument("--segmentation-path", type=Path, help="Path to the brain segmentation NIfTI file")
intracranial_mask_parser.add_argument(
"--segmentation-path",
type=Path,
help="Path to refined segmentation file, generated by \
the segmentation refinement module, i.e. mritk seg refine",
)
intracranial_mask_parser.add_argument(
"--csf-mask-path",
type=Path,
help="Path to the CSF mask NIfTI file, generated by \
the csf mask module, i.e. mritk mask csf",
)
intracranial_mask_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the resulting mask")

if extra_args_cb is not None:
Expand All @@ -199,7 +250,7 @@ def dispatch(args):
csf_mask_data.save(args.pop("output"), dtype=np.uint8)
elif command == "intracranial":
intracranial_mask_data = intracranial_mask(
csf_segmentation_path=args.pop("csf_segmentation_path"), segmentation_path=args.pop("segmentation_path")
segmentation_path=args.pop("segmentation_path"), csf_mask_path=args.pop("csf_mask_path")
)
intracranial_mask_data.save(args.pop("output"), dtype=np.uint8)
else:
Expand Down
150 changes: 149 additions & 1 deletion src/mritk/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,21 @@
# Copyright (C) 2026 Cécile Daversin-Catty (cecile@simula.no)
# Copyright (C) 2026 Simula Research Laboratory

import argparse
import itertools
import logging
import os
import re
from collections.abc import Callable
from pathlib import Path
from urllib.request import urlretrieve

import numpy as np
import numpy.typing as npt
import pandas as pd
import scipy

from .data import MRIData, load_mri_data
from .data import MRIData, apply_affine, load_mri_data

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -88,6 +92,10 @@ class Segmentation(MRIData):
labels to a descriptive Lookup Table (LUT).
"""

mri: MRIData
rois: np.ndarray
lut: pd.DataFrame

def __init__(self, data: np.ndarray, affine: np.ndarray, lut: pd.DataFrame | None = None):
"""
Initializes the Segmentation object.
Expand Down Expand Up @@ -147,6 +155,72 @@ def get_roi_labels(self, rois: npt.NDArray[np.int32] | None = None) -> pd.DataFr

return self.lut.loc[self.lut.index.isin(rois), [self._label_name]].rename_axis("ROI").reset_index()

def resample_to_reference(self, reference_mri: MRIData):
"""
Resamples the segmentation to match the spatial dimensions and resolution of a reference MRI.

Args:
reference_mri (MRIData): The MRI to which the segmentation will be resampled,
for example a T1-weighted anatomical scan.
Returns:
Segmentation: A new Segmentation object containing the resampled data.
"""

shape_in = self.shape
shape_out = reference_mri.shape

# Generate a grid of voxel indices for the output space
upsampled_indices = np.fromiter(
itertools.product(*(np.arange(ni) for ni in shape_out)),
dtype=np.dtype((int, 3)),
)
# Get voxel indices in the input segmentation space corresponding to the output grid
seg_indices = apply_affine(
np.linalg.inv(self.affine),
apply_affine(reference_mri.affine, upsampled_indices),
)
seg_indices = np.rint(seg_indices).astype(int)

# The two images does not necessarily share field of view.
# Remove voxels which are not located within the segmentation fov.
valid_index_mask = (seg_indices > 0).all(axis=1) * (seg_indices < shape_in).all(axis=1)
upsampled_indices = upsampled_indices[valid_index_mask]
seg_indices = seg_indices[valid_index_mask]

seg_upsampled = np.zeros(shape_out, dtype=self.data.dtype)
I_in, J_in, K_in = seg_indices.T
I_out, J_out, K_out = upsampled_indices.T
seg_upsampled[I_out, J_out, K_out] = self.data[I_in, J_in, K_in]

return Segmentation(data=seg_upsampled, affine=reference_mri.affine, lut=self.lut)

def smooth(self, sigma: float, cutoff_score: float = 0.5, **kwargs) -> MRIData:
"""
Applies Gaussian smoothing to the segmentation labels to create a soft probabilistic map.

Args:
sigma (float): The standard deviation for the Gaussian kernel.
cutoff_score (float, optional): A threshold to remove low-confidence voxels. Defaults to 0.5.
**kwargs: Additional keyword arguments passed to scipy.ndimage.gaussian_filter.

Returns:
dict[str, np.ndarray]: A dictionary containing 'labels' (the smoothed segmentation)
and 'scores' (the confidence scores for each voxel).
"""
smoothed_rois = np.zeros_like(self.data)
high_scores = np.zeros(self.data.shape)

for roi in self.rois:
scores = scipy.ndimage.gaussian_filter((self.data == roi).astype(float), sigma=sigma, **kwargs)
is_new_high_score = scores > high_scores
smoothed_rois[is_new_high_score] = roi
high_scores[is_new_high_score] = scores[is_new_high_score]

delete_scores = (high_scores < cutoff_score) * (self.data == 0)
smoothed_rois[delete_scores] = 0

return MRIData(data=smoothed_rois, affine=self.affine)


class FreeSurferSegmentation(Segmentation):
"""
Expand Down Expand Up @@ -407,3 +481,77 @@ def write_lut(filename: Path, table: pd.DataFrame):

# Save as tab-separated values without headers or indices
newtable.to_csv(filename, sep="\t", index=False, header=False)


def add_arguments(
parser: argparse.ArgumentParser,
extra_args_cb: Callable[[argparse.ArgumentParser], None] | None = None,
) -> None:
subparser = parser.add_subparsers(dest="seg-command", help="Commands for segmentation processing")

resample_parser = subparser.add_parser(
"resample", help="Resample a segmentation to match the space of a reference MRI", formatter_class=parser.formatter_class
)
resample_parser.add_argument("-i", "--input", type=Path, help="Path to the input segmentation NIfTI file")
resample_parser.add_argument(
"-r",
"--reference",
type=Path,
help="Path to the reference MRI \
- usually a registered T1 weighted anatomical scan",
)
resample_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the resampled segmentation")

smooth_parser = subparser.add_parser(
"smooth",
help="Apply Gaussian smoothing to a segmentation to create a soft probabilistic map",
formatter_class=parser.formatter_class,
)
smooth_parser.add_argument("-i", "--input", type=Path, help="Path to the input (refined) segmentation NIfTI file")
smooth_parser.add_argument("-s", "--sigma", type=float, help="Standard deviation for the Gaussian kernel used in smoothing")
smooth_parser.add_argument(
"-c", "--cutoff", type=float, default=0.5, help="Cutoff score to remove low-confidence voxels (default: 0.5)"
)
smooth_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the smoothed segmentation")

refine_parser = subparser.add_parser(
"refine",
help="Refine a segmentation by applying Gaussian smoothing to the labels",
formatter_class=parser.formatter_class,
)
refine_parser.add_argument("-i", "--input", type=Path, help="Path to the input segmentation NIfTI file")
refine_parser.add_argument(
"-r",
"--reference",
type=Path,
help="Path to the reference MRI \
- usually a registered T1 weighted anatomical scan",
)
refine_parser.add_argument("-s", "--smooth", type=float, help="Standard deviation for the Gaussian kernel used in smoothing")
refine_parser.add_argument("-o", "--output", type=Path, help="Desired output path for the refined segmentation")

if extra_args_cb is not None:
extra_args_cb(resample_parser)
extra_args_cb(smooth_parser)
extra_args_cb(refine_parser)


def dispatch(args):
command = args.pop("seg-command")
if command == "resample":
print("Resampling segmentation...")
input_seg = Segmentation.from_file(args.pop("input"))
reference_mri = Segmentation.from_file(args.pop("reference"))
Comment thread
finsberg marked this conversation as resolved.
Outdated
resampled_seg = input_seg.resample_to_reference(reference_mri)
resampled_seg.save(args.pop("output"), dtype=np.int32)
elif command == "smooth":
smoothed = Segmentation.from_file(args.pop("input")).smooth(sigma=args.pop("sigma"), cutoff_score=args.pop("cutoff"))
smoothed.save(args.pop("output"), dtype=np.int32)
elif command == "refine":
seg = Segmentation.from_file(args.pop("input"))
refined = seg.resample_to_reference(MRIData.from_file(args.pop("reference")))
smoothed = refined.smooth(sigma=args.pop("smooth"))
refined.data = np.where(smoothed.data > 0, smoothed.data, refined.data)
refined.save(args.pop("output"), dtype=np.int32)
else:
raise ValueError(f"Unknown segmentation command: {command}")
8 changes: 7 additions & 1 deletion tests/create_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,17 @@ def main():
"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-intracranial_binary.nii.gz",
"mri-processed/mri_dataset/derivatives/sub-01/ses-01/sub-01_ses-01_acq-mixed_T1map.nii.gz",
"mri-processed/mri_dataset/derivatives/sub-01/ses-01/sub-01_ses-01_acq-looklocker_T1map.nii.gz",
"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-aparc+aseg_refined.nii.gz",
"mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_acq-looklocker_T1map_registered.nii.gz",
"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-aseg.nii.gz",
"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-aparc+aseg.nii.gz",
"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-csf-wmparc.nii.gz",
"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-aseg_refined.nii.gz",
"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-aparc+aseg_refined.nii.gz",
"mri-processed/mri_processed_data/sub-01/segmentations/sub-01_seg-wmparc_refined.nii.gz",
"mri-processed/mri_processed_data/sub-01/registered/sub-01_ses-01_T2w_registered.nii.gz",
"freesurfer/mri_processed_data/freesurfer/sub-01/mri/aparc+aseg.mgz",
"freesurfer/mri_processed_data/freesurfer/sub-01/mri/aseg.mgz",
"freesurfer/mri_processed_data/freesurfer/sub-01/mri/wmparc.mgz",
]

for file in files:
Expand Down
Loading
Loading