Skip to content

Commit

Permalink
Unit tests for the fiber coherence
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed Feb 29, 2024
1 parent 4c2848c commit 5ecc85d
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 18 deletions.
46 changes: 35 additions & 11 deletions scilpy/reconst/fiber_coherence.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,23 @@
NB_FLIPS = 4
ANGLE_TH = np.pi / 6.

# directions to the 26 neighbors.
# Preparing once rather than in _compute_fiber_coherence, possibly called many
# times.
all_d = np.indices((3, 3, 3))
all_d = all_d.T.reshape((27, 3)) - 1
all_d = np.delete(all_d, 13, axis=0)

def compute_fiber_coherence_table(directions, values):

def compute_fiber_coherence_fliptable(directions, values):
"""
Compute fiber coherence indexes for all possible axes permutations/flips.
Compute fiber coherence indexes for all possible axes permutations/flips
(ex, originating from a flip in the gradient table).
The mathematics are presented in :
[1] Schilling et al. A fiber coherence index for quality control of B-table
orientation in diffusion MRI scans. Magn Reson Imaging. 2019 May;58:82-89.
doi: 10.1016/j.mri.2019.01.018.
Parameters
----------
Expand All @@ -26,18 +39,19 @@ def compute_fiber_coherence_table(directions, values):
Transform representing each permutation/flip, in the same
order as `coherence` list.
"""
# Generate transforms for 24 possible permutation/flips of
# gradient directions. (Reminder. We want to verify if there was possibly
# a flip in the gradient table).
permutations = list(itertools.permutations([0, 1, 2]))
transforms = np.zeros((len(permutations)*NB_FLIPS, 3, 3))

# Generate transforms for 24 possible permutation/flips of
# gradient directions
for i in range(len(permutations)):
transforms[i*NB_FLIPS, np.arange(3), permutations[i]] = 1
for ii in range(3):
flip = np.eye(3)
flip[ii, ii] = -1
transforms[ii+i*NB_FLIPS+1] = transforms[i*NB_FLIPS].dot(flip)

# Compute the coherence for each one.
coherence = []
for t in transforms:
index = compute_fiber_coherence(directions.dot(t), values)
Expand All @@ -61,11 +75,7 @@ def compute_fiber_coherence(peaks, values):
coherence: float
Fiber coherence value.
"""
# directions to neighbors
all_d = np.indices((3, 3, 3))
all_d = all_d.T.reshape((27, 3)) - 1
all_d = np.delete(all_d, 13, axis=0)

# Normalizing peaks
norm_peaks = np.zeros_like(peaks)
norms = np.linalg.norm(peaks, axis=-1)
norm_peaks[norms > 0] = peaks[norms > 0] / norms[norms > 0][..., None]
Expand All @@ -78,13 +88,27 @@ def compute_fiber_coherence(peaks, values):
slice_z = slice(1 + tz, peaks.shape[2] - 1 + tz)

di_norm = di / np.linalg.norm(di)
I_u = np.abs(norm_peaks.dot(di_norm)) > np.cos(ANGLE_TH)

# Spatial coherence between the peak at each voxel and the direction to
# the neighbor di.
# Ex: if the peak is aligned in x and current di is aligned in x,
# returns True (with angle < 30 ; cos angle > 30)
cos_angles = np.abs(norm_peaks.dot(di_norm))
I_u = cos_angles > np.cos(ANGLE_TH)

# Doing the same thing with v; results in the same image but translated
# from one voxel. (With 1 voxel padding around the border).
I_v = np.zeros_like(I_u)
I_v[1:-1, 1:-1, 1:-1] = I_u[slice_x, slice_y, slice_z]

# Where both conditions are met:
I_uv = np.logical_and(I_u, I_v)
u = np.nonzero(I_uv)

# v = the same voxels as u, but with the neighborhood difference.
v = tuple(np.array(u) + di.astype(int).reshape(3, 1))

# Summing the FA of those voxels
coherence += np.sum(values[u]) + np.sum(values[v])

return coherence
69 changes: 64 additions & 5 deletions scilpy/reconst/tests/test_fiber_coherence.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,70 @@
# -*- coding: utf-8 -*-
import numpy as np

from scilpy.reconst.fiber_coherence import (compute_fiber_coherence_fliptable,
compute_fiber_coherence)

def test_compute_fiber_coherence_table():
# toDO
pass

def test_compute_fiber_coherence_fliptable():
# Just checking that we get 24 values.
# See below for the real tests.
directions = np.zeros((3, 3, 5, 3), dtype=float)
fa = np.zeros((3, 3, 5), dtype=float)
coherence, transforms = compute_fiber_coherence_fliptable(directions, fa)
assert len(coherence) == 24
assert len(transforms) == 24


def test_compute_fiber_coherence():
# toDO
pass
# Coherence will be strong if we have voxels were the peak points towards
# the neighbor. Ex: Imagine the corpus callosum, where the voxels in X
# all have peaks in X.

# Test 1.
# Aligned on the last dimension (z), we have 4 peaks all pointing in the
# z direction, with strong FA.
directions = np.zeros((3, 3, 5, 3), dtype=float)
directions[1, 1, :, :] = np.asarray([[0, 0, 1],
[0, 0, 1],
[0, 0, 1],
[0, 0, -1],
[0, 0, 0]], dtype=float)
fa = np.zeros((3, 3, 5), dtype=float)
fa[1, 1, :] = [1, 1, 1, 1, 0]

# There should be a good coherence (actually we get 10).
coherence1 = compute_fiber_coherence(directions, fa)
assert coherence1 > 0

# Test 2. Testing symmetry: reversing the 4th voxel should not change the
# result
directions[1, 1, :, :] = np.asarray([[0, 0, 1],
[0, 0, 1],
[0, 0, 1],
[0, 0, 1],
[0, 0, 0]], dtype=float)
coherence2 = compute_fiber_coherence(directions, fa)
assert coherence2 == coherence1

# Test 3
# Same directions, but with low FA
fa = np.zeros((3, 3, 5), dtype=float)
fa[1, 1, :] = [0.2, 0.2, 0.2, 0.2, 0]

# There should be a good coherence (actually we get 2).
coherence3 = compute_fiber_coherence(directions, fa)
assert coherence3 < coherence2

# Test 4. Voxels with non-zero peaks still have peaks in z, but they are
# aligned in y.
directions = np.zeros((3, 5, 3, 3), dtype=float)
directions[1, :, 1, :] = np.asarray([[0, 0, 1],
[0, 0, 1],
[0, 0, 1],
[0, 0, -1],
[0, 0, 0]], dtype=float)
fa = np.zeros((3, 3, 5), dtype=float)
fa[1, 1, :] = [1, 1, 1, 1, 0]
coherence4 = compute_fiber_coherence(directions, fa)
assert coherence4 == 0

4 changes: 2 additions & 2 deletions scripts/scil_gradients_validate_correct.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from scilpy.io.utils import (add_overwrite_arg, assert_inputs_exist,
assert_outputs_exist, add_verbose_arg)
from scilpy.io.image import get_data_as_mask
from scilpy.reconst.fiber_coherence import compute_fiber_coherence_table
from scilpy.reconst.fiber_coherence import compute_fiber_coherence_fliptable


EPILOG = """
Expand Down Expand Up @@ -111,7 +111,7 @@ def main():
peaks[np.logical_not(mask)] = 0

peaks[fa < args.fa_th] = 0
coherence, transform = compute_fiber_coherence_table(peaks, fa)
coherence, transform = compute_fiber_coherence_fliptable(peaks, fa)

best_t = transform[np.argmax(coherence)]
if (best_t == np.eye(3)).all():
Expand Down

0 comments on commit 5ecc85d

Please sign in to comment.