Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move utils.streamlines methods #950

Merged
merged 10 commits into from
Mar 27, 2024
3 changes: 2 additions & 1 deletion scilpy/segment/tractogram_from_roi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from scilpy.io.image import get_data_as_mask
from scilpy.io.streamlines import load_tractogram_with_reference
from scilpy.segment.streamlines import filter_grid_roi, filter_grid_roi_both
from scilpy.tractanalysis.features import remove_loops_and_sharp_turns
from scilpy.tractograms.streamline_operations import \
remove_loops_and_sharp_turns
from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map

from scilpy.tractograms.streamline_and_mask_operations import \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,163 @@

from itertools import count, takewhile
import logging
from multiprocessing import Pool

from dipy.segment.clustering import QuickBundles, qbx_and_merge
from dipy.segment.clustering import QuickBundles
from dipy.segment.featurespeed import ResampleFeature
from dipy.segment.metric import AveragePointwiseEuclideanMetric
from dipy.tracking import metrics as tm
import numpy as np
from dipy.tracking.streamlinespeed import set_number_of_points
from scipy.spatial import cKDTree
from sklearn.cluster import KMeans

from scilpy.tractograms.streamline_and_mask_operations import \
get_endpoints_density_map
from scilpy.tractograms.streamline_operations import \
resample_streamlines_num_points
resample_streamlines_num_points, get_streamlines_bounding_box


def get_streamlines_centroid(streamlines, nb_points):
"""
Compute centroid from streamlines using QuickBundles.

Parameters
----------
streamlines: list of ndarray
The list of streamlines from which we compute the centroid.
nb_points: int
Number of points defining the centroid streamline.

Returns
-------
List of length one, containing a np.ndarray of shape (nb_points, 3)
"""
resample_feature = ResampleFeature(nb_points=nb_points)
quick_bundle = QuickBundles(
threshold=np.inf,
metric=AveragePointwiseEuclideanMetric(resample_feature))
clusters = quick_bundle.cluster(streamlines)
centroid_streamlines = clusters.centroids

return centroid_streamlines


def uniformize_bundle_sft(sft, axis=None, ref_bundle=None, swap=False):
"""Uniformize the streamlines in the given tractogram.

Parameters
----------
sft: StatefulTractogram
The tractogram that contains the list of streamlines to be uniformized
axis: int, optional
Orient endpoints in the given axis
ref_bundle: streamlines
Orient endpoints the same way as this bundle (or centroid)
swap: boolean, optional
Swap the orientation of streamlines
"""
old_space = sft.space
old_origin = sft.origin
sft.to_vox()
sft.to_corner()
density = get_endpoints_density_map(sft, point_to_select=3)
indices = np.argwhere(density > 0)
kmeans = KMeans(n_clusters=2, random_state=0, copy_x=True,
n_init=20).fit(indices)

labels = np.zeros(density.shape)
for i in range(len(kmeans.labels_)):
labels[tuple(indices[i])] = kmeans.labels_[i]+1

k_means_centers = kmeans.cluster_centers_
main_dir_barycenter = np.argmax(
np.abs(k_means_centers[0] - k_means_centers[-1]))

if len(sft.streamlines) > 0:
axis_name = ['x', 'y', 'z']
if axis is None or ref_bundle is not None:
if ref_bundle is not None:
ref_bundle.to_vox()
ref_bundle.to_corner()
centroid = get_streamlines_centroid(ref_bundle.streamlines,
20)[0]
else:
centroid = get_streamlines_centroid(sft.streamlines, 20)[0]
main_dir_ends = np.argmax(np.abs(centroid[0] - centroid[-1]))
main_dir_displacement = np.argmax(
np.abs(np.sum(np.gradient(centroid, axis=0), axis=0)))

if main_dir_displacement != main_dir_ends \
or main_dir_displacement != main_dir_barycenter:
logging.info('Ambiguity in orientation, you should use --axis')
axis = axis_name[main_dir_displacement]
logging.info('Orienting endpoints in the {} axis'.format(axis))
axis_pos = axis_name.index(axis)

if bool(k_means_centers[0][axis_pos] >
k_means_centers[1][axis_pos]) ^ bool(swap):
labels[labels == 1] = 3
labels[labels == 2] = 1
labels[labels == 3] = 2

for i in range(len(sft.streamlines)):
if ref_bundle:
res_centroid = set_number_of_points(centroid, 20)
res_streamlines = set_number_of_points(sft.streamlines[i], 20)
norm_direct = np.sum(
np.linalg.norm(res_centroid - res_streamlines, axis=0))
norm_flip = np.sum(
np.linalg.norm(res_centroid - res_streamlines[::-1],
axis=0))
if bool(norm_direct > norm_flip) ^ bool(swap):
sft.streamlines[i] = sft.streamlines[i][::-1]
for key in sft.data_per_point[i]:
sft.data_per_point[key][i] = \
sft.data_per_point[key][i][::-1]
else:
# Bitwise XOR
if (bool(labels[tuple(sft.streamlines[i][0].astype(int))] >
labels[tuple(sft.streamlines[i][-1].astype(int))])
^ bool(swap)):
sft.streamlines[i] = sft.streamlines[i][::-1]
for key in sft.data_per_point[i]:
sft.data_per_point[key][i] = \
sft.data_per_point[key][i][::-1]
sft.to_space(old_space)
sft.to_origin(old_origin)


def uniformize_bundle_sft_using_mask(sft, mask, swap=False):
"""Uniformize the streamlines in the given tractogram so head is closer to
to a region of interest.

Parameters
----------
sft: StatefulTractogram
The tractogram that contains the list of streamlines to be uniformized
mask: np.ndarray
Mask to use as a reference for the ROI.
swap: boolean, optional
Swap the orientation of streamlines
"""

# barycenter = np.average(np.argwhere(mask), axis=0)
old_space = sft.space
old_origin = sft.origin
sft.to_vox()
sft.to_corner()

tree = cKDTree(np.argwhere(mask))
for i in range(len(sft.streamlines)):
head_dist = tree.query(sft.streamlines[i][0])[0]
tail_dist = tree.query(sft.streamlines[i][-1])[0]
if bool(head_dist > tail_dist) ^ bool(swap):
sft.streamlines[i] = sft.streamlines[i][::-1]
for key in sft.data_per_point[i]:
sft.data_per_point[key][i] = \
sft.data_per_point[key][i][::-1]

sft.to_space(old_space)
sft.to_origin(old_origin)


def detect_ushape(sft, minU, maxU):
Expand Down Expand Up @@ -57,94 +204,6 @@ def detect_ushape(sft, minU, maxU):
return ids


def remove_loops_and_sharp_turns(streamlines,
max_angle,
use_qb=False,
qb_threshold=15.,
qb_seed=0,
num_processes=1):
"""
Remove loops and sharp turns from a list of streamlines.

Parameters
----------
streamlines: list of ndarray
The list of streamlines from which to remove loops and sharp turns.
max_angle: float
Maximal winding angle a streamline can have before
being classified as a loop.
use_qb: bool
Set to True if the additional QuickBundles pass is done.
This will help remove sharp turns. Should only be used on
bundled streamlines, not on whole-brain tractograms.
qb_threshold: float
Quickbundles distance threshold, only used if use_qb is True.
qb_seed: int
Seed to initialize randomness in QuickBundles
num_processes : int
Split the calculation to a pool of children processes.

Returns
-------
list: the ids of clean streamlines
Only the ids are returned so proper filtering can be done afterwards
"""

streamlines_clean = []
ids = []
pool = Pool(num_processes)

windings = pool.map(tm.winding, streamlines)
pool.close()
streamlines_clean = streamlines[np.array(windings) < max_angle]
ids = list(np.where(np.array(windings) < max_angle)[0])

if use_qb:
ids = []
if len(streamlines_clean) > 1:
curvature = []

rng = np.random.RandomState(qb_seed)
clusters = qbx_and_merge(streamlines_clean,
[40, 30, 20, qb_threshold],
rng=rng, verbose=False)

for cc in clusters.centroids:
curvature.append(tm.mean_curvature(cc))
mean_curvature = sum(curvature)/len(curvature)

for i in range(len(clusters.centroids)):
if tm.mean_curvature(clusters.centroids[i]) <= mean_curvature:
ids.extend(clusters[i].indices)
else:
logging.info("Impossible to use the use_qb option because " +
"not more than one streamline left from the\n" +
"input file.")
return ids


def get_streamlines_bounding_box(streamlines):
"""
Classify inliers and outliers from a list of streamlines.
Parameters
----------
streamlines: list of ndarray
The list of streamlines from which inliers and outliers are separated.
Returns
-------
tuple: Minimum and maximum corner coordinate of the streamlines
bounding box
"""
box_min = np.array([np.inf, np.inf, np.inf])
box_max = -np.array([np.inf, np.inf, np.inf])

for s in streamlines:
box_min = np.minimum(box_min, np.min(s, axis=0))
box_max = np.maximum(box_max, np.max(s, axis=0))

return box_min, box_max


def prune(streamlines, threshold, features):
"""
Discriminate streamlines based on a metrics, usually summary from function
Expand Down Expand Up @@ -283,27 +342,3 @@ def remove_outliers(streamlines, threshold, nb_points=12, nb_samplings=30,

return outliers_ids, inliers_ids


def get_streamlines_centroid(streamlines, nb_points):
"""
Compute centroid from streamlines using QuickBundles.

Parameters
----------
streamlines: list of ndarray
The list of streamlines from which we compute the centroid.
nb_points: int
Number of points defining the centroid streamline.

Returns
-------
List of length one, containing a np.ndarray of shape (nb_points, 3)
"""
resample_feature = ResampleFeature(nb_points=nb_points)
quick_bundle = QuickBundles(
threshold=np.inf,
metric=AveragePointwiseEuclideanMetric(resample_feature))
clusters = quick_bundle.cluster(streamlines)
centroid_streamlines = clusters.centroids

return centroid_streamlines
62 changes: 62 additions & 0 deletions scilpy/tractograms/dps_and_dpp_management.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,68 @@
# -*- coding: utf-8 -*-
import numpy as np

from scilpy.viz.utils import clip_and_normalize_data_for_cmap


def add_data_as_color_dpp(sft, cmap, data, clip_outliers, min_range, max_range,
min_cmap, max_cmap, log, LUT):
"""
Normalizes data between 0 and 1 for an easier management with colormaps.
The real lower bound and upperbound are returned.

Data can be clipped to (min_range, max_range) before normalization.
Alternatively, data can be kept as is,

Parameters
----------
sft: StatefulTractogram
The tractogram
cmap: plt colormap
The colormap
data: np.ndarray
The data to convert to color. Expecting one value per point to add as
dpp. If instead data has one value per streamline, setting the same
color to all points of the streamline (as dpp).
clip_outliers: bool
See description of the following parameters in
clip_and_normalize_data_for_cmap.
min_range: float
EmmaRenauld marked this conversation as resolved.
Show resolved Hide resolved
max_range: float
min_cmap: float
max_cmap: float
log: bool
LUT: np.ndarray

Returns
-------
sft: StatefulTractogram
The tractogram, with dpp 'color' added.
lbound: float
The lower bound of the associated colormap.
ubound: float
The upper bound of the associated colormap.
"""
values, lbound, ubound = clip_and_normalize_data_for_cmap(
data, clip_outliers, min_range, max_range,
min_cmap, max_cmap, log, LUT)

color = cmap(values)[:, 0:3] * 255
if len(color) == len(sft):
tmp = [np.tile([color[i][0], color[i][1], color[i][2]],
(len(sft.streamlines[i]), 1))
for i in range(len(sft.streamlines))]
sft.data_per_point['color'] = tmp
elif len(color) == len(sft.streamlines._data):
sft.data_per_point['color'] = sft.streamlines
sft.data_per_point['color']._data = color
else:
raise ValueError("Error in the code... Colors do not have the right "
"shape. Expecting either one color per streamline "
"({}) or one per point ({}) but got {}."
.format(len(sft), len(sft.streamlines._data),
len(color)))
return sft, lbound, ubound


def convert_dps_to_dpp(sft, keys, overwrite=False):
"""
Expand Down
Loading