diff --git a/scilpy/reconst/fiber_coherence.py b/scilpy/reconst/fiber_coherence.py index 410569d259..1bd78ddff8 100644 --- a/scilpy/reconst/fiber_coherence.py +++ b/scilpy/reconst/fiber_coherence.py @@ -6,10 +6,23 @@ NB_FLIPS = 4 ANGLE_TH = np.pi / 6. +# directions to the 26 neighbors. +# Preparing once rather than in _compute_fiber_coherence, possibly called many +# times. +all_d = np.indices((3, 3, 3)) +all_d = all_d.T.reshape((27, 3)) - 1 +all_d = np.delete(all_d, 13, axis=0) -def compute_fiber_coherence_table(directions, values): + +def compute_fiber_coherence_fliptable(directions, values): """ - Compute fiber coherence indexes for all possible axes permutations/flips. + Compute fiber coherence indexes for all possible axes permutations/flips + (ex, originating from a flip in the gradient table). + + The mathematics are presented in : + [1] Schilling et al. A fiber coherence index for quality control of B-table + orientation in diffusion MRI scans. Magn Reson Imaging. 2019 May;58:82-89. + doi: 10.1016/j.mri.2019.01.018. Parameters ---------- @@ -26,11 +39,11 @@ def compute_fiber_coherence_table(directions, values): Transform representing each permutation/flip, in the same order as `coherence` list. """ + # Generate transforms for 24 possible permutation/flips of + # gradient directions. (Reminder. We want to verify if there was possibly + # a flip in the gradient table). permutations = list(itertools.permutations([0, 1, 2])) transforms = np.zeros((len(permutations)*NB_FLIPS, 3, 3)) - - # Generate transforms for 24 possible permutation/flips of - # gradient directions for i in range(len(permutations)): transforms[i*NB_FLIPS, np.arange(3), permutations[i]] = 1 for ii in range(3): @@ -38,6 +51,7 @@ def compute_fiber_coherence_table(directions, values): flip[ii, ii] = -1 transforms[ii+i*NB_FLIPS+1] = transforms[i*NB_FLIPS].dot(flip) + # Compute the coherence for each one. coherence = [] for t in transforms: index = compute_fiber_coherence(directions.dot(t), values) @@ -61,11 +75,7 @@ def compute_fiber_coherence(peaks, values): coherence: float Fiber coherence value. """ - # directions to neighbors - all_d = np.indices((3, 3, 3)) - all_d = all_d.T.reshape((27, 3)) - 1 - all_d = np.delete(all_d, 13, axis=0) - + # Normalizing peaks norm_peaks = np.zeros_like(peaks) norms = np.linalg.norm(peaks, axis=-1) norm_peaks[norms > 0] = peaks[norms > 0] / norms[norms > 0][..., None] @@ -78,13 +88,27 @@ def compute_fiber_coherence(peaks, values): slice_z = slice(1 + tz, peaks.shape[2] - 1 + tz) di_norm = di / np.linalg.norm(di) - I_u = np.abs(norm_peaks.dot(di_norm)) > np.cos(ANGLE_TH) + + # Spatial coherence between the peak at each voxel and the direction to + # the neighbor di. + # Ex: if the peak is aligned in x and current di is aligned in x, + # returns True (with angle < 30 ; cos angle > 30) + cos_angles = np.abs(norm_peaks.dot(di_norm)) + I_u = cos_angles > np.cos(ANGLE_TH) + + # Doing the same thing with v; results in the same image but translated + # from one voxel. (With 1 voxel padding around the border). I_v = np.zeros_like(I_u) I_v[1:-1, 1:-1, 1:-1] = I_u[slice_x, slice_y, slice_z] + # Where both conditions are met: I_uv = np.logical_and(I_u, I_v) u = np.nonzero(I_uv) + + # v = the same voxels as u, but with the neighborhood difference. v = tuple(np.array(u) + di.astype(int).reshape(3, 1)) + + # Summing the FA of those voxels coherence += np.sum(values[u]) + np.sum(values[v]) return coherence diff --git a/scilpy/reconst/tests/test_fiber_coherence.py b/scilpy/reconst/tests/test_fiber_coherence.py index 3f6feaa95a..33cd9d9a4f 100644 --- a/scilpy/reconst/tests/test_fiber_coherence.py +++ b/scilpy/reconst/tests/test_fiber_coherence.py @@ -1,11 +1,70 @@ # -*- coding: utf-8 -*- +import numpy as np +from scilpy.reconst.fiber_coherence import (compute_fiber_coherence_fliptable, + compute_fiber_coherence) -def test_compute_fiber_coherence_table(): - # toDO - pass + +def test_compute_fiber_coherence_fliptable(): + # Just checking that we get 24 values. + # See below for the real tests. + directions = np.zeros((3, 3, 5, 3), dtype=float) + fa = np.zeros((3, 3, 5), dtype=float) + coherence, transforms = compute_fiber_coherence_fliptable(directions, fa) + assert len(coherence) == 24 + assert len(transforms) == 24 def test_compute_fiber_coherence(): - # toDO - pass + # Coherence will be strong if we have voxels were the peak points towards + # the neighbor. Ex: Imagine the corpus callosum, where the voxels in X + # all have peaks in X. + + # Test 1. + # Aligned on the last dimension (z), we have 4 peaks all pointing in the + # z direction, with strong FA. + directions = np.zeros((3, 3, 5, 3), dtype=float) + directions[1, 1, :, :] = np.asarray([[0, 0, 1], + [0, 0, 1], + [0, 0, 1], + [0, 0, -1], + [0, 0, 0]], dtype=float) + fa = np.zeros((3, 3, 5), dtype=float) + fa[1, 1, :] = [1, 1, 1, 1, 0] + + # There should be a good coherence (actually we get 10). + coherence1 = compute_fiber_coherence(directions, fa) + assert coherence1 > 0 + + # Test 2. Testing symmetry: reversing the 4th voxel should not change the + # result + directions[1, 1, :, :] = np.asarray([[0, 0, 1], + [0, 0, 1], + [0, 0, 1], + [0, 0, 1], + [0, 0, 0]], dtype=float) + coherence2 = compute_fiber_coherence(directions, fa) + assert coherence2 == coherence1 + + # Test 3 + # Same directions, but with low FA + fa = np.zeros((3, 3, 5), dtype=float) + fa[1, 1, :] = [0.2, 0.2, 0.2, 0.2, 0] + + # There should be a good coherence (actually we get 2). + coherence3 = compute_fiber_coherence(directions, fa) + assert coherence3 < coherence2 + + # Test 4. Voxels with non-zero peaks still have peaks in z, but they are + # aligned in y. + directions = np.zeros((3, 5, 3, 3), dtype=float) + directions[1, :, 1, :] = np.asarray([[0, 0, 1], + [0, 0, 1], + [0, 0, 1], + [0, 0, -1], + [0, 0, 0]], dtype=float) + fa = np.zeros((3, 3, 5), dtype=float) + fa[1, 1, :] = [1, 1, 1, 1, 0] + coherence4 = compute_fiber_coherence(directions, fa) + assert coherence4 == 0 + diff --git a/scripts/scil_gradients_validate_correct.py b/scripts/scil_gradients_validate_correct.py index 41e0929118..a0c9a197d8 100755 --- a/scripts/scil_gradients_validate_correct.py +++ b/scripts/scil_gradients_validate_correct.py @@ -31,7 +31,7 @@ from scilpy.io.utils import (add_overwrite_arg, assert_inputs_exist, assert_outputs_exist, add_verbose_arg) from scilpy.io.image import get_data_as_mask -from scilpy.reconst.fiber_coherence import compute_fiber_coherence_table +from scilpy.reconst.fiber_coherence import compute_fiber_coherence_fliptable EPILOG = """ @@ -111,7 +111,7 @@ def main(): peaks[np.logical_not(mask)] = 0 peaks[fa < args.fa_th] = 0 - coherence, transform = compute_fiber_coherence_table(peaks, fa) + coherence, transform = compute_fiber_coherence_fliptable(peaks, fa) best_t = transform[np.argmax(coherence)] if (best_t == np.eye(3)).all():