Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ scil_sh_fusion = "scilpy.cli.scil_sh_fusion:main"
scil_sh_to_aodf = "scilpy.cli.scil_sh_to_aodf:main"
scil_sh_to_rish = "scilpy.cli.scil_sh_to_rish:main"
scil_sh_to_sf = "scilpy.cli.scil_sh_to_sf:main"
scil_fodf_global_sf_threshold = "scilpy.cli.scil_fodf_global_sf_threshold:main"
scil_stats_group_comparison = "scilpy.cli.scil_stats_group_comparison:main"
scil_surface_assign_custom_color = "scilpy.cli.scil_surface_assign_custom_color:main"
scil_surface_assign_uniform_color = "scilpy.cli.scil_surface_assign_uniform_color:main"
Expand Down
91 changes: 91 additions & 0 deletions src/scilpy/cli/scil_fodf_global_sf_threshold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Compute a binary mask based on a global SF threshold.
The script masks voxels where the maximum SF amplitude is below
either a relative factor or an absolute threshold.

When fODFs are evaluated on a sphere (SF), the amplitude of the lobes
corresponds to the strength of the diffusion signal in those directions.
Thresholding these amplitudes is a common practice to filter out spurious
peaks arising from noise or the deconvolution process (e.g., ringing effects).

The absolute threshold can be estimated from the mean/median maximum fODF in the
ventricles, computed with scil_fodf_max_in_ventricles.
"""

import argparse
import logging

import nibabel as nib
import numpy as np

from scilpy.io.utils import (add_sh_basis_args, add_sphere_arg,
add_verbose_arg, add_overwrite_arg,
assert_inputs_exist, assert_outputs_exist,
parse_sh_basis_arg)
from scilpy.reconst.utils import compute_sf_threshold_mask
from scilpy.version import version_string


def _build_arg_parser():
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawTextHelpFormatter,
epilog=version_string)

p.add_argument('in_odf',
help='Input ODF file (SH or Peaks) (.nii.gz).')
p.add_argument('out_mask',
help='Output binary mask (.nii.gz).')

thr_g = p.add_mutually_exclusive_group(required=True)
thr_g.add_argument('--relative', type=float,
help='Global SF threshold relative factor (0-1).')
thr_g.add_argument('--absolute', type=float,
help='Global SF absolute threshold.')

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it need to be mutually exclusive? Could be both, no?

add_sh_basis_args(p)
add_sphere_arg(p)
add_overwrite_arg(p)
add_verbose_arg(p)

return p


def main():
parser = _build_arg_parser()
args = parser.parse_args()
logging.getLogger().setLevel(logging.getLevelName(args.verbose))

assert_inputs_exist(parser, args.in_odf)
assert_outputs_exist(parser, args, args.out_mask)

sh_basis, is_legacy = parse_sh_basis_arg(args)

logging.info("Loading ODF data.")
img = nib.load(args.in_odf)
data = img.get_fdata(dtype=np.float32)

logging.info("Computing global SF threshold mask.")
mask, global_max, threshold = compute_sf_threshold_mask(
data, sphere_name=args.sphere, relative_factor=args.relative,
absolute_threshold=args.absolute, sh_basis=sh_basis,
is_legacy=is_legacy)

logging.info("Global max SF amplitude: {:.4f}".format(global_max))
if args.relative is not None:
logging.info("Relative threshold: {:.4f} (Factor: {})"
.format(threshold, args.relative))
else:
logging.info("Absolute threshold used: {:.4f}".format(args.absolute))

logging.info("Number of voxels in mask: {}".format(np.sum(mask)))

# Save mask
mask_img = nib.Nifti1Image(mask.astype(np.uint8), img.affine,
img.header)
nib.save(mask_img, args.out_mask)


if __name__ == "__main__":
main()
49 changes: 33 additions & 16 deletions src/scilpy/cli/scil_tracking_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
The tracking direction is chosen in the aperture cone defined by the
previous tracking direction and the angular constraint.

WARNING: This script DOES NOT support asymetric FODF input (aFODF).
WARNING: This script DOES NOT support asymmetric FODF input (aFODF).

Algo 'eudx': select the peak from the spherical function (SF) most closely
aligned to the previous direction, and follow an average of it and the previous
Expand Down Expand Up @@ -41,9 +41,10 @@
* Forward tracking: For GPU tracking, the `--forward_only` flag can be used
to disable backward tracking. This option isn't available for CPU
tracking.
* Random number generator seed (RNG): CPU and GPU use different RNG implementations,<
so the same `--seed` is reproducible within a backend but does not guarantee
identical streamlines across CPU vs GPU tracking.
* Random number generator seed (RNG): CPU and GPU use different RNG
implementations, so the same `--seed` is reproducible within a
backend but does not guarantee identical streamlines across CPU vs
GPU tracking.

All the input nifti files must be in isotropic resolution.

Expand All @@ -61,20 +62,20 @@
import logging
from time import perf_counter

import nibabel as nib
import numpy as np
from nibabel.streamlines import TrkFile, detect_format

from dipy.data import get_sphere
from dipy.tracking import utils as track_utils
from dipy.tracking.local_tracking import LocalTracking
from dipy.tracking.stopping_criterion import BinaryStoppingCriterion
from dipy.tracking.tracker import eudx_tracking
import nibabel as nib
from nibabel.streamlines import TrkFile, detect_format
import numpy as np

from scilpy.io.image import get_data_as_mask
from scilpy.io.utils import (add_sphere_arg, add_verbose_arg,
assert_headers_compatible, assert_inputs_exist,
assert_outputs_exist, parse_sh_basis_arg,
verify_compression_th, load_matrix_in_any_format)
assert_outputs_exist, load_matrix_in_any_format,
parse_sh_basis_arg, verify_compression_th)
from scilpy.tracking.tracker import GPUTracker
from scilpy.tracking.utils import (add_mandatory_options_tracking,
add_out_options, add_seeding_options,
Expand Down Expand Up @@ -104,7 +105,7 @@ def _build_arg_parser():

# Other options, only available in this script:
track_g.add_argument('--sh_to_pmf', action='store_true',
help='If set, map sherical harmonics to spherical '
help='If set, map spherical harmonics to spherical '
'function (pmf) before \ntracking (faster, '
'requires more memory)')
track_g.add_argument('--algo', default='prob',
Expand Down Expand Up @@ -200,6 +201,18 @@ def main():
logging.debug("Loading masks and finding seeds.")
mask_data = get_data_as_mask(nib.load(args.in_mask), dtype=bool)

# ODF data for thresholding
odf_sh_data = odf_sh_img.get_fdata(dtype=np.float32)

sh_basis, is_legacy = parse_sh_basis_arg(args)

sf_mask = None
if args.global_sf_rel_thr is not None or \
args.global_sf_abs_thr is not None:
from scilpy.tracking.utils import get_global_sf_threshold_mask

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason why this is a conditional import?

sf_mask = get_global_sf_threshold_mask(
odf_sh_data, args, sh_basis, is_legacy)

if args.npv:
nb_seeds = args.npv
seed_per_vox = True
Expand Down Expand Up @@ -235,11 +248,16 @@ def main():
random_seed=args.seed)
total_nb_seeds = len(seeds)

combined_mask = mask_data
if sf_mask is not None:
combined_mask = np.logical_and(mask_data, sf_mask)

if not args.use_gpu:
# LocalTracking.maxlen is actually the maximum length
# per direction, we need to filter post-tracking.
max_steps_per_direction = int(args.max_length / args.step_size)
stopping_criterion = BinaryStoppingCriterion(mask_data)

stopping_criterion = BinaryStoppingCriterion(combined_mask)

logging.info("Starting CPU local tracking.")
if args.algo == 'eudx':
Expand All @@ -248,7 +266,7 @@ def main():
stopping_criterion,
np.eye(4),
pam=get_direction_getter(
args.in_odf, args.algo, args.sphere,
odf_sh_data, args.algo, args.sphere,
args.sub_sphere, args.theta, sh_basis,
voxel_size, args.sf_threshold, args.sh_to_pmf,
args.probe_length, args.probe_radius,
Expand All @@ -264,7 +282,7 @@ def main():
else:
streamlines_generator = LocalTracking(
get_direction_getter(
args.in_odf, args.algo, args.sphere,
odf_sh_data, args.algo, args.sphere,
args.sub_sphere, args.theta, sh_basis,
voxel_size, args.sf_threshold, args.sh_to_pmf,
args.probe_length, args.probe_radius,
Expand All @@ -284,14 +302,13 @@ def main():
max_strl_len = int(2.0 * args.max_length / args.step_size) + 1

# data volume
odf_sh = odf_sh_img.get_fdata(dtype=np.float32)

# GPU tracking needs the full sphere
sphere = get_sphere(name=args.sphere).subdivide(n=args.sub_sphere)

logging.info("Starting GPU local tracking.")
streamlines_generator = GPUTracker(
odf_sh, mask_data, seeds,
odf_sh_data, combined_mask, seeds,
vox_step_size, max_strl_len,
theta=get_theta(args.theta, args.algo),
sf_threshold=args.sf_threshold,
Expand Down
Loading
Loading