Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New upsampling method (PTT) for tractogram #818

Merged
merged 35 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
8f83a61
Working version, no new script
frheault Nov 23, 2023
a9f62be
Fix
Nov 23, 2023
4c00772
working ptt
Nov 24, 2023
05d3d9b
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
frheault Nov 28, 2023
3ae2cb8
Fix Antoine comment, fix compress, optimize
frheault Nov 29, 2023
0962d2a
Fix conflict new scripts
frheault Dec 4, 2023
94fe97b
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
Dec 14, 2023
4504055
Merge branch 'new_upsampling_tube' of github.com:frheault/scilpy into…
Dec 14, 2023
ed7e4a1
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
frheault Dec 18, 2023
8e6249c
Merge branch 'json_rename_scripts' of github.com:frheault/scilpy into…
frheault Dec 18, 2023
f5168ad
Change RNG to generator
frheault Dec 18, 2023
c73fd70
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
Dec 20, 2023
255d7cf
Merge branch 'new_upsampling_tube' of github.com:frheault/scilpy into…
Dec 20, 2023
a4f38e7
Fix Pep8
Dec 20, 2023
7a5fcd9
fix test
Dec 20, 2023
84fd8f4
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
frheault Jan 10, 2024
c5fb0f7
Fix Antoine comments
frheault Jan 10, 2024
d08533d
Merge branch 'new_upsampling_tube' of github.com:frheault/scilpy into…
frheault Jan 10, 2024
b1ab095
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
frheault Jan 15, 2024
bc8fc56
Added unit test
frheault Jan 29, 2024
9f83929
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
frheault Jan 29, 2024
fba8fcf
added newline
frheault Feb 1, 2024
508ce0e
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
frheault Feb 20, 2024
bac39ef
Fix conflict
frheault Feb 21, 2024
86e8f5b
Move functions and add tests, emmanuel comments
frheault Feb 22, 2024
c386b56
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
frheault Feb 27, 2024
521acab
Move function, add tests, pep8
frheault Feb 27, 2024
6765016
Fix import
frheault Feb 27, 2024
5d26160
Fix kdtree type error
frheault Feb 28, 2024
b32e164
Revert to sklearn
frheault Feb 28, 2024
ebc9205
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
Mar 4, 2024
c492232
address comments
Mar 17, 2024
c50a518
Fix conflict
Mar 17, 2024
de85fd9
Fix error for pytest
Mar 18, 2024
aa8038c
Fix Emmanuelle comments in argparse
frheault Mar 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 50 additions & 18 deletions scilpy/tractograms/tractogram_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
from dipy.tracking.streamline import transform_streamlines
from nibabel.streamlines.array_sequence import ArraySequence
import numpy as np
from numpy.polynomial.polynomial import Polynomial
from scipy.ndimage import map_coordinates
from scipy.spatial import cKDTree

from scilpy.tractograms.streamline_operations import smooth_line_gaussian, \
smooth_line_spline
smooth_line_spline, resample_streamlines_step_size
from scilpy.utils.streamlines import cut_invalid_streamlines
from scilpy.utils.spatial_ops import parallel_transport_streamline

MIN_NB_POINTS = 10
KEY_INDEX = np.concatenate((range(5), range(-1, -6, -1)))
Expand Down Expand Up @@ -603,7 +605,8 @@ def transform_warp_sft(sft, linear_transfo, target, inverse=False,
return new_sft


def upsample_tractogram(sft, nb, point_wise_std=None, streamline_wise_std=None,
def upsample_tractogram(sft, nb, point_wise_std=None,
streamline_wise_std=None, keep_tube=True,
gaussian=None, spline=None, seed=None):
"""
Generates new streamlines by either adding gaussian noise around
Expand All @@ -622,6 +625,8 @@ def upsample_tractogram(sft, nb, point_wise_std=None, streamline_wise_std=None,
streamline_wise_std : float
The standard deviation of the gaussian to use to generate
streamline-wise noise on the streamlines.
keep_tube : bool
If True, simply move the streamlines along their tangent uniformely.
gaussian: float
frheault marked this conversation as resolved.
Show resolved Hide resolved
The sigma used for smoothing streamlines.
spline: (float, int)
Expand All @@ -635,27 +640,54 @@ def upsample_tractogram(sft, nb, point_wise_std=None, streamline_wise_std=None,
new_sft : StatefulTractogram
The upsampled tractogram.
"""
assert bool(point_wise_std) ^ bool(streamline_wise_std), \
'Can only add either point-wise or streamline-wise noise' + \
', not both nor none.'
def linear(data):
"""Scale data linearly to [-1, 1]."""
min_val = np.min(data)
max_val = np.max(data)
return 2 * (data - min_val) / (max_val - min_val) - 1

def sigmoid(data):
"""Scale data using sigmoid and then to [-1, 1]."""
return 2 * (1 / (1 + np.exp(-data))) - 1

def tanh(data):
"""Scale data using tanh to [-1, 1]."""
return np.tanh(data)

def exp(data):
"""Scale data using exp and then to [-1, 1]."""
exp_data = np.exp(data)
scaled_data = exp_data / np.max(exp_data)
return 2 * scaled_data - 1

rng = np.random.RandomState(seed)

# Get the number of streamlines to add
nb_new = nb - len(sft.streamlines)

# Get the streamlines that will serve as a base for new ones
indices = rng.choice(
len(sft.streamlines), nb_new)
new_streamlines = sft.streamlines.copy()

resample_sft = resample_streamlines_step_size(sft, 1)
frheault marked this conversation as resolved.
Show resolved Hide resolved
new_streamlines = []
indices = np.random.choice(len(resample_sft), nb,
replace=True)
# For all selected streamlines, add noise and smooth
for s in sft.streamlines[indices]:
if point_wise_std:
noise = rng.normal(scale=point_wise_std, size=s.shape)
else: # streamline_wise_std
noise = rng.normal(scale=streamline_wise_std, size=s.shape[-1])
new_s = s + noise
for s in resample_sft.streamlines[indices]:
if len(s) < 3:
new_streamlines.append(s)

new_s = parallel_transport_streamline(s, 1,
frheault marked this conversation as resolved.
Show resolved Hide resolved
streamline_wise_std)[0]

# Generate smooth noise_factor
noise = np.random.normal(loc=0, scale=point_wise_std,
size=len(s))
x = np.arange(len(noise))
poly_coeffs = np.polyfit(x, noise, 3)
polynomial = Polynomial(poly_coeffs[::-1])
frheault marked this conversation as resolved.
Show resolved Hide resolved
noise_factor = polynomial(x)
# print(noise_factor)

vec = s - new_s
vec /= np.linalg.norm(vec, axis=0)
new_s += vec * np.expand_dims(noise_factor, axis=1)

if gaussian:
new_s = smooth_line_gaussian(new_s, gaussian)
elif spline:
Expand Down
153 changes: 153 additions & 0 deletions scilpy/utils/spatial_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# -*- coding: utf-8 -*-

import numpy as np


def normalize_vector(v):
"""Normalize a 3D vector."""
norm = np.linalg.norm(v)
return v / norm if norm != 0 else v


def find_perpendicular_vectors(v):
frheault marked this conversation as resolved.
Show resolved Hide resolved
"""Find two perpendicular unit vectors for a given normalized 3D vector."""
if v[0] == 0 and v[1] == 0:
if v[2] == 0:
# v is a zero vector, can't find perpendicular vectors
return None, None
# v is along the z-axis, choose x and y axes as perpendicular vectors
return np.array([1, 0, 0]), np.array([0, 1, 0])

# General case, find one vector perpendicular to v and z-axis
u = np.cross(v, [0, 0, 1])
u = normalize_vector(u)

# Find another vector perpendicular to both v and u
w = np.cross(v, u)
w = normalize_vector(w)

return u, w


def project_on_plane(vector, u, w):
frheault marked this conversation as resolved.
Show resolved Hide resolved
"""Project a vector onto the plane defined by vectors u and w."""
# Calculate projections onto u and w
proj_u = np.dot(vector, u) * u
proj_w = np.dot(vector, w) * w

# Combine projections
projection = proj_u + proj_w

# Normalize and scale to match the original vector's norm
norm_vector = np.linalg.norm(vector)
normalized_projection = normalize_vector(projection)
scaled_projection = normalized_projection * norm_vector

return scaled_projection


def parallel_transport_streamline(streamline, nb_streamlines, radius):
""" Generate new streamlines by parallel transport of the input
streamline. See [0] and [1] for more details.

[0]: Hanson, A.J., & Ma, H. (1995). Parallel Transport Approach to Curve Framing. # noqa E501
[1]: TD Essentials: Parallel Transport. https://www.youtube.com/watch?v=5LedteSEgOE

Parameters
----------
streamline: ndarray (N, 3)
The streamline to transport.
nb_streamlines: int
The number of streamlines to generate.
frheault marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
new_streamlines: list of ndarray (N, 3)
The generated streamlines.
"""

def r(vec, theta):
""" Rotation matrix around a 3D vector by an angle theta.
From https://stackoverflow.com/questions/6802577/rotation-of-3d-vector

TODO?: Put this somewhere else.
frheault marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
vec: ndarray (3,)
The vector to rotate around.
theta: float
The angle of rotation in radians.

Returns
-------
rot: ndarray (3, 3)
The rotation matrix.
"""

vec = vec / np.linalg.norm(vec)
x, y, z = vec
c, s = np.cos(theta), np.sin(theta)
return np.array([[c + x**2 * (1 - c),
x * y * (1 - c) - z * s,
x * z * (1 - c) + y * s],
[y * x * (1 - c) + z * s,
c + y**2 * (1 - c),
y * z * (1 - c) - x * s],
[z * x * (1 - c) - y * s,
z * y * (1 - c) + x * s,
c + z**2 * (1 - c)]])

# Compute the tangent at each point of the streamline
T = np.gradient(streamline, axis=0)
# Normalize the tangents
T = T / np.linalg.norm(T, axis=1)[:, None]

# Placeholder for the normal vector at each point
V = np.zeros_like(T)
# Set the normal vector at the first point to be [0, 1, 0]
# (arbitrary choice)
frheault marked this conversation as resolved.
Show resolved Hide resolved
V[0] = np.random.rand(3)

# For each point
for i in range(0, T.shape[0]-1):
# Compute the torsion vector
B = np.cross(T[i], T[i+1])
# If the torsion vector is 0, the normal vector does not change
if np.linalg.norm(B) < 1e-3:
V[i+1] = V[i]
# Else, the normal vector is rotated around the torsion vector by
# the torsion.
else:
B = B / np.linalg.norm(B)
theta = np.arccos(np.dot(T[i], T[i+1]))
# Rotate the vector V[i] around the vector B by theta
# radians.
V[i+1] = np.dot(r(B, theta), V[i])

# Compute the binormal vector at each point
W = np.cross(T, V, axis=1)

# Generate the new streamlines
# TODO?: This could easily be optimized to avoid the for loop, we have to
# see if this becomes a bottleneck.
new_streamlines = []
for i in range(nb_streamlines):
# Get a random number between -1 and 1
rand_v = (np.random.rand() * 2 - 1)
rand_w = (np.random.rand() * 2 - 1)
# Compute the norm of the "displacement"
norm = np.sqrt(rand_v**2 + rand_w**2)
# Displace the normal and binormal vectors by a random amount
V_mod = V * rand_v
W_mod = W * rand_w
# Compute the displacement vector
VW = (V_mod + W_mod)
# Displace the streamline around the original one following the
# parallel frame. Make sure to normalize the displacement vector
# so that the new streamline is in a circle around the original one.

new_s = streamline + (np.random.rand() * VW / norm) * radius
new_streamlines.append(new_s)

return new_streamlines
2 changes: 1 addition & 1 deletion scilpy/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,4 @@ def recursive_print(data):
print(list(data.keys()))
recursive_print(data[list(data.keys())[0]])
else:
return
return
12 changes: 7 additions & 5 deletions scripts/scil_resample_tractogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,14 @@ def _build_arg_parser():

# For upsampling:
upsampling_group = p.add_argument_group('Upsampling params')
std_group = upsampling_group.add_mutually_exclusive_group()
std_group.add_argument('--point_wise_std', type=float,
upsampling_group.add_argument('--point_wise_std', type=float, default=1,
help='Noise to add to existing streamlines\'' +
' points to generate new ones.')
std_group.add_argument('--streamline_wise_std', type=float,
' points to generate new ones [%(default)s].')
upsampling_group.add_argument('--streamline_wise_std', type=float, default=1,
help='Noise to add to existing whole' +
' streamlines to generate new ones.')
' streamlines to generate new ones [%(default)s].')
upsampling_group.add_argument('--keep_tube', action='store_true',
help='Keep streamlines as tube (default: False).')
sub_p = upsampling_group.add_mutually_exclusive_group()
sub_p.add_argument('--gaussian', metavar='SIGMA', type=int,
help='Sigma for smoothing. Use the value of surronding'
Expand Down Expand Up @@ -146,6 +147,7 @@ def main():
sft = upsample_tractogram(
sft, args.nb_streamlines,
args.point_wise_std, args.streamline_wise_std,
args.keep_tube,
args.gaussian, args.spline, args.seed)
elif args.nb_streamlines < original_number:
if args.downsample_per_cluster:
Expand Down