Skip to content

Commit

Permalink
Merge pull request #936 from EmmaRenauld/reconst_scripts_part4
Browse files Browse the repository at this point in the history
Reconst scripts part4
  • Loading branch information
arnaudbore authored Mar 7, 2024
2 parents a865a55 + 91e5856 commit 9eeb469
Show file tree
Hide file tree
Showing 13 changed files with 183 additions and 162 deletions.
84 changes: 45 additions & 39 deletions scilpy/io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,12 @@ def add_processes_arg(parser):
parser.add_argument('--processes', dest='nbr_processes',
metavar='NBR', type=int, default=1,
help='Number of sub-processes to start. \n'
'Default: [%(default)s]')
'Default: [%(default)s]')


def add_reference_arg(parser, arg_name=None):
if arg_name:
parser.add_argument('--'+arg_name+'_ref',
parser.add_argument('--' + arg_name + '_ref',
help='Reference anatomy for {} (if tck/vtk/fib/dpy'
') file\n'
'support (.nii or .nii.gz).'.format(arg_name))
Expand Down Expand Up @@ -302,34 +302,34 @@ def add_sh_basis_args(parser, mandatory=False, input_output=False):
if input_output:
nargs = 2
def_val = ['descoteaux07_legacy', 'tournier07']
input_output_msg = '\nBoth the input and output bases are ' +\
'required, in that order.'
input_output_msg = ('\nBoth the input and output bases are '
'required, in that order.')
else:
nargs = 1
def_val = ['descoteaux07_legacy']
input_output_msg = ''

choices = ['descoteaux07', 'tournier07', 'descoteaux07_legacy',
'tournier07_legacy']
help_msg = 'Spherical harmonics basis used for the SH coefficients. ' +\
input_output_msg +\
'\nMust be either \'descoteaux07\', \'tournier07\', \n' +\
'\'descoteaux07_legacy\' or \'tournier07_legacy\'' +\
' [%(default)s]:\n' +\
' \'descoteaux07\' : SH basis from the Descoteaux ' +\
'et al.\n' +\
' MRM 2007 paper\n' +\
' \'tournier07\' : SH basis from the new ' +\
'Tournier et al.\n' +\
' NeuroImage 2019 paper, as in ' +\
'MRtrix 3.\n' +\
' \'descoteaux07_legacy\': SH basis from the legacy Dipy ' +\
'implementation\n' +\
' of the ' +\
'Descoteaux et al. MRM 2007 paper\n' +\
' \'tournier07_legacy\' : SH basis from the legacy ' +\
'Tournier et al.\n' +\
' NeuroImage 2007 paper.'
help_msg = ("Spherical harmonics basis used for the SH coefficients. "
"{}\n"
"Must be either descoteaux07', 'tournier07', \n"
"'descoteaux07_legacy' or 'tournier07_legacy' [%(default)s]:\n"
" 'descoteaux07' : SH basis from the Descoteaux "
"et al.\n"
" MRM 2007 paper\n"
" 'tournier07' : SH basis from the new "
"Tournier et al.\n"
" NeuroImage 2019 paper, as in "
"MRtrix 3.\n"
" 'descoteaux07_legacy': SH basis from the legacy Dipy "
"implementation\n"
" of the Descoteaux et al. MRM 2007 "
"paper\n"
" 'tournier07_legacy' : SH basis from the legacy "
"Tournier et al.\n"
" NeuroImage 2007 paper."
.format(input_output_msg))

if mandatory:
arg_name = 'sh_basis'
Expand All @@ -353,10 +353,14 @@ def parse_sh_basis_arg(args):
Returns
-------
sh_basis : string
Spherical harmonic basis name.
is_legacy : bool
Whether or not the SH basis is in its legacy form.
if args.sh_basis is a list of one string:
sh_basis : string
Spherical harmonic basis name.
is_legacy : bool
Whether the SH basis is in its legacy form.
else: (args:sh_basis is a list of two strings)
Returns a Tuple of 4 values:
(sh_basis_in, is_legacy_in, sh_basis_out, is_legacy_out)
"""
sh_basis_name = args.sh_basis[0]
sh_basis = 'descoteaux07' if 'descoteaux07' in sh_basis_name \
Expand All @@ -373,7 +377,7 @@ def parse_sh_basis_arg(args):


def add_nifti_screenshot_default_args(
parser, slice_ids_mandatory=True, transparency_mask_mandatory=True
parser, slice_ids_mandatory=True, transparency_mask_mandatory=True
):
_mask_prefix = "" if transparency_mask_mandatory else "--"

Expand All @@ -385,8 +389,8 @@ def add_nifti_screenshot_default_args(
"the transparency mask are selected."
_output_help = "Name of the output image(s). If multiple slices are " \
"provided (or none), their index will be append to " \
"the name (e.g. volume.jpg, volume.png becomes " \
"volume_slice_0.jpg, volume_slice_0.png)."
"the name (e.g. volume.jpg, volume.png becomes " \
"volume_slice_0.jpg, volume_slice_0.png)."

# Positional arguments
parser.add_argument(
Expand Down Expand Up @@ -422,12 +426,12 @@ def add_nifti_screenshot_default_args(


def add_nifti_screenshot_overlays_args(
parser, labelmap_overlay=True, mask_overlay=True,
transparency_is_overlay=False
parser, labelmap_overlay=True, mask_overlay=True,
transparency_is_overlay=False
):
if labelmap_overlay:
parser.add_argument(
"--in_labelmap", help="Labelmap 3D Nifti image (.nii/.nii.gz).")
"--in_labelmap", help="Labelmap 3D Nifti image (.nii/.nii.gz).")
parser.add_argument(
"--labelmap_cmap_name", default="viridis",
help="Colormap name for the labelmap image data. [%(default)s]")
Expand Down Expand Up @@ -540,6 +544,7 @@ def assert_inputs_exist(parser, required, optional=None):
optional: string or list of paths
Optional paths to be checked.
"""

def check(path):
if not os.path.isfile(path):
parser.error('Input file {} does not exist'.format(path))
Expand Down Expand Up @@ -576,6 +581,7 @@ def assert_outputs_exist(parser, args, required, optional=None,
check_dir_exists: bool
Test if output directory exists.
"""

def check(path):
if os.path.isfile(path) and not args.overwrite:
parser.error('Output file {} exists. Use -f to force '
Expand Down Expand Up @@ -620,6 +626,7 @@ def assert_output_dirs_exist_and_empty(parser, args, required,
create_dir: bool
If true, create the directory if it does not exist.
"""

def check(path):
if not os.path.isdir(path):
if not create_dir:
Expand Down Expand Up @@ -751,18 +758,18 @@ def read_info_from_mb_bdo(filename):
geometry = root.attrib['type']
center_tag = root.find('origin')
flip = [-1, -1, 1]
center = [flip[0]*float(center_tag.attrib['x'].replace(',', '.')),
flip[1]*float(center_tag.attrib['y'].replace(',', '.')),
flip[2]*float(center_tag.attrib['z'].replace(',', '.'))]
center = [flip[0] * float(center_tag.attrib['x'].replace(',', '.')),
flip[1] * float(center_tag.attrib['y'].replace(',', '.')),
flip[2] * float(center_tag.attrib['z'].replace(',', '.'))]
row_list = tree.iter('Row')
radius = [None, None, None]
for i, row in enumerate(row_list):
for j in range(0, 3):
if j == i:
key = 'col' + str(j+1)
key = 'col' + str(j + 1)
radius[i] = float(row.attrib[key].replace(',', '.'))
else:
key = 'col' + str(j+1)
key = 'col' + str(j + 1)
value = float(row.attrib[key].replace(',', '.'))
if abs(value) > 0.01:
raise ValueError('Does not support rotation, for now \n'
Expand Down Expand Up @@ -912,7 +919,6 @@ def range_checker(arg: str):


def get_default_screenshotting_data(args):

volume_img = nib.load(args.in_volume)

transparency_mask_img = None
Expand Down
6 changes: 4 additions & 2 deletions scilpy/reconst/mti.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,10 @@ def adjust_B1_map_header(B1_img, slope):
Parameters
----------
B1_img: B1 nifti image object.
slope: Slope value, obtained from the image header.
B1_img: nifti image object
The B1 map.
slope: float
The slope value, obtained from the image header.
Returns
----------
Expand Down
5 changes: 2 additions & 3 deletions scilpy/reconst/sh.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,8 +577,7 @@ def convert_sh_basis(shm_coeff, sphere, mask=None,
mask = np.sum(shm_coeff, axis=3).astype(bool)

nbr_processes = multiprocessing.cpu_count() \
if nbr_processes is None or nbr_processes < 0 \
else nbr_processes
if nbr_processes is None or nbr_processes < 0 else nbr_processes

# Ravel the first 3 dimensions while keeping the 4th intact, like a list of
# 1D time series voxels. Then separate it in chunks of len(nbr_processes).
Expand Down Expand Up @@ -647,7 +646,7 @@ def convert_sh_to_sf(shm_coeff, sphere, mask=None, dtype="float32",
If True, use a full SH basis (even and odd orders) for the input SH
coefficients.
is_input_legacy : bool, optional
Whether or not the input basis is in its legacy form.
Whether the input basis is in its legacy form.
nbr_processes: int, optional
The number of subprocesses to use.
Default: multiprocessing.cpu_count()
Expand Down
61 changes: 37 additions & 24 deletions scripts/scil_NODDI_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@
add_verbose_arg,
assert_inputs_exist,
assert_output_dirs_exist_and_empty,
redirect_stdout_c)
from scilpy.gradients.bvec_bval_tools import identify_shells
redirect_stdout_c, add_tolerance_arg,
add_skip_b0_check_arg)
from scilpy.gradients.bvec_bval_tools import (check_b0_threshold,
identify_shells)


EPILOG = """
Reference:
Expand All @@ -41,7 +44,7 @@
def _build_arg_parser():
p = argparse.ArgumentParser(
description=__doc__, epilog=EPILOG,
formatter_class=argparse.RawDescriptionHelpFormatter)
formatter_class=argparse.RawTextHelpFormatter)

p.add_argument('in_dwi',
help='DWI file acquired with a NODDI compatible protocol '
Expand All @@ -56,10 +59,9 @@ def _build_arg_parser():
p.add_argument('--out_dir', default="results",
help='Output directory for the NODDI results. '
'[%(default)s]')
p.add_argument('--b_thr', type=int, default=40,
help='Limit value to consider that a b-value is on an '
'existing shell. Above this limit, the b-value is '
'placed on a new shell. This includes b0s values.')
add_tolerance_arg(p)
add_skip_b0_check_arg(p, will_overwrite_with_min=False,
b0_tol_name='--tolerance')

g1 = p.add_argument_group(title='Model options')
g1.add_argument('--para_diff', type=float, default=1.7e-3,
Expand Down Expand Up @@ -101,56 +103,66 @@ def main():
logging.getLogger().setLevel(logging.getLevelName(args.verbose))
redirected_stdout = redirect_stdout(sys.stdout)

# Verifications
if args.compute_only and not args.save_kernels:
parser.error('--compute_only must be used with --save_kernels.')

assert_inputs_exist(parser, [args.in_dwi, args.in_bval, args.in_bvec],
args.mask)

assert_output_dirs_exist_and_empty(parser, args,
args.out_dir,
assert_output_dirs_exist_and_empty(parser, args, args.out_dir,
optional=args.save_kernels)

# Generage a scheme file from the bvals and bvecs files
# Generate a scheme file from the bvals and bvecs files
bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec)
_ = check_b0_threshold(bvals.min(), b0_thr=args.tolerance,
skip_b0_check=args.skip_b0_check)
shells_centroids, indices_shells = identify_shells(bvals, args.tolerance,
round_centroids=True)

nb_shells = len(shells_centroids)
if nb_shells <= 1:
raise ValueError("Amico's NODDI works with data with more than one "
"shell, but you seem to have single-shell data (we "
"found shells {}). Change tolerance if necessary."
.format(np.sort(shells_centroids)))

logging.info('Will compute NODDI with AMICO on {} shells at found at {}.'
.format(len(shells_centroids), np.sort(shells_centroids)))

# Save the resulting bvals to a temporary file
tmp_dir = tempfile.TemporaryDirectory()
tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.b')
tmp_bval_filename = os.path.join(tmp_dir.name, 'bval')
bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec)
shells_centroids, indices_shells = identify_shells(bvals,
args.b_thr,
round_centroids=True)
np.savetxt(tmp_bval_filename, shells_centroids[indices_shells],
newline=' ', fmt='%i')
fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename)
logging.info('Compute NODDI with AMICO on {} shells at found '
'at {}.'.format(len(shells_centroids), shells_centroids))

with redirected_stdout:
# Load the data
amico.core.setup()
ae = amico.Evaluation('.', '.')
ae.load_data(args.in_dwi,
tmp_scheme_filename,
mask_filename=args.mask)
ae.load_data(args.in_dwi, tmp_scheme_filename, mask_filename=args.mask)

# Compute the response functions
ae.set_model("NODDI")

intra_vol_frac = np.linspace(0.1, 0.99, 12)
intra_orient_distr = np.hstack((np.array([0.03, 0.06]),
np.linspace(0.09, 0.99, 10)))

ae.model.set(args.para_diff, args.iso_diff,
intra_vol_frac, intra_orient_distr,
False)
ae.model.set(dPar=args.para_diff, dIso=args.iso_diff,
IC_VFs=intra_vol_frac, IC_ODs=intra_orient_distr,
isExvivo=False)
ae.set_solver(lambda1=args.lambda1, lambda2=args.lambda2)

# The kernels are, by default, set to be in the current directory
# Depending on the choice, manually change the saving location
if args.save_kernels:
kernels_dir = os.path.join(args.save_kernels)
kernels_dir = args.save_kernels
regenerate_kernels = True
elif args.load_kernels:
kernels_dir = os.path.join(args.load_kernels)
kernels_dir = args.load_kernels
regenerate_kernels = False
else:
kernels_dir = os.path.join(tmp_dir.name, 'kernels', ae.model.id)
Expand All @@ -166,6 +178,7 @@ def main():

# Model fit
ae.fit()

# Save the results
ae.save_results()

Expand Down
Loading

0 comments on commit 9eeb469

Please sign in to comment.