Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New upsampling method (PTT) for tractogram #818

Merged
merged 35 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
8f83a61
Working version, no new script
frheault Nov 23, 2023
a9f62be
Fix
Nov 23, 2023
4c00772
working ptt
Nov 24, 2023
05d3d9b
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
frheault Nov 28, 2023
3ae2cb8
Fix Antoine comment, fix compress, optimize
frheault Nov 29, 2023
0962d2a
Fix conflict new scripts
frheault Dec 4, 2023
94fe97b
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
Dec 14, 2023
4504055
Merge branch 'new_upsampling_tube' of github.com:frheault/scilpy into…
Dec 14, 2023
ed7e4a1
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
frheault Dec 18, 2023
8e6249c
Merge branch 'json_rename_scripts' of github.com:frheault/scilpy into…
frheault Dec 18, 2023
f5168ad
Change RNG to generator
frheault Dec 18, 2023
c73fd70
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
Dec 20, 2023
255d7cf
Merge branch 'new_upsampling_tube' of github.com:frheault/scilpy into…
Dec 20, 2023
a4f38e7
Fix Pep8
Dec 20, 2023
7a5fcd9
fix test
Dec 20, 2023
84fd8f4
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
frheault Jan 10, 2024
c5fb0f7
Fix Antoine comments
frheault Jan 10, 2024
d08533d
Merge branch 'new_upsampling_tube' of github.com:frheault/scilpy into…
frheault Jan 10, 2024
b1ab095
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
frheault Jan 15, 2024
bc8fc56
Added unit test
frheault Jan 29, 2024
9f83929
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
frheault Jan 29, 2024
fba8fcf
added newline
frheault Feb 1, 2024
508ce0e
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
frheault Feb 20, 2024
bac39ef
Fix conflict
frheault Feb 21, 2024
86e8f5b
Move functions and add tests, emmanuel comments
frheault Feb 22, 2024
c386b56
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
frheault Feb 27, 2024
521acab
Move function, add tests, pep8
frheault Feb 27, 2024
6765016
Fix import
frheault Feb 27, 2024
5d26160
Fix kdtree type error
frheault Feb 28, 2024
b32e164
Revert to sklearn
frheault Feb 28, 2024
ebc9205
Merge branch 'master' of github.com:scilus/scilpy into new_upsampling…
Mar 4, 2024
c492232
address comments
Mar 17, 2024
c50a518
Fix conflict
Mar 17, 2024
de85fd9
Fix error for pytest
Mar 18, 2024
aa8038c
Fix Emmanuelle comments in argparse
frheault Mar 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 53 additions & 28 deletions scilpy/tractograms/tractogram_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@
from dipy.io.utils import get_reference_info, is_header_compatible
from dipy.segment.clustering import qbx_and_merge
from dipy.tracking.streamline import transform_streamlines
from dipy.tracking.streamlinespeed import compress_streamlines
from nibabel.streamlines import TrkFile, TckFile
from nibabel.streamlines.array_sequence import ArraySequence
import numpy as np
from numpy.polynomial.polynomial import Polynomial
from scipy.ndimage import map_coordinates
from scipy.spatial import cKDTree

from scilpy.tractograms.streamline_operations import smooth_line_gaussian, \
smooth_line_spline
resample_streamlines_step_size
from scilpy.utils.streamlines import cut_invalid_streamlines
from scilpy.utils.spatial_ops import parallel_transport_streamline

MIN_NB_POINTS = 10
KEY_INDEX = np.concatenate((range(5), range(-1, -6, -1)))
Expand Down Expand Up @@ -603,8 +607,19 @@ def transform_warp_sft(sft, linear_transfo, target, inverse=False,
return new_sft


def upsample_tractogram(sft, nb, point_wise_std=None, streamline_wise_std=None,
gaussian=None, spline=None, seed=None):
def compress_streamlines_wrapper(tractogram, error_rate):
if isinstance(tractogram, (TrkFile, TckFile)):
arnaudbore marked this conversation as resolved.
Show resolved Hide resolved
return lambda: (compress_streamlines(
s, error_rate) for s in tractogram.streamlines)
else:
if hasattr(tractogram, 'streamlines'):
tractogram = tractogram.streamlines
return [compress_streamlines(
s, error_rate) for s in tractogram]


def upsample_tractogram(sft, nb, point_wise_std=None, tube_radius=None,
gaussian=None, error_rate=None, seed=None):
"""
Generates new streamlines by either adding gaussian noise around
streamlines' points, or by translating copies of existing streamlines
Expand All @@ -619,9 +634,8 @@ def upsample_tractogram(sft, nb, point_wise_std=None, streamline_wise_std=None,
point_wise_std : float
The standard deviation of the gaussian to use to generate point-wise
noise on the streamlines.
streamline_wise_std : float
The standard deviation of the gaussian to use to generate
streamline-wise noise on the streamlines.
error_rate : float
The maximum distance (in mm) to the original position of any point.
frheault marked this conversation as resolved.
Show resolved Hide resolved
gaussian: float
frheault marked this conversation as resolved.
Show resolved Hide resolved
The sigma used for smoothing streamlines.
spline: (float, int)
Expand All @@ -635,35 +649,46 @@ def upsample_tractogram(sft, nb, point_wise_std=None, streamline_wise_std=None,
new_sft : StatefulTractogram
The upsampled tractogram.
"""
assert bool(point_wise_std) ^ bool(streamline_wise_std), \
'Can only add either point-wise or streamline-wise noise' + \
', not both nor none.'

rng = np.random.RandomState(seed)

# Get the number of streamlines to add
nb_new = nb - len(sft.streamlines)
rng = np.random.default_rng(seed)

# Get the streamlines that will serve as a base for new ones
indices = rng.choice(
len(sft.streamlines), nb_new)
new_streamlines = sft.streamlines.copy()
resample_sft = resample_streamlines_step_size(sft, 1)
frheault marked this conversation as resolved.
Show resolved Hide resolved
new_streamlines = []
indices = rng.choice(len(resample_sft), nb, replace=True)
unique_indices, count = np.unique(indices, return_counts=True)

# For all selected streamlines, add noise and smooth
for s in sft.streamlines[indices]:
if point_wise_std:
noise = rng.normal(scale=point_wise_std, size=s.shape)
else: # streamline_wise_std
noise = rng.normal(scale=streamline_wise_std, size=s.shape[-1])
new_s = s + noise
for i, c in zip(unique_indices, count):
s = resample_sft.streamlines[i]
if len(s) < 3:
new_streamlines.extend(np.repeat(s, c).tolist())
new_s = parallel_transport_streamline(s, c, tube_radius)

# Generate smooth noise_factor
noise = rng.normal(loc=0, scale=point_wise_std,
size=len(s))

x = np.arange(len(noise))
poly_coeffs = np.polyfit(x, noise, 3)
polynomial = Polynomial(poly_coeffs[::-1])
frheault marked this conversation as resolved.
Show resolved Hide resolved
noise_factor = polynomial(x)

vec = s - new_s
vec /= np.linalg.norm(vec, axis=0)
new_s += vec * np.expand_dims(noise_factor, axis=1)

if gaussian:
new_s = smooth_line_gaussian(new_s, gaussian)
elif spline:
new_s = smooth_line_spline(new_s, spline[0], spline[1])
new_s = [smooth_line_gaussian(s, gaussian) for s in new_s]

new_streamlines.append(new_s)
new_streamlines.extend(new_s)

if error_rate:
compressed_streamlines = compress_streamlines_wrapper(new_streamlines,
error_rate)
else:
compressed_streamlines = new_streamlines

new_sft = StatefulTractogram.from_sft(new_streamlines, sft)
new_sft = StatefulTractogram.from_sft(compressed_streamlines, sft)
return new_sft


Expand Down
115 changes: 115 additions & 0 deletions scilpy/utils/spatial_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# -*- coding: utf-8 -*-

import numpy as np


def parallel_transport_streamline(streamline, nb_streamlines, radius, rng=None):
frheault marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After discussion with Arnaud, we would put this in tractogram.streamline_operations.

Can you also move the r(vec, theta) method into utils.util, and name it something like def rotation_around_vector_matrix?

""" Generate new streamlines by parallel transport of the input
streamline. See [0] and [1] for more details.

[0]: Hanson, A.J., & Ma, H. (1995). Parallel Transport Approach to
Curve Framing. # noqa E501
[1]: TD Essentials: Parallel Transport.
https://www.youtube.com/watch?v=5LedteSEgOE

Parameters
----------
streamline: ndarray (N, 3)
The streamline to transport.
nb_streamlines: int
The number of streamlines to generate.
frheault marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
new_streamlines: list of ndarray (N, 3)
The generated streamlines.
"""

def r(vec, theta):
""" Rotation matrix around a 3D vector by an angle theta.
From https://stackoverflow.com/questions/6802577/rotation-of-3d-vector

TODO?: Put this somewhere else.
frheault marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
vec: ndarray (3,)
The vector to rotate around.
theta: float
The angle of rotation in radians.

Returns
-------
rot: ndarray (3, 3)
The rotation matrix.
"""

vec = vec / np.linalg.norm(vec)
x, y, z = vec
c, s = np.cos(theta), np.sin(theta)
return np.array([[c + x**2 * (1 - c),
x * y * (1 - c) - z * s,
x * z * (1 - c) + y * s],
[y * x * (1 - c) + z * s,
c + y**2 * (1 - c),
y * z * (1 - c) - x * s],
[z * x * (1 - c) - y * s,
z * y * (1 - c) + x * s,
c + z**2 * (1 - c)]])
if rng is None:
rng = np.random.default_rng(0)

# Compute the tangent at each point of the streamline
T = np.gradient(streamline, axis=0)
# Normalize the tangents
T = T / np.linalg.norm(T, axis=1)[:, None]

# Placeholder for the normal vector at each point
V = np.zeros_like(T)
# Set the normal vector at the first point to be [0, 1, 0]
# (arbitrary choice)
frheault marked this conversation as resolved.
Show resolved Hide resolved
V[0] = np.roll(streamline[0] - streamline[1], 1)
V[0] = V[0] / np.linalg.norm(V[0])
# For each point
for i in range(0, T.shape[0]-1):
# Compute the torsion vector
B = np.cross(T[i], T[i+1])
# If the torsion vector is 0, the normal vector does not change
if np.linalg.norm(B) < 1e-3:
V[i+1] = V[i]
# Else, the normal vector is rotated around the torsion vector by
# the torsion.
else:
B = B / np.linalg.norm(B)
theta = np.arccos(np.dot(T[i], T[i+1]))
# Rotate the vector V[i] around the vector B by theta
# radians.
V[i+1] = np.dot(r(B, theta), V[i])

# Compute the binormal vector at each point
W = np.cross(T, V, axis=1)

# Generate the new streamlines
# TODO?: This could easily be optimized to avoid the for loop, we have to
# see if this becomes a bottleneck.
new_streamlines = []
for i in range(nb_streamlines):
# Get a random number between -1 and 1
rand_v = rng.uniform(-1, 1)
rand_w = rng.uniform(-1, 1)

# Compute the norm of the "displacement"
norm = np.sqrt(rand_v**2 + rand_w**2)
# Displace the normal and binormal vectors by a random amount
V_mod = V * rand_v
W_mod = W * rand_w
# Compute the displacement vector
VW = (V_mod + W_mod)
# Displace the streamline around the original one following the
# parallel frame. Make sure to normalize the displacement vector
# so that the new streamline is in a circle around the original one.

new_s = streamline + (rng.uniform(0, 1) * VW / norm) * radius
new_streamlines.append(new_s)

return new_streamlines
2 changes: 1 addition & 1 deletion scilpy/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,4 @@ def recursive_print(data):
print(list(data.keys()))
recursive_print(data[list(data.keys())[0]])
else:
return
return
8 changes: 1 addition & 7 deletions scripts/scil_tractogram_compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
import argparse
import logging

from dipy.tracking.streamlinespeed import compress_streamlines
import nibabel as nib
from nibabel.streamlines import LazyTractogram
import numpy as np

from scilpy.io.streamlines import check_tracts_same_format
from scilpy.tractograms.tractogram_operations import compress_streamlines_wrapper
from scilpy.io.utils import (add_overwrite_arg,
add_verbose_arg,
assert_inputs_exist,
Expand All @@ -40,11 +40,6 @@ def _build_arg_parser():
return p


def compress_streamlines_wrapper(tractogram, error_rate):
return lambda: (compress_streamlines(
s, error_rate) for s in tractogram.streamlines)


def main():
parser = _build_arg_parser()
args = parser.parse_args()
Expand All @@ -64,7 +59,6 @@ def main():
in_tractogram = nib.streamlines.load(args.in_tractogram, lazy_load=True)
compressed_streamlines = compress_streamlines_wrapper(in_tractogram,
args.error_rate)

out_tractogram = LazyTractogram(compressed_streamlines,
affine_to_rasmm=np.eye(4))
nib.streamlines.save(out_tractogram, args.out_tractogram,
Expand Down
51 changes: 25 additions & 26 deletions scripts/scil_tractogram_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
ensure that all clusters are represented in the final tractogram.

Example usage:
$ scil_tractogram_resample.py input.trk 1000 output.trk \
--point_wise_std 0.5 --spline 5 10 --keep_invalid_streamlines
$ scil_resample_tractogram.py input.trk 1000 output.trk \
frheault marked this conversation as resolved.
Show resolved Hide resolved
--point_wise_std 0.5 --gaussian 5 --keep_invalid_streamlines
$ scil_visualize_bundles.py output.trk --local_coloring --width=0.1
"""

Expand Down Expand Up @@ -61,25 +61,24 @@ def _build_arg_parser():

# For upsampling:
upsampling_group = p.add_argument_group('Upsampling params')
std_group = upsampling_group.add_mutually_exclusive_group()
std_group.add_argument('--point_wise_std', type=float,
help='Noise to add to existing streamlines\'' +
' points to generate new ones.')
std_group.add_argument('--streamline_wise_std', type=float,
help='Noise to add to existing whole' +
' streamlines to generate new ones.')
sub_p = upsampling_group.add_mutually_exclusive_group()
sub_p.add_argument('--gaussian', metavar='SIGMA', type=int,
help='Sigma for smoothing. Use the value of surronding'
' X,Y,Z points on \nthe streamline to blur the'
' streamlines. A good sigma choice would \nbe '
'around 5.')
sub_p.add_argument('--spline', nargs=2, metavar=('SIGMA', 'NB_CTRL_POINT'),
type=int,
help='Sigma and number of points for smoothing. Models '
'each streamline \nas a spline. A good sigma '
'choice would be around 5 and control \npoints '
'around 10.')
upsampling_group.add_argument('--point_wise_std', type=float, default=1,
help='Noise to add to existing streamlines '
'points to generate new ones [%(default)s].')
upsampling_group.add_argument('--tube_radius', type=float, default=1,
help='Maximum distance to generate streamlines '
' around the original ones [%(default)s].')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two spaces between streamlines and around.

upsampling_group.add_argument('--force_tube', action='store_true',
help='Force the use of parellel transport to '
'resample, even if the output tractogram '
'has fewer streamlines.')
frheault marked this conversation as resolved.
Show resolved Hide resolved
upsampling_group.add_argument('--gaussian', metavar='SIGMA', type=int,
help='Sigma for smoothing. Use the value of '
'surrounding X,Y,Z points on the '
'streamline to blur the streamlines.\n'
'A good sigma choice would around 5.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be

upsampling_group.add_argument('-e', dest='error_rate', type=float, default=0.1,
help='Maximum compression distance in mm '
'[%(default)s].')

upsampling_group.add_argument(
'--keep_invalid_streamlines', action='store_true',
Expand Down Expand Up @@ -116,8 +115,8 @@ def main():
args = parser.parse_args()

if (args.point_wise_std is not None and args.point_wise_std <= 0) or \
(args.streamline_wise_std is not None and
args.streamline_wise_std <= 0):
(args.tube_radius is not None and
args.tube_radius <= 0):
parser.error('STD needs to be above 0.')

assert_inputs_exist(parser, args.in_tractogram)
Expand All @@ -139,15 +138,15 @@ def main():
logging.debug("Done. Now getting {} streamlines."
.format(args.nb_streamlines))

if args.nb_streamlines > original_number:
if args.nb_streamlines > original_number or args.tube_radius:
# Check is done here because it is not required if downsampling
if not args.point_wise_std and not args.streamline_wise_std:
parser.error("one of the arguments --point_wise_std " +
"--streamline_wise_std is required")
frheault marked this conversation as resolved.
Show resolved Hide resolved
sft = upsample_tractogram(
sft, args.nb_streamlines,
args.point_wise_std, args.streamline_wise_std,
args.gaussian, args.spline, args.seed)
args.point_wise_std, args.tube_radius,
args.gaussian, args.error_rate, args.seed)
elif args.nb_streamlines < original_number:
if args.downsample_per_cluster:
# output contains rejected streamlines, we don't use them.
Expand Down
12 changes: 12 additions & 0 deletions scripts/tests/test_tractogram_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,15 @@ def test_execution_upsample(script_runner):
'--point_wise_std', '0.5', '-f',
'--downsample_per_cluster')
assert ret.success


def test_execution_upsample_ptt(script_runner):
os.chdir(os.path.expanduser(tmp_dir.name))
in_tracto = os.path.join(get_home(), 'tracking',
'union_shuffle_sub.trk')

ret = script_runner.run('scil_tractogram_resample.py', in_tracto,
'500', 'union_shuffle_sub_upsampled.trk', '-f',
'--point_wise_std', '10', '--tube_radius', '5',
'--force_tube')
assert ret.success