Skip to content

Commit

Permalink
Merge pull request #818 from frheault/new_upsampling_tube
Browse files Browse the repository at this point in the history
New upsampling method (PTT) for tractogram
  • Loading branch information
arnaudbore authored Mar 20, 2024
2 parents 380038e + aa8038c commit c584c7e
Show file tree
Hide file tree
Showing 10 changed files with 360 additions and 78 deletions.
90 changes: 90 additions & 0 deletions scilpy/tractograms/streamline_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from dipy.tracking.streamlinespeed import (length, set_number_of_points)
from scipy.interpolate import splev, splprep

from scilpy.utils.util import rotation_around_vector_matrix


def _get_streamline_pt_index(points_to_index, vox_index, from_start=True):
"""Get the index of the streamline point in the voxel.
Expand Down Expand Up @@ -84,6 +86,7 @@ def _get_point_on_line(first_point, second_point, vox_lower_corner):

return first_point + ray * (t0 + t1) / 2.


def filter_streamlines_by_length(sft, min_length=0., max_length=np.inf):
"""
Filter streamlines using minimum and max length.
Expand Down Expand Up @@ -388,3 +391,90 @@ def smooth_line_spline(streamline, smoothing_parameter, nb_ctrl_points):
smoothed_streamline[-1] = streamline[-1]

return smoothed_streamline


def parallel_transport_streamline(streamline, nb_streamlines, radius, rng=None):
""" Generate new streamlines by parallel transport of the input
streamline. See [0] and [1] for more details.
[0]: Hanson, A.J., & Ma, H. (1995). Parallel Transport Approach to
Curve Framing. # noqa E501
[1]: TD Essentials: Parallel Transport.
https://www.youtube.com/watch?v=5LedteSEgOE
Parameters
----------
streamline: ndarray (N, 3)
The streamline to transport.
nb_streamlines: int
The number of streamlines to generate.
radius: float
The radius of the circle around the original streamline in which the
new streamlines will be generated.
rng: numpy.random.Generator, optional
The random number generator to use. If None, the default numpy
random number generator will be used.
Returns
-------
new_streamlines: list of ndarray (N, 3)
The generated streamlines.
"""

if rng is None:
rng = np.random.default_rng(0)

# Compute the tangent at each point of the streamline
T = np.gradient(streamline, axis=0)
# Normalize the tangents
T = T / np.linalg.norm(T, axis=1)[:, None]

# Placeholder for the normal vector at each point
V = np.zeros_like(T)
# Set the normal vector at the first point to kind of perpendicular to
# the first direction vector
V[0] = np.roll(streamline[0] - streamline[1], 1)
V[0] = V[0] / np.linalg.norm(V[0])
# For each point
for i in range(0, T.shape[0]-1):
# Compute the torsion vector
B = np.cross(T[i], T[i+1])
# If the torsion vector is 0, the normal vector does not change
if np.linalg.norm(B) < 1e-3:
V[i+1] = V[i]
# Else, the normal vector is rotated around the torsion vector by
# the torsion.
else:
B = B / np.linalg.norm(B)
theta = np.arccos(np.dot(T[i], T[i+1]))
# Rotate the vector V[i] around the vector B by theta
# radians.
V[i+1] = np.dot(rotation_around_vector_matrix(B, theta), V[i])

# Compute the binormal vector at each point
W = np.cross(T, V, axis=1)

# Generate the new streamlines
# TODO?: This could easily be optimized to avoid the for loop, we have to
# see if this becomes a bottleneck.
new_streamlines = []
for i in range(nb_streamlines):
# Get a random number between -1 and 1
rand_v = rng.uniform(-1, 1)
rand_w = rng.uniform(-1, 1)

# Compute the norm of the "displacement"
norm = np.sqrt(rand_v**2 + rand_w**2)
# Displace the normal and binormal vectors by a random amount
V_mod = V * rand_v
W_mod = W * rand_w
# Compute the displacement vector
VW = (V_mod + W_mod)
# Displace the streamline around the original one following the
# parallel frame. Make sure to normalize the displacement vector
# so that the new streamline is in a circle around the original one.

new_s = streamline + (rng.uniform(0, 1) * VW / norm) * radius
new_streamlines.append(new_s)

return new_streamlines
26 changes: 25 additions & 1 deletion scilpy/tractograms/tests/test_streamline_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
import os
import tempfile

from dipy.io.streamline import load_tractogram
from dipy.tracking.streamlinespeed import length
import nibabel as nib
import numpy as np
from numpy.testing import assert_array_almost_equal
import pytest
from dipy.io.streamline import load_tractogram
from dipy.tracking.streamlinespeed import length
Expand All @@ -16,7 +19,8 @@
resample_streamlines_num_points,
resample_streamlines_step_size,
smooth_line_gaussian,
smooth_line_spline)
smooth_line_spline,
parallel_transport_streamline)
from scilpy.tractograms.tractogram_operations import concatenate_sft


Expand Down Expand Up @@ -368,3 +372,23 @@ def test_smooth_line_spline():
dist_2 = np.linalg.norm(noisy_streamline - smoothed_streamline)

assert dist_1 < dist_2


def test_parallel_transport_streamline():
sft, _ = _setup_files()
streamline = sft.streamlines[0]

rng = np.random.default_rng(3018)
pt_streamlines = parallel_transport_streamline(
streamline, 20, 5, rng)

avg_streamline = np.mean(pt_streamlines, axis=0)

assert_array_almost_equal(avg_streamline[0],
[-26.999582, -116.320145, 6.3678055],
decimal=4)
assert_array_almost_equal(avg_streamline[-1],
[-155.99944, -116.56515, 6.2451267],
decimal=4)
assert [len(s) for s in pt_streamlines] == [130] * 20
assert len(pt_streamlines) == 20
44 changes: 43 additions & 1 deletion scilpy/tractograms/tests/test_tractogram_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tempfile

import numpy as np
from dipy.io.stateful_tractogram import StatefulTractogram
from dipy.io.streamline import load_tractogram

from scilpy import SCILPY_HOME
Expand All @@ -18,6 +19,9 @@
perform_tractogram_operation_on_lines,
perform_tractogram_operation_on_sft,
shuffle_streamlines,
split_sft_randomly,
split_sft_randomly_per_cluster,
upsample_tractogram,
union,
union_robust)

Expand Down Expand Up @@ -136,7 +140,7 @@ def test_robust_operations():

def test_concatenate_sft():
# Testing with different metadata
sft2 = sft.from_sft(sft.streamlines, sft)
sft2 = StatefulTractogram.from_sft(sft.streamlines, sft)
sft2.data_per_point['test2_different'] = [[['a', 'b', 'c']] * len(s)
for s in sft.streamlines]

Expand All @@ -158,3 +162,41 @@ def test_combining_sft():
perform_tractogram_operation_on_sft('union', [sft, sft], precision=None,
fake_metadata=False, no_metadata=False)


def test_upsample_tractogram():
new_sft = upsample_tractogram(sft, 1000, 0.5, 5, False, 0.1, 0)
first_chunk = [[112.64021, 35.409477, 59.42175],
[109.09777, 35.287857, 61.845505],
[110.41855, 37.077374, 56.930523]]
last_chunk = [[110.40285, 51.036686, 62.419273],
[109.698586, 48.330017, 64.50656],
[113.04737, 45.89119, 64.778534]]

assert len(new_sft) == 1000
assert len(new_sft.streamlines._data) == 8404
assert np.allclose(first_chunk, new_sft.streamlines._data[0:30:10])
assert np.allclose(last_chunk, new_sft.streamlines._data[-1:-31:-10])


def test_split_sft_randomly():
sft_copy = StatefulTractogram.from_sft(sft.streamlines, sft)
new_sft_list = split_sft_randomly(sft_copy, 2, 0)

assert len(new_sft_list) == 2 and isinstance(new_sft_list, list)
assert len(new_sft_list[0]) == 2 and len(new_sft_list[1]) == 2
assert np.allclose(new_sft_list[0].streamlines[0][0],
[112.458, 35.7144, 58.7432])
assert np.allclose(new_sft_list[1].streamlines[0][0],
[112.168, 35.259, 59.419])


def test_split_sft_randomly_per_cluster():
sft_copy = StatefulTractogram.from_sft(sft.streamlines, sft)
new_sft_list = split_sft_randomly_per_cluster(sft_copy, [2], 0,
[40, 30, 20, 10])
assert len(new_sft_list) == 2 and isinstance(new_sft_list, list)
assert len(new_sft_list[0]) == 2 and len(new_sft_list[1]) == 2
assert np.allclose(new_sft_list[0].streamlines[0][0],
[112.168, 35.259, 59.419])
assert np.allclose(new_sft_list[1].streamlines[0][0],
[112.266, 35.4188, 59.0421])
107 changes: 77 additions & 30 deletions scilpy/tractograms/tractogram_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@
from dipy.io.utils import get_reference_info, is_header_compatible
from dipy.segment.clustering import qbx_and_merge
from dipy.tracking.streamline import transform_streamlines
from dipy.tracking.streamlinespeed import compress_streamlines
from nibabel.streamlines import TrkFile, TckFile
from nibabel.streamlines.array_sequence import ArraySequence
import numpy as np
from numpy.polynomial.polynomial import Polynomial
from scipy.ndimage import map_coordinates
from scipy.spatial import cKDTree

from scilpy.tractograms.streamline_operations import smooth_line_gaussian, \
smooth_line_spline
resample_streamlines_step_size, parallel_transport_streamline
from scilpy.utils.streamlines import cut_invalid_streamlines

MIN_NB_POINTS = 10
Expand Down Expand Up @@ -603,8 +606,36 @@ def transform_warp_sft(sft, linear_transfo, target, inverse=False,
return new_sft


def upsample_tractogram(sft, nb, point_wise_std=None, streamline_wise_std=None,
gaussian=None, spline=None, seed=None):
def compress_streamlines_wrapper(tractogram, error_rate):
"""
Compresses the streamlines of a tractogram.
Supports both nibabel.Tractogram dipy.StatefulTractogram
or list of streamlines.
Parameters
----------
tractogram: TrkFile, TckFile, ArraySequence, list
The tractogram to compress.
error_rate: float
The maximum distance (in mm) for point displacement during compression.
Returns
-------
compressed_streamlines: list of np.ndarray
The compressed streamlines.
"""
if isinstance(tractogram, (TrkFile, TckFile)):
return lambda: (compress_streamlines(
s, error_rate) for s in tractogram.streamlines)
else:
if hasattr(tractogram, 'streamlines'):
tractogram = tractogram.streamlines
return [compress_streamlines(
s, error_rate) for s in tractogram]


def upsample_tractogram(sft, nb, point_wise_std=None, tube_radius=None,
gaussian=None, error_rate=None, seed=None):
"""
Generates new streamlines by either adding gaussian noise around
streamlines' points, or by translating copies of existing streamlines
Expand All @@ -619,11 +650,12 @@ def upsample_tractogram(sft, nb, point_wise_std=None, streamline_wise_std=None,
point_wise_std : float
The standard deviation of the gaussian to use to generate point-wise
noise on the streamlines.
streamline_wise_std : float
The standard deviation of the gaussian to use to generate
streamline-wise noise on the streamlines.
tube_radius : float
The radius of the tube used to model the streamlines.
gaussian: float
The sigma used for smoothing streamlines.
error_rate : float
The maximum distance (in mm) to the original position of any point.
spline: (float, int)
Pair of sigma and number of control points used to model each
streamline as a spline and smooth it.
Expand All @@ -635,35 +667,49 @@ def upsample_tractogram(sft, nb, point_wise_std=None, streamline_wise_std=None,
new_sft : StatefulTractogram
The upsampled tractogram.
"""
assert bool(point_wise_std) ^ bool(streamline_wise_std), \
'Can only add either point-wise or streamline-wise noise' + \
', not both nor none.'

rng = np.random.RandomState(seed)

# Get the number of streamlines to add
nb_new = nb - len(sft.streamlines)
rng = np.random.default_rng(seed)

# Get the streamlines that will serve as a base for new ones
indices = rng.choice(
len(sft.streamlines), nb_new)
new_streamlines = sft.streamlines.copy()
resampled_sft = resample_streamlines_step_size(sft, 1)
new_streamlines = []
indices = rng.choice(len(resampled_sft), nb, replace=True)
unique_indices, count = np.unique(indices, return_counts=True)

# For all selected streamlines, add noise and smooth
for s in sft.streamlines[indices]:
if point_wise_std:
noise = rng.normal(scale=point_wise_std, size=s.shape)
else: # streamline_wise_std
noise = rng.normal(scale=streamline_wise_std, size=s.shape[-1])
new_s = s + noise
for i, c in zip(unique_indices, count):
s = resampled_sft.streamlines[i]
if len(s) < 3:
new_streamlines.extend(np.repeat(s, c).tolist())
new_s = parallel_transport_streamline(s, c, tube_radius)

# Generate smooth noise_factor
noise = rng.normal(loc=0, scale=point_wise_std,
size=len(s))

# Instead of generating random noise, we fit a polynomial to the
# noise and use it to generate a spatially smooth noise along the
# streamline (simply to avoid sharp changes in the noise factor).
x = np.arange(len(noise))
poly_coeffs = np.polyfit(x, noise, 3)
polynomial = Polynomial(poly_coeffs[::-1])
noise_factor = polynomial(x)

vec = s - new_s
vec /= np.linalg.norm(vec, axis=0)
new_s += vec * np.expand_dims(noise_factor, axis=1)

if gaussian:
new_s = smooth_line_gaussian(new_s, gaussian)
elif spline:
new_s = smooth_line_spline(new_s, spline[0], spline[1])
new_s = [smooth_line_gaussian(s, gaussian) for s in new_s]

new_streamlines.extend(new_s)

new_streamlines.append(new_s)
if error_rate:
compressed_streamlines = compress_streamlines_wrapper(new_streamlines,
error_rate)
else:
compressed_streamlines = new_streamlines

new_sft = StatefulTractogram.from_sft(new_streamlines, sft)
new_sft = StatefulTractogram.from_sft(compressed_streamlines, sft)
return new_sft


Expand Down Expand Up @@ -798,9 +844,10 @@ def split_sft_randomly_per_cluster(orig_sft, chunk_sizes, seed, thresholds):
nb_chunks = len(chunk_sizes)
percent_kept_per_chunk = [nb / len(orig_sft) for nb in chunk_sizes]

logging.info("Computing QBx")
logging.debug("Computing QBx")
rng = np.random.RandomState(seed)
clusters = qbx_and_merge(orig_sft.streamlines, thresholds, nb_pts=20,
verbose=False)
verbose=False, rng=rng)

logging.info("Done. Now getting list of indices in each of the {} "
"cluster.".format(len(clusters)))
Expand Down
Empty file added scilpy/utils/tests/__init__.py
Empty file.
Loading

0 comments on commit c584c7e

Please sign in to comment.