Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New upsampling method (PTT) for tractogram #818

Merged
merged 35 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
8f83a61
Working version, no new script
frheault Nov 23, 2023
a9f62be
Fix
Nov 23, 2023
4c00772
working ptt
Nov 24, 2023
05d3d9b
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
frheault Nov 28, 2023
3ae2cb8
Fix Antoine comment, fix compress, optimize
frheault Nov 29, 2023
0962d2a
Fix conflict new scripts
frheault Dec 4, 2023
94fe97b
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
Dec 14, 2023
4504055
Merge branch 'new_upsampling_tube' of github.com:frheault/scilpy into…
Dec 14, 2023
ed7e4a1
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
frheault Dec 18, 2023
8e6249c
Merge branch 'json_rename_scripts' of github.com:frheault/scilpy into…
frheault Dec 18, 2023
f5168ad
Change RNG to generator
frheault Dec 18, 2023
c73fd70
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
Dec 20, 2023
255d7cf
Merge branch 'new_upsampling_tube' of github.com:frheault/scilpy into…
Dec 20, 2023
a4f38e7
Fix Pep8
Dec 20, 2023
7a5fcd9
fix test
Dec 20, 2023
84fd8f4
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
frheault Jan 10, 2024
c5fb0f7
Fix Antoine comments
frheault Jan 10, 2024
d08533d
Merge branch 'new_upsampling_tube' of github.com:frheault/scilpy into…
frheault Jan 10, 2024
b1ab095
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
frheault Jan 15, 2024
bc8fc56
Added unit test
frheault Jan 29, 2024
9f83929
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
frheault Jan 29, 2024
fba8fcf
added newline
frheault Feb 1, 2024
508ce0e
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
frheault Feb 20, 2024
bac39ef
Fix conflict
frheault Feb 21, 2024
86e8f5b
Move functions and add tests, emmanuel comments
frheault Feb 22, 2024
c386b56
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
frheault Feb 27, 2024
521acab
Move function, add tests, pep8
frheault Feb 27, 2024
6765016
Fix import
frheault Feb 27, 2024
5d26160
Fix kdtree type error
frheault Feb 28, 2024
b32e164
Revert to sklearn
frheault Feb 28, 2024
ebc9205
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
Mar 4, 2024
c492232
address comments
Mar 17, 2024
c50a518
Fix conflict
Mar 17, 2024
de85fd9
Fix error for pytest
Mar 18, 2024
aa8038c
Fix Emmanuelle comments in argparse
frheault Mar 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions scilpy/tractograms/streamline_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,3 +388,122 @@ def smooth_line_spline(streamline, smoothing_parameter, nb_ctrl_points):
smoothed_streamline[-1] = streamline[-1]

return smoothed_streamline


def rotation_around_vector_matrix(vec, theta):
""" Rotation matrix around a 3D vector by an angle theta.
From https://stackoverflow.com/questions/6802577/rotation-of-3d-vector

TODO?: Put this somewhere else.
frheault marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
vec: ndarray (3,)
The vector to rotate around.
theta: float
The angle of rotation in radians.

Returns
-------
rot: ndarray (3, 3)
The rotation matrix.
"""

vec = vec / np.linalg.norm(vec)
x, y, z = vec
c, s = np.cos(theta), np.sin(theta)
return np.array([[c + x**2 * (1 - c),
x * y * (1 - c) - z * s,
x * z * (1 - c) + y * s],
[y * x * (1 - c) + z * s,
c + y**2 * (1 - c),
y * z * (1 - c) - x * s],
[z * x * (1 - c) - y * s,
z * y * (1 - c) + x * s,
c + z**2 * (1 - c)]])

def parallel_transport_streamline(streamline, nb_streamlines, radius, rng=None):
""" Generate new streamlines by parallel transport of the input
frheault marked this conversation as resolved.
Show resolved Hide resolved
streamline. See [0] and [1] for more details.

[0]: Hanson, A.J., & Ma, H. (1995). Parallel Transport Approach to
Curve Framing. # noqa E501
[1]: TD Essentials: Parallel Transport.
https://www.youtube.com/watch?v=5LedteSEgOE

Parameters
----------
streamline: ndarray (N, 3)
The streamline to transport.
nb_streamlines: int
The number of streamlines to generate.
radius: float
The radius of the circle around the original streamline in which the
new streamlines will be generated.
rng: numpy.random.Generator, optional
The random number generator to use. If None, the default numpy
random number generator will be used.

Returns
-------
new_streamlines: list of ndarray (N, 3)
The generated streamlines.
"""

if rng is None:
rng = np.random.default_rng(0)

# Compute the tangent at each point of the streamline
T = np.gradient(streamline, axis=0)
# Normalize the tangents
T = T / np.linalg.norm(T, axis=1)[:, None]

# Placeholder for the normal vector at each point
V = np.zeros_like(T)
# Set the normal vector at the first point to kind of perpendicular to
# the first direction vector
V[0] = np.roll(streamline[0] - streamline[1], 1)
V[0] = V[0] / np.linalg.norm(V[0])
# For each point
for i in range(0, T.shape[0]-1):
# Compute the torsion vector
B = np.cross(T[i], T[i+1])
# If the torsion vector is 0, the normal vector does not change
if np.linalg.norm(B) < 1e-3:
V[i+1] = V[i]
# Else, the normal vector is rotated around the torsion vector by
# the torsion.
else:
B = B / np.linalg.norm(B)
theta = np.arccos(np.dot(T[i], T[i+1]))
# Rotate the vector V[i] around the vector B by theta
# radians.
V[i+1] = np.dot(rotation_around_vector_matrix(B, theta), V[i])

# Compute the binormal vector at each point
W = np.cross(T, V, axis=1)

# Generate the new streamlines
# TODO?: This could easily be optimized to avoid the for loop, we have to
# see if this becomes a bottleneck.
new_streamlines = []
for i in range(nb_streamlines):
# Get a random number between -1 and 1
rand_v = rng.uniform(-1, 1)
rand_w = rng.uniform(-1, 1)

# Compute the norm of the "displacement"
norm = np.sqrt(rand_v**2 + rand_w**2)
# Displace the normal and binormal vectors by a random amount
V_mod = V * rand_v
W_mod = W * rand_w
# Compute the displacement vector
VW = (V_mod + W_mod)
# Displace the streamline around the original one following the
# parallel frame. Make sure to normalize the displacement vector
# so that the new streamline is in a circle around the original one.

new_s = streamline + (rng.uniform(0, 1) * VW / norm) * radius
new_streamlines.append(new_s)

return new_streamlines
51 changes: 50 additions & 1 deletion scilpy/tractograms/tests/test_streamline_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import nibabel as nib
import numpy as np
from numpy.testing import assert_almost_equal, assert_array_equal
import pytest

from dipy.io.streamline import load_tractogram
Expand All @@ -16,7 +17,8 @@
resample_streamlines_num_points,
resample_streamlines_step_size,
smooth_line_gaussian,
smooth_line_spline)
smooth_line_spline,
rotation_around_vector_matrix)
from scilpy.tractograms.tractogram_operations import concatenate_sft

fetch_data(get_testing_files_dict(), keys=['tractograms.zip'])
Expand Down Expand Up @@ -367,3 +369,50 @@ def test_smooth_line_spline():
dist_2 = np.linalg.norm(noisy_streamline - smoothed_streamline)

assert dist_1 < dist_2


def test_output_shape_and_type():
frheault marked this conversation as resolved.
Show resolved Hide resolved
"""Test the output shape and type."""
vec = np.array([1, 0, 0])
theta = np.pi / 4 # 45 degrees
rot_matrix = rotation_around_vector_matrix(vec, theta)
assert isinstance(rot_matrix, np.ndarray)
assert np.array_equal(rot_matrix.shape, (3, 3))


def test_magnitude_preservation():
"""Test if the rotation preserves the magnitude of a vector."""
vec = np.array([1, 0, 0])
theta = np.pi / 4
rot_matrix = rotation_around_vector_matrix(vec, theta)
rotated_vec = np.dot(rot_matrix, vec)
assert_almost_equal(np.linalg.norm(rotated_vec), np.linalg.norm(vec),
decimal=5)


def test_known_rotation():
"""Test a known rotation case."""
vec = np.array([0, 0, 1]) # Rotation around z-axis
theta = np.pi / 2 # 90 degrees
rot_matrix = rotation_around_vector_matrix(vec, theta)
original_vec = np.array([1, 0, 0])
expected_rotated_vec = np.array([0, 1, 0])
rotated_vec = np.dot(rot_matrix, original_vec)
assert_almost_equal(rotated_vec, expected_rotated_vec, decimal=5)


def test_zero_rotation():
"""Test rotation with theta = 0."""
vec = np.array([1, 0, 0])
theta = 0
rot_matrix = rotation_around_vector_matrix(vec, theta)
np.array_equal(rot_matrix, np.eye(3))


def test_full_rotation():
"""Test rotation with theta = 2*pi (should be identity)."""
vec = np.array([1, 0, 0])
theta = 2 * np.pi
rot_matrix = rotation_around_vector_matrix(vec, theta)
# Allow for minor floating-point errors
assert_almost_equal(rot_matrix, np.eye(3), decimal=5)
44 changes: 42 additions & 2 deletions scilpy/tractograms/tests/test_tractogram_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
import tempfile

import numpy as np
from dipy.io.stateful_tractogram import StatefulTractogram
from dipy.io.streamline import load_tractogram

from scilpy.io.fetcher import fetch_data, get_testing_files_dict, get_home
from scilpy.tractograms.tractogram_operations import flip_sft, \
shuffle_streamlines, perform_tractogram_operation_on_lines, intersection, union, \
difference, intersection_robust, difference_robust, union_robust, \
concatenate_sft, perform_tractogram_operation_on_sft
concatenate_sft, perform_tractogram_operation_on_sft, upsample_tractogram, \
split_sft_randomly, split_sft_randomly_per_cluster

# Prepare SFT
fetch_data(get_testing_files_dict(), keys='surface_vtk_fib.zip')
Expand Down Expand Up @@ -126,7 +128,7 @@ def test_robust_operations():

def test_concatenate_sft():
# Testing with different metadata
sft2 = sft.from_sft(sft.streamlines, sft)
sft2 = StatefulTractogram.from_sft(sft.streamlines, sft)
sft2.data_per_point['test2_different'] = [[['a', 'b', 'c']] * len(s)
for s in sft.streamlines]

Expand All @@ -148,3 +150,41 @@ def test_combining_sft():
perform_tractogram_operation_on_sft('union', [sft, sft], precision=None,
fake_metadata=False, no_metadata=False)


def test_upsample_tractogram():
new_sft = upsample_tractogram(sft, 1000, 0.5, 5, False, 0.1, 0)
first_chunk = [[112.64021, 35.409477, 59.42175],
[109.09777, 35.287857, 61.845505],
[110.41855, 37.077374, 56.930523]]
last_chunk = [[110.40285, 51.036686, 62.419273],
[109.698586, 48.330017, 64.50656],
[113.04737, 45.89119, 64.778534]]

assert len(new_sft) == 1000
assert len(new_sft.streamlines._data) == 8404
assert np.allclose(first_chunk, new_sft.streamlines._data[0:30:10])
assert np.allclose(last_chunk, new_sft.streamlines._data[-1:-31:-10])


def test_split_sft_randomly():
sft_copy = StatefulTractogram.from_sft(sft.streamlines, sft)
new_sft_list = split_sft_randomly(sft_copy, 2, 0)

assert len(new_sft_list) == 2 and isinstance(new_sft_list, list)
assert len(new_sft_list[0]) == 2 and len(new_sft_list[1]) == 2
assert np.allclose(new_sft_list[0].streamlines[0][0],
[112.458, 35.7144, 58.7432])
assert np.allclose(new_sft_list[1].streamlines[0][0],
[112.168, 35.259, 59.419])


def test_split_sft_randomly_per_cluster():
sft_copy = StatefulTractogram.from_sft(sft.streamlines, sft)
new_sft_list = split_sft_randomly_per_cluster(sft_copy, [2], 0,
[40, 30, 20, 10])
assert len(new_sft_list) == 2 and isinstance(new_sft_list, list)
assert len(new_sft_list[0]) == 2 and len(new_sft_list[1]) == 2
assert np.allclose(new_sft_list[0].streamlines[0][0],
[112.168, 35.259, 59.419])
assert np.allclose(new_sft_list[1].streamlines[0][0],
[112.266, 35.4188, 59.0421])
Loading