Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SH basis legacy support #921

Merged
merged 20 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions scilpy/denoise/asym_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
def angle_aware_bilateral_filtering(in_sh, sh_order=8,
sh_basis='descoteaux07',
in_full_basis=False,
is_legacy=True,
sphere_str='repulsion724',
sigma_spatial=1.0, sigma_angular=1.0,
sigma_range=0.5, use_gpu=True):
Expand All @@ -27,6 +28,8 @@ def angle_aware_bilateral_filtering(in_sh, sh_order=8,
Name of SH basis used.
in_full_basis: bool, optional
True if input is expressed in full SH basis.
is_legacy : bool, optional
Whether or not the SH basis is in its legacy form.
sphere_str: str, optional
Name of the DIPY sphere to use for sh to sf projection.
sigma_spatial: float, optional
Expand All @@ -46,6 +49,7 @@ def angle_aware_bilateral_filtering(in_sh, sh_order=8,
if use_gpu and have_opencl:
return angle_aware_bilateral_filtering_gpu(in_sh, sh_order,
sh_basis, in_full_basis,
is_legacy,
sphere_str, sigma_spatial,
sigma_angular, sigma_range)
elif use_gpu and not have_opencl:
Expand All @@ -54,13 +58,15 @@ def angle_aware_bilateral_filtering(in_sh, sh_order=8,
else:
return angle_aware_bilateral_filtering_cpu(in_sh, sh_order,
sh_basis, in_full_basis,
is_legacy,
sphere_str, sigma_spatial,
sigma_angular, sigma_range)


def angle_aware_bilateral_filtering_gpu(in_sh, sh_order=8,
sh_basis='descoteaux07',
in_full_basis=False,
is_legacy=True,
sphere_str='repulsion724',
sigma_spatial=1.0,
sigma_angular=1.0,
Expand All @@ -78,6 +84,8 @@ def angle_aware_bilateral_filtering_gpu(in_sh, sh_order=8,
Name of SH basis used.
in_full_basis: bool, optional
True if input is expressed in full SH basis.
is_legacy : bool, optional
Whether or not the SH basis is in its legacy form.
sphere_str: str, optional
Name of the DIPY sphere to use for sh to sf projection.
sigma_spatial: float, optional
Expand All @@ -104,11 +112,13 @@ def angle_aware_bilateral_filtering_gpu(in_sh, sh_order=8,
sh_to_sf_mat = sh_to_sf_matrix(sphere, sh_order=sh_order,
basis_type=sh_basis,
full_basis=in_full_basis,
legacy=is_legacy,
return_inv=False)

_, sf_to_sh_mat = sh_to_sf_matrix(sphere, sh_order=sh_order,
basis_type=sh_basis,
full_basis=True,
legacy=is_legacy,
return_inv=True)

out_n_coeffs = sf_to_sh_mat.shape[1]
Expand Down Expand Up @@ -150,6 +160,7 @@ def angle_aware_bilateral_filtering_gpu(in_sh, sh_order=8,
def angle_aware_bilateral_filtering_cpu(in_sh, sh_order=8,
sh_basis='descoteaux07',
in_full_basis=False,
is_legacy=True,
sphere_str='repulsion724',
sigma_spatial=1.0,
sigma_angular=1.0,
Expand All @@ -168,6 +179,8 @@ def angle_aware_bilateral_filtering_cpu(in_sh, sh_order=8,
Name of SH basis used.
in_full_basis: bool, optional
True if input is expressed in full SH basis.
is_legacy : bool, optional
Whether or not the SH basis is in its legacy form.
sphere_str: str, optional
Name of the DIPY sphere to use for sh to sf projection.
sigma_spatial: float, optional
Expand All @@ -194,7 +207,8 @@ def angle_aware_bilateral_filtering_cpu(in_sh, sh_order=8,

nb_sf = len(sphere.vertices)
B = sh_to_sf_matrix(sphere, sh_order=sh_order, basis_type=sh_basis,
return_inv=False, full_basis=in_full_basis)
return_inv=False, full_basis=in_full_basis,
legacy=is_legacy)

mean_sf = np.zeros(in_sh.shape[:-1] + (nb_sf,))

Expand All @@ -209,7 +223,7 @@ def angle_aware_bilateral_filtering_cpu(in_sh, sh_order=8,

# Convert back to SH coefficients
_, B_inv = sh_to_sf_matrix(sphere, sh_order=sh_order, basis_type=sh_basis,
full_basis=True)
full_basis=True, legacy=is_legacy)
out_sh = np.array([np.dot(i, B_inv) for i in mean_sf], dtype=in_sh.dtype)
# By default, return only asymmetric SH
return out_sh
Expand Down Expand Up @@ -371,7 +385,7 @@ def _correlate_spatial(image_u, h_filter, sigma_range):


def cosine_filtering(in_sh, sh_order=8, sh_basis='descoteaux07',
in_full_basis=False, dot_sharpness=1.0,
in_full_basis=False, is_legacy=True, dot_sharpness=1.0,
sphere_str='repulsion724', sigma=1.0):
"""
Average the SH projected on a sphere using a first-neighbor gaussian
Expand All @@ -389,6 +403,8 @@ def cosine_filtering(in_sh, sh_order=8, sh_basis='descoteaux07',
SH basis of the input signal.
in_full_basis: bool, optional
True if the input is in full SH basis.
is_legacy : bool, optional
Whether or not the SH basis is in its legacy form.
dot_sharpness: float, optional
Exponent of the dot product. When set to 0.0, directions
are not weighted by the dot product.
Expand All @@ -411,13 +427,14 @@ def cosine_filtering(in_sh, sh_order=8, sh_basis='descoteaux07',
nb_sf = len(sphere.vertices)
mean_sf = np.zeros(np.append(in_sh.shape[:-1], nb_sf))
B = sh_to_sf_matrix(sphere, sh_order=sh_order, basis_type=sh_basis,
return_inv=False, full_basis=in_full_basis)
return_inv=False, full_basis=in_full_basis,
legacy=is_legacy)

# We want a B matrix to project on an inverse sphere to have the sf on
# the opposite hemisphere for a given vertice
neg_B = sh_to_sf_matrix(Sphere(xyz=-sphere.vertices), sh_order=sh_order,
basis_type=sh_basis, return_inv=False,
full_basis=in_full_basis)
full_basis=in_full_basis, legacy=is_legacy)

# Apply filter to each sphere vertice
for sf_i in range(nb_sf):
Expand All @@ -435,7 +452,8 @@ def cosine_filtering(in_sh, sh_order=8, sh_basis='descoteaux07',
# Convert back to SH coefficients
_, B_inv = sh_to_sf_matrix(sphere, sh_order=sh_order,
basis_type=sh_basis,
full_basis=True)
full_basis=True,
legacy=is_legacy)

out_sh = np.array([np.dot(i, B_inv) for i in mean_sf], dtype=in_sh.dtype)
return out_sh
Expand Down
4 changes: 2 additions & 2 deletions scilpy/denoise/tests/test_asym_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_angle_aware_bilateral_filtering():

sh_order, full_basis = get_sh_order_and_fullness(in_sh.shape[-1])
out = angle_aware_bilateral_filtering_cpu(in_sh, sh_order,
sh_basis, full_basis,
sh_basis, full_basis, True,
sphere_str, sigma_spatial,
sigma_angular, sigma_range)

Expand All @@ -40,7 +40,7 @@ def test_cosine_filtering():
sharpness = 1.0

sh_order, full_basis = get_sh_order_and_fullness(in_sh.shape[-1])
out = cosine_filtering(in_sh, sh_order, sh_basis, full_basis,
out = cosine_filtering(in_sh, sh_order, sh_basis, full_basis, True,
sharpness, sphere_str, sigma_spatial)

assert np.allclose(out, fodf_3x3_order8_descoteaux07_filtered_cosine)
79 changes: 69 additions & 10 deletions scilpy/io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,36 +240,95 @@ def add_bbox_arg(parser):
'streamlines).')


def add_sh_basis_args(parser, mandatory=False):
"""Add spherical harmonics (SH) bases argument.
def add_sh_basis_args(parser, mandatory=False, input_output=False):
"""
Add spherical harmonics (SH) bases argument. For more information about
the bases, see https://docs.dipy.org/stable/theory/sh_basis.html.

Parameters
----------
parser: argparse.ArgumentParser object
Parser.
mandatory: bool
Whether this argument is mandatory.
input_output: bool
Whether this argument should expect both input and output bases or not.
If set, the sh_basis argument will expect first the input basis,
followed by the output basis.
"""
choices = ['descoteaux07', 'tournier07']
def_val = 'descoteaux07'
if input_output:
nargs = 2
def_val = ['descoteaux07_legacy', 'tournier07']
input_output_msg = '\nBoth the input and output bases are ' +\
'required, in that order.'
else:
nargs = 1
def_val = ['descoteaux07_legacy']
input_output_msg = ''

choices = ['descoteaux07', 'tournier07', 'descoteaux07_legacy',
'tournier07_legacy']
help_msg = 'Spherical harmonics basis used for the SH coefficients. ' +\
'\nMust be either \'descoteaux07\' or \'tournier07\'' +\
input_output_msg +\
'\nMust be either \'descoteaux07\', \'tournier07\', \n' +\
'\'descoteaux07_legacy\' or \'tournier07_legacy\'' +\
' [%(default)s]:\n' +\
' \'descoteaux07\': SH basis from the Descoteaux et al.\n' +\
' MRM 2007 paper\n' +\
' \'tournier07\' : SH basis from the Tournier et al.\n' +\
' NeuroImage 2007 paper.'
' \'descoteaux07\' : SH basis from the Descoteaux ' +\
'et al.\n' +\
' MRM 2007 paper\n' +\
' \'tournier07\' : SH basis from the new ' +\
'Tournier et al.\n' +\
' NeuroImage 2019 paper, as in ' +\
'MRtrix 3.\n' +\
' \'descoteaux07_legacy\': SH basis from the legacy Dipy ' +\
'implementation\n' +\
' of the ' +\
'Descoteaux et al. MRM 2007 paper\n' +\
' \'tournier07_legacy\' : SH basis from the legacy ' +\
'Tournier et al.\n' +\
' NeuroImage 2007 paper.'

if mandatory:
arg_name = 'sh_basis'
else:
arg_name = '--sh_basis'

parser.add_argument(arg_name,
parser.add_argument(arg_name, nargs=nargs,
choices=choices, default=def_val,
help=help_msg)


def parse_sh_basis_arg(args):
"""
Parser the input from args.sh_basis. If two SH bases are given,
both input/output sh_basis and is_legacy are returned.

Parameters
----------
args : ArgumentParser.parse_args
ArgumentParser.parse_args from a script.

Returns
-------
sh_basis : string
Spherical harmonic basis name.
is_legacy : bool
Whether or not the SH basis is in its legacy form.
"""
sh_basis_name = args.sh_basis[0]
sh_basis = 'descoteaux07' if 'descoteaux07' in sh_basis_name \
else 'tournier07'
is_legacy = 'legacy' in sh_basis_name
if len(args.sh_basis) == 2:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it works but you could also simply do

 basis = 'descoteaux07' if 'descoteaux07' in sh_basis_name else 'tournier07'
 legacy = 'legacy' in sh_basis_name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but I prefer the more "general" way, without assuming descoteaux07 or tournier07. @arnaudbore ? I think Charles refers to lines 313 to 316 and 319 to 320.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would agree with you if we were not using choices in add_sh_basis_args where we are already assuming descoteaux07 and tournier07. I would go for @CHrlS98 suggestion more elegant 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it better now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool

sh_basis_name = args.sh_basis[1]
out_sh_basis = 'descoteaux07' if 'descoteaux07' in sh_basis_name \
else 'tournier07'
is_out_legacy = 'legacy' in sh_basis_name
return sh_basis, is_legacy, out_sh_basis, is_out_legacy
else:
return sh_basis, is_legacy


def add_nifti_screenshot_default_args(
parser, slice_ids_mandatory=True, transparency_mask_mandatory=True
):
Expand Down
18 changes: 12 additions & 6 deletions scilpy/reconst/fodf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
from dipy.data import get_sphere
from dipy.reconst.mcsd import MSDeconvFit
from dipy.reconst.multi_voxel import MultiVoxelFit
from dipy.reconst.shm import sh_to_sf_matrix

from scilpy.reconst.utils import find_order_from_nb_coeff, get_b_matrix
from scilpy.reconst.utils import find_order_from_nb_coeff

from dipy.utils.optpkg import optional_package
cvx, have_cvxpy, _ = optional_package("cvxpy")


def get_ventricles_max_fodf(data, fa, md, zoom, args):
def get_ventricles_max_fodf(data, fa, md, zoom, sh_basis, args,
is_legacy=True):
"""
Compute mean maximal fodf value in ventricules. Given
heuristics thresholds on FA and MD values, finds the
Expand All @@ -30,9 +32,13 @@ def get_ventricles_max_fodf(data, fa, md, zoom, args):
FA (Fractional Anisotropy) volume from DTI
md: ndarray (x, y, z)
MD (Mean Diffusivity) volume from DTI
vol: int > 0
Maximum Nnumber of voxels used to compute the mean.
zoom: int > 0
Maximum number of voxels used to compute the mean.
1000 works well at 2x2x2 = 8 mm3
sh_basis: str
Either 'tournier07' or 'descoteaux07'
is_legacy : bool, optional
Whether or not the SH basis is in its legacy form.

Returns
-------
Expand All @@ -42,7 +48,7 @@ def get_ventricles_max_fodf(data, fa, md, zoom, args):

order = find_order_from_nb_coeff(data)
sphere = get_sphere('repulsion100')
b_matrix = get_b_matrix(order, sphere, args.sh_basis)
b_matrix, _ = sh_to_sf_matrix(sphere, order, sh_basis, legacy=is_legacy)
sum_of_max = 0
count = 0

Expand Down Expand Up @@ -86,7 +92,7 @@ def get_ventricles_max_fodf(data, fa, md, zoom, args):
continue
if fa[i, j, k] < args.fa_threshold \
and md[i, j, k] > args.md_threshold:
sf = np.dot(data[i, j, k], b_matrix.T)
sf = np.dot(data[i, j, k], b_matrix)
sum_of_max += sf.max()
count += 1
mask[i, j, k] = 1
Expand Down
Loading