diff --git a/scilpy/segment/tractogram_from_roi.py b/scilpy/segment/tractogram_from_roi.py index e53d3ca1e..431946d2e 100644 --- a/scilpy/segment/tractogram_from_roi.py +++ b/scilpy/segment/tractogram_from_roi.py @@ -14,7 +14,8 @@ from scilpy.io.image import get_data_as_mask from scilpy.io.streamlines import load_tractogram_with_reference from scilpy.segment.streamlines import filter_grid_roi, filter_grid_roi_both -from scilpy.tractanalysis.features import remove_loops_and_sharp_turns +from scilpy.tractograms.streamline_operations import \ + remove_loops_and_sharp_turns from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map from scilpy.tractograms.streamline_and_mask_operations import \ diff --git a/scilpy/tractanalysis/features.py b/scilpy/tractanalysis/bundle_operations.py similarity index 59% rename from scilpy/tractanalysis/features.py rename to scilpy/tractanalysis/bundle_operations.py index af4696275..2a380f319 100644 --- a/scilpy/tractanalysis/features.py +++ b/scilpy/tractanalysis/bundle_operations.py @@ -2,16 +2,163 @@ from itertools import count, takewhile import logging -from multiprocessing import Pool -from dipy.segment.clustering import QuickBundles, qbx_and_merge +from dipy.segment.clustering import QuickBundles from dipy.segment.featurespeed import ResampleFeature from dipy.segment.metric import AveragePointwiseEuclideanMetric -from dipy.tracking import metrics as tm import numpy as np +from dipy.tracking.streamlinespeed import set_number_of_points +from scipy.spatial import cKDTree +from sklearn.cluster import KMeans +from scilpy.tractograms.streamline_and_mask_operations import \ + get_endpoints_density_map from scilpy.tractograms.streamline_operations import \ - resample_streamlines_num_points + resample_streamlines_num_points, get_streamlines_bounding_box + + +def get_streamlines_centroid(streamlines, nb_points): + """ + Compute centroid from streamlines using QuickBundles. + + Parameters + ---------- + streamlines: list of ndarray + The list of streamlines from which we compute the centroid. + nb_points: int + Number of points defining the centroid streamline. + + Returns + ------- + List of length one, containing a np.ndarray of shape (nb_points, 3) + """ + resample_feature = ResampleFeature(nb_points=nb_points) + quick_bundle = QuickBundles( + threshold=np.inf, + metric=AveragePointwiseEuclideanMetric(resample_feature)) + clusters = quick_bundle.cluster(streamlines) + centroid_streamlines = clusters.centroids + + return centroid_streamlines + + +def uniformize_bundle_sft(sft, axis=None, ref_bundle=None, swap=False): + """Uniformize the streamlines in the given tractogram. + + Parameters + ---------- + sft: StatefulTractogram + The tractogram that contains the list of streamlines to be uniformized + axis: int, optional + Orient endpoints in the given axis + ref_bundle: streamlines + Orient endpoints the same way as this bundle (or centroid) + swap: boolean, optional + Swap the orientation of streamlines + """ + old_space = sft.space + old_origin = sft.origin + sft.to_vox() + sft.to_corner() + density = get_endpoints_density_map(sft, point_to_select=3) + indices = np.argwhere(density > 0) + kmeans = KMeans(n_clusters=2, random_state=0, copy_x=True, + n_init=20).fit(indices) + + labels = np.zeros(density.shape) + for i in range(len(kmeans.labels_)): + labels[tuple(indices[i])] = kmeans.labels_[i]+1 + + k_means_centers = kmeans.cluster_centers_ + main_dir_barycenter = np.argmax( + np.abs(k_means_centers[0] - k_means_centers[-1])) + + if len(sft.streamlines) > 0: + axis_name = ['x', 'y', 'z'] + if axis is None or ref_bundle is not None: + if ref_bundle is not None: + ref_bundle.to_vox() + ref_bundle.to_corner() + centroid = get_streamlines_centroid(ref_bundle.streamlines, + 20)[0] + else: + centroid = get_streamlines_centroid(sft.streamlines, 20)[0] + main_dir_ends = np.argmax(np.abs(centroid[0] - centroid[-1])) + main_dir_displacement = np.argmax( + np.abs(np.sum(np.gradient(centroid, axis=0), axis=0))) + + if main_dir_displacement != main_dir_ends \ + or main_dir_displacement != main_dir_barycenter: + logging.info('Ambiguity in orientation, you should use --axis') + axis = axis_name[main_dir_displacement] + logging.info('Orienting endpoints in the {} axis'.format(axis)) + axis_pos = axis_name.index(axis) + + if bool(k_means_centers[0][axis_pos] > + k_means_centers[1][axis_pos]) ^ bool(swap): + labels[labels == 1] = 3 + labels[labels == 2] = 1 + labels[labels == 3] = 2 + + for i in range(len(sft.streamlines)): + if ref_bundle: + res_centroid = set_number_of_points(centroid, 20) + res_streamlines = set_number_of_points(sft.streamlines[i], 20) + norm_direct = np.sum( + np.linalg.norm(res_centroid - res_streamlines, axis=0)) + norm_flip = np.sum( + np.linalg.norm(res_centroid - res_streamlines[::-1], + axis=0)) + if bool(norm_direct > norm_flip) ^ bool(swap): + sft.streamlines[i] = sft.streamlines[i][::-1] + for key in sft.data_per_point[i]: + sft.data_per_point[key][i] = \ + sft.data_per_point[key][i][::-1] + else: + # Bitwise XOR + if (bool(labels[tuple(sft.streamlines[i][0].astype(int))] > + labels[tuple(sft.streamlines[i][-1].astype(int))]) + ^ bool(swap)): + sft.streamlines[i] = sft.streamlines[i][::-1] + for key in sft.data_per_point[i]: + sft.data_per_point[key][i] = \ + sft.data_per_point[key][i][::-1] + sft.to_space(old_space) + sft.to_origin(old_origin) + + +def uniformize_bundle_sft_using_mask(sft, mask, swap=False): + """Uniformize the streamlines in the given tractogram so head is closer to + to a region of interest. + + Parameters + ---------- + sft: StatefulTractogram + The tractogram that contains the list of streamlines to be uniformized + mask: np.ndarray + Mask to use as a reference for the ROI. + swap: boolean, optional + Swap the orientation of streamlines + """ + + # barycenter = np.average(np.argwhere(mask), axis=0) + old_space = sft.space + old_origin = sft.origin + sft.to_vox() + sft.to_corner() + + tree = cKDTree(np.argwhere(mask)) + for i in range(len(sft.streamlines)): + head_dist = tree.query(sft.streamlines[i][0])[0] + tail_dist = tree.query(sft.streamlines[i][-1])[0] + if bool(head_dist > tail_dist) ^ bool(swap): + sft.streamlines[i] = sft.streamlines[i][::-1] + for key in sft.data_per_point[i]: + sft.data_per_point[key][i] = \ + sft.data_per_point[key][i][::-1] + + sft.to_space(old_space) + sft.to_origin(old_origin) def detect_ushape(sft, minU, maxU): @@ -57,94 +204,6 @@ def detect_ushape(sft, minU, maxU): return ids -def remove_loops_and_sharp_turns(streamlines, - max_angle, - use_qb=False, - qb_threshold=15., - qb_seed=0, - num_processes=1): - """ - Remove loops and sharp turns from a list of streamlines. - - Parameters - ---------- - streamlines: list of ndarray - The list of streamlines from which to remove loops and sharp turns. - max_angle: float - Maximal winding angle a streamline can have before - being classified as a loop. - use_qb: bool - Set to True if the additional QuickBundles pass is done. - This will help remove sharp turns. Should only be used on - bundled streamlines, not on whole-brain tractograms. - qb_threshold: float - Quickbundles distance threshold, only used if use_qb is True. - qb_seed: int - Seed to initialize randomness in QuickBundles - num_processes : int - Split the calculation to a pool of children processes. - - Returns - ------- - list: the ids of clean streamlines - Only the ids are returned so proper filtering can be done afterwards - """ - - streamlines_clean = [] - ids = [] - pool = Pool(num_processes) - - windings = pool.map(tm.winding, streamlines) - pool.close() - streamlines_clean = streamlines[np.array(windings) < max_angle] - ids = list(np.where(np.array(windings) < max_angle)[0]) - - if use_qb: - ids = [] - if len(streamlines_clean) > 1: - curvature = [] - - rng = np.random.RandomState(qb_seed) - clusters = qbx_and_merge(streamlines_clean, - [40, 30, 20, qb_threshold], - rng=rng, verbose=False) - - for cc in clusters.centroids: - curvature.append(tm.mean_curvature(cc)) - mean_curvature = sum(curvature)/len(curvature) - - for i in range(len(clusters.centroids)): - if tm.mean_curvature(clusters.centroids[i]) <= mean_curvature: - ids.extend(clusters[i].indices) - else: - logging.info("Impossible to use the use_qb option because " + - "not more than one streamline left from the\n" + - "input file.") - return ids - - -def get_streamlines_bounding_box(streamlines): - """ - Classify inliers and outliers from a list of streamlines. - Parameters - ---------- - streamlines: list of ndarray - The list of streamlines from which inliers and outliers are separated. - Returns - ------- - tuple: Minimum and maximum corner coordinate of the streamlines - bounding box - """ - box_min = np.array([np.inf, np.inf, np.inf]) - box_max = -np.array([np.inf, np.inf, np.inf]) - - for s in streamlines: - box_min = np.minimum(box_min, np.min(s, axis=0)) - box_max = np.maximum(box_max, np.max(s, axis=0)) - - return box_min, box_max - - def prune(streamlines, threshold, features): """ Discriminate streamlines based on a metrics, usually summary from function @@ -283,27 +342,3 @@ def remove_outliers(streamlines, threshold, nb_points=12, nb_samplings=30, return outliers_ids, inliers_ids - -def get_streamlines_centroid(streamlines, nb_points): - """ - Compute centroid from streamlines using QuickBundles. - - Parameters - ---------- - streamlines: list of ndarray - The list of streamlines from which we compute the centroid. - nb_points: int - Number of points defining the centroid streamline. - - Returns - ------- - List of length one, containing a np.ndarray of shape (nb_points, 3) - """ - resample_feature = ResampleFeature(nb_points=nb_points) - quick_bundle = QuickBundles( - threshold=np.inf, - metric=AveragePointwiseEuclideanMetric(resample_feature)) - clusters = quick_bundle.cluster(streamlines) - centroid_streamlines = clusters.centroids - - return centroid_streamlines diff --git a/scilpy/tractograms/dps_and_dpp_management.py b/scilpy/tractograms/dps_and_dpp_management.py index 216de74d0..8ea857958 100644 --- a/scilpy/tractograms/dps_and_dpp_management.py +++ b/scilpy/tractograms/dps_and_dpp_management.py @@ -1,6 +1,76 @@ # -*- coding: utf-8 -*- import numpy as np +from scilpy.viz.utils import clip_and_normalize_data_for_cmap + + +def add_data_as_color_dpp(sft, cmap, data, clip_outliers, min_range, max_range, + min_cmap, max_cmap, log, LUT): + """ + Normalizes data between 0 and 1 for an easier management with colormaps. + The real lower bound and upperbound are returned. + + Data can be clipped to (min_range, max_range) before normalization. + Alternatively, data can be kept as is, + + Parameters + ---------- + sft: StatefulTractogram + The tractogram + cmap: plt colormap + The colormap + data: np.ndarray + The data to convert to color. Expecting one value per point to add as + dpp. If instead data has one value per streamline, setting the same + color to all points of the streamline (as dpp). + clip_outliers: bool + See description of the following parameters in + clip_and_normalize_data_for_cmap. + min_range: float + Data values below min_range will be clipped. + max_range: float + Data values above max_range will be clipped. + min_cmap: float + Minimum value of the colormap. Most useful when min_range and max_range + are not set; to fix the colormap range without modifying the data. + max_cmap: float + Maximum value of the colormap. Idem. + log: bool + If True, apply a logarithmic scale to the data. + LUT: np.ndarray + If set, replaces the data values by the Look-Up Table values. In order, + the first value of the LUT is set everywhere where data==1, etc. + + Returns + ------- + sft: StatefulTractogram + The tractogram, with dpp 'color' added. + lbound: float + The lower bound of the associated colormap. + ubound: float + The upper bound of the associated colormap. + """ + values, lbound, ubound = clip_and_normalize_data_for_cmap( + data, clip_outliers, min_range, max_range, + min_cmap, max_cmap, log, LUT) + + color = cmap(values)[:, 0:3] * 255 + if len(color) == len(sft): + tmp = [np.tile([color[i][0], color[i][1], color[i][2]], + (len(sft.streamlines[i]), 1)) + for i in range(len(sft.streamlines))] + sft.data_per_point['color'] = tmp + elif len(color) == len(sft.streamlines._data): + sft.data_per_point['color'] = sft.streamlines + sft.data_per_point['color']._data = color + else: + raise ValueError("Error in the code... Colors do not have the right " + "shape. Expecting either one color per streamline " + "({}) or one per point ({}) but got {}." + .format(len(sft), len(sft.streamlines._data), + len(color))) + return sft, lbound, ubound + def convert_dps_to_dpp(sft, keys, overwrite=False): """ diff --git a/scilpy/tractograms/streamline_operations.py b/scilpy/tractograms/streamline_operations.py index 77056ae58..87558071c 100644 --- a/scilpy/tractograms/streamline_operations.py +++ b/scilpy/tractograms/streamline_operations.py @@ -1,10 +1,15 @@ # -*- coding: utf-8 -*- +import copy import logging +from multiprocessing import Pool import numpy as np import scipy.ndimage as ndi from dipy.io.stateful_tractogram import StatefulTractogram -from dipy.tracking.streamlinespeed import (length, set_number_of_points) +from dipy.segment.clustering import qbx_and_merge +from dipy.tracking import metrics as tm +from dipy.tracking.streamlinespeed import (length, set_number_of_points, + compress_streamlines) from scipy.interpolate import splev, splprep from scilpy.utils.util import rotation_around_vector_matrix @@ -87,6 +92,200 @@ def _get_point_on_line(first_point, second_point, vox_lower_corner): return first_point + ray * (t0 + t1) / 2. +def get_angles(sft): + """Color streamlines according to their length. + + Parameters + ---------- + sft: StatefulTractogram + The tractogram. + + Returns + ------- + angles: list[np.ndarray] + The angles per streamline, in degree. + """ + angles = [] + for i in range(len(sft.streamlines)): + dirs = np.diff(sft.streamlines[i], axis=0) + dirs /= np.linalg.norm(dirs, axis=-1, keepdims=True) + cos_angles = np.sum(dirs[:-1, :] * dirs[1:, :], axis=1) + # Resolve numerical instability + cos_angles = np.minimum(np.maximum(-1.0, cos_angles), 1.0) + line_angles = [0.0] + list(np.arccos(cos_angles)) + [0.0] + angles.extend(line_angles) + + angles = np.rad2deg(angles) + + return angles + + +def get_values_along_length(sft): + """Get the streamlines' coordinate positions according to their length. + + Parameters + ---------- + sft: StatefulTractogram + The tractogram that contains the list of streamlines to be colored + + Returns + ------- + positions: list[list] + For each streamline, the linear distribution of its length. + """ + positions = [] + for i in range(len(sft.streamlines)): + positions.extend(list(np.linspace(0, 1, len(sft.streamlines[i])))) + + return positions + + +def compress_sft(sft, tol_error=0.01): + """ + Compress a stateful tractogram. Uses Dipy's compress_streamlines, but + deals with space better. + + Dipy's description: + The compression consists in merging consecutive segments that are + nearly collinear. The merging is achieved by removing the point the two + segments have in common. + + The linearization process [Presseau15] ensures that every point being + removed are within a certain margin (in mm) of the resulting streamline. + Recommendations for setting this margin can be found in [Presseau15] + (in which they called it tolerance error). + + The compression also ensures that two consecutive points won't be too far + from each other (precisely less or equal than *max_segment_length* mm). + This is a tradeoff to speed up the linearization process [Rheault15]. A + low value will result in a faster linearization but low compression, + whereas a high value will result in a slower linearization but high + compression. + + [Presseau C. et al., A new compression format for fiber tracking datasets, + NeuroImage, no 109, 73-83, 2015.] + + Parameters + ---------- + sft: StatefulTractogram + The sft to compress. + tol_error: float (optional) + Tolerance error in mm (default: 0.01). A rule of thumb is to set it + to 0.01mm for deterministic streamlines and 0.1mm for probabilitic + streamlines. + + Returns + ------- + compressed_sft: StatefulTractogram + """ + # Go to world space + orig_space = sft.space + sft.to_rasmm() + + # Compress streamlines + compressed_streamlines = compress_streamlines(sft.streamlines, + tol_error=tol_error) + if sft.data_per_point is not None and sft.data_per_point.keys(): + logging.warning("Initial StatefulTractogram contained data_per_point. " + "This information will not be carried in the final " + "tractogram.") + + compressed_sft = StatefulTractogram.from_sft( + compressed_streamlines, sft, + data_per_streamline=sft.data_per_streamline) + + # Return to original space + compressed_sft.to_space(orig_space) + + return compressed_sft + + +def cut_invalid_streamlines(sft): + """ Cut streamlines so their longest segment are within the bounding box. + This function keeps the data_per_point and data_per_streamline. + + Parameters + ---------- + sft: StatefulTractogram + The sft to remove invalid points from. + + Returns + ------- + new_sft : StatefulTractogram + New object with the invalid points removed from each streamline. + cutting_counter : int + Number of streamlines that were cut. + """ + if not len(sft): + return sft, 0 + + # Keep track of the streamlines' original space/origin + space = sft.space + origin = sft.origin + + sft.to_vox() + sft.to_corner() + + copy_sft = copy.deepcopy(sft) + epsilon = 0.001 + indices_to_remove, _ = copy_sft.remove_invalid_streamlines() + + new_streamlines = [] + new_data_per_point = {} + new_data_per_streamline = {} + for key in sft.data_per_point.keys(): + new_data_per_point[key] = [] + for key in sft.data_per_streamline.keys(): + new_data_per_streamline[key] = [] + + cutting_counter = 0 + for ind in range(len(sft.streamlines)): + # No reason to try to cut if all points are within the volume + if ind in indices_to_remove: + best_pos = [0, 0] + cur_pos = [0, 0] + for pos, point in enumerate(sft.streamlines[ind]): + if (point < epsilon).any() or \ + (point >= sft.dimensions - epsilon).any(): + cur_pos = [pos+1, pos+1] + if cur_pos[1] - cur_pos[0] > best_pos[1] - best_pos[0]: + best_pos = cur_pos + cur_pos[1] += 1 + + if not best_pos == [0, 0]: + new_streamlines.append( + sft.streamlines[ind][best_pos[0]:best_pos[1]-1]) + cutting_counter += 1 + for key in sft.data_per_streamline.keys(): + new_data_per_streamline[key].append( + sft.data_per_streamline[key][ind]) + for key in sft.data_per_point.keys(): + new_data_per_point[key].append( + sft.data_per_point[key][ind][ + best_pos[0]:best_pos[1]-1]) + else: + logging.warning('Streamlines entirely out of the volume.') + else: + new_streamlines.append(sft.streamlines[ind]) + for key in sft.data_per_streamline.keys(): + new_data_per_streamline[key].append( + sft.data_per_streamline[key][ind]) + for key in sft.data_per_point.keys(): + new_data_per_point[key].append(sft.data_per_point[key][ind]) + new_sft = StatefulTractogram.from_sft( + new_streamlines, sft, data_per_streamline=new_data_per_streamline, + data_per_point=new_data_per_point) + + # Move the streamlines back to the original space/origin + sft.to_space(space) + sft.to_origin(origin) + + new_sft.to_space(space) + new_sft.to_origin(origin) + + return new_sft, cutting_counter + + def filter_streamlines_by_length(sft, min_length=0., max_length=np.inf): """ Filter streamlines using minimum and max length. @@ -423,7 +622,8 @@ def generate_matched_points(sft): return matched_points -def parallel_transport_streamline(streamline, nb_streamlines, radius, rng=None): +def parallel_transport_streamline(streamline, nb_streamlines, radius, + rng=None): """ Generate new streamlines by parallel transport of the input streamline. See [0] and [1] for more details. @@ -508,3 +708,93 @@ def parallel_transport_streamline(streamline, nb_streamlines, radius, rng=None): new_streamlines.append(new_s) return new_streamlines + + +def remove_loops_and_sharp_turns(streamlines, + max_angle, + use_qb=False, + qb_threshold=15., + qb_seed=0, + num_processes=1): + """ + Remove loops and sharp turns from a list of streamlines. + + Parameters + ---------- + streamlines: list of ndarray + The list of streamlines from which to remove loops and sharp turns. + max_angle: float + Maximal winding angle a streamline can have before + being classified as a loop. + use_qb: bool + Set to True if the additional QuickBundles pass is done. + This will help remove sharp turns. Should only be used on + bundled streamlines, not on whole-brain tractograms. + qb_threshold: float + Quickbundles distance threshold, only used if use_qb is True. + qb_seed: int + Seed to initialize randomness in QuickBundles + num_processes : int + Split the calculation to a pool of children processes. + + Returns + ------- + list: the ids of clean streamlines + Only the ids are returned so proper filtering can be done afterwards + """ + + streamlines_clean = [] + ids = [] + pool = Pool(num_processes) + + windings = pool.map(tm.winding, streamlines) + pool.close() + streamlines_clean = streamlines[np.array(windings) < max_angle] + ids = list(np.where(np.array(windings) < max_angle)[0]) + + if use_qb: + ids = [] + if len(streamlines_clean) > 1: + curvature = [] + + rng = np.random.RandomState(qb_seed) + clusters = qbx_and_merge(streamlines_clean, + [40, 30, 20, qb_threshold], + rng=rng, verbose=False) + + for cc in clusters.centroids: + curvature.append(tm.mean_curvature(cc)) + mean_curvature = sum(curvature)/len(curvature) + + for i in range(len(clusters.centroids)): + if tm.mean_curvature(clusters.centroids[i]) <= mean_curvature: + ids.extend(clusters[i].indices) + else: + logging.info("Impossible to use the use_qb option because " + + "not more than one streamline left from the\n" + + "input file.") + return ids + + +def get_streamlines_bounding_box(streamlines): + """ + Classify inliers and outliers from a list of streamlines. + + Parameters + ---------- + streamlines: list of ndarray + The list of streamlines from which inliers and outliers are separated. + + Returns + ------- + tuple: Minimum and maximum corner coordinate of the streamlines + bounding box + """ + box_min = np.array([np.inf, np.inf, np.inf]) + box_max = -np.array([np.inf, np.inf, np.inf]) + + for s in streamlines: + box_min = np.minimum(box_min, np.min(s, axis=0)) + box_max = np.maximum(box_max, np.max(s, axis=0)) + + return box_min, box_max diff --git a/scilpy/tractograms/tests/test_streamline_operations.py b/scilpy/tractograms/tests/test_streamline_operations.py index 7986f57ac..fa54ae99a 100644 --- a/scilpy/tractograms/tests/test_streamline_operations.py +++ b/scilpy/tractograms/tests/test_streamline_operations.py @@ -62,6 +62,26 @@ def _setup_files(): return sft, rois +def test_angles(): + # toDo + pass + + +def test_get_values_along_length(): + # toDo + pass + + +def test_compress_sft(): + # toDo + pass + + +def test_cut_invalid_streamlines(): + # toDo + pass + + def test_filter_streamlines_by_length_max_length(): """ Test the filter_streamlines_by_length function with a max length. """ @@ -374,6 +394,11 @@ def test_smooth_line_spline(): assert dist_1 < dist_2 +def test_generate_matched_points(): + # toDo + pass + + def test_parallel_transport_streamline(): sft, _ = _setup_files() streamline = sft.streamlines[0] diff --git a/scilpy/tractograms/tests/test_tractogram_operations.py b/scilpy/tractograms/tests/test_tractogram_operations.py index eb259d24a..10868f3ac 100644 --- a/scilpy/tractograms/tests/test_tractogram_operations.py +++ b/scilpy/tractograms/tests/test_tractogram_operations.py @@ -201,3 +201,8 @@ def test_split_sft_randomly_per_cluster(): [112.168, 35.259, 59.419]) assert np.allclose(new_sft_list[1].streamlines[0][0], [112.266, 35.4188, 59.0421]) + + +def filter_tractogram_data(): + # toDo + pass diff --git a/scilpy/tractograms/tractogram_operations.py b/scilpy/tractograms/tractogram_operations.py index 48bb25b88..54b79f79b 100644 --- a/scilpy/tractograms/tractogram_operations.py +++ b/scilpy/tractograms/tractogram_operations.py @@ -25,8 +25,8 @@ from scipy.spatial import cKDTree from scilpy.tractograms.streamline_operations import smooth_line_gaussian, \ - resample_streamlines_step_size, parallel_transport_streamline -from scilpy.utils.streamlines import cut_invalid_streamlines + resample_streamlines_step_size, parallel_transport_streamline, \ + cut_invalid_streamlines MIN_NB_POINTS = 10 KEY_INDEX = np.concatenate((range(5), range(-1, -6, -1))) @@ -47,14 +47,14 @@ def shuffle_streamlines(sft, rng_seed=None): return shuffled_sft -def _get_axis_flip_vector(flip_axes): +def get_axis_flip_vector(flip_axes): """ - Create a flip vector from a list of axes + Create a flip vector from a list of axes. Parameters ---------- - flip_axis: list - List of axis you want to flip + flip_axes: list + List of axes you want to flip Returns ------- @@ -97,7 +97,7 @@ def flip_sft(sft, flip_axes): # Could return sft. But creating new SFT (or deep copy). flipped_streamlines = sft.streamlines else: - flip_vector = _get_axis_flip_vector(flip_axes) + flip_vector = get_axis_flip_vector(flip_axes) shift_vector = _get_shift_vector(sft) flipped_streamlines = [] @@ -879,6 +879,44 @@ def split_sft_randomly_per_cluster(orig_sft, chunk_sizes, seed, thresholds): return final_sfts +def filter_tractogram_data(tractogram, streamline_ids): + """ + Filter a tractogram according to streamline ids and keep the data. + + Parameters: + ----------- + tractogram: StatefulTractogram + Tractogram containing the data to be filtered. + streamline_ids: array_like + List of streamline ids the data corresponds to. + + Returns: + -------- + new_tractogram: Tractogram or StatefulTractogram + Returns a new tractogram with only the selected streamlines and data. + """ + + streamline_ids = np.asarray(streamline_ids, dtype=int) + + assert np.all( + np.in1d(streamline_ids, np.arange(len(tractogram.streamlines))) + ), "Received ids outside of streamline range" + + new_streamlines = tractogram.streamlines[streamline_ids] + new_data_per_streamline = tractogram.data_per_streamline[streamline_ids] + new_data_per_point = tractogram.data_per_point[streamline_ids] + + # Could have been nice to deepcopy the tractogram modify the attributes in + # place instead of creating a new one, but tractograms cant be subsampled + # if they have data. + + return StatefulTractogram.from_sft( + new_streamlines, + tractogram, + data_per_point=new_data_per_point, + data_per_streamline=new_data_per_streamline) + + OPERATIONS = { 'difference_robust': difference_robust, 'intersection_robust': intersection_robust, diff --git a/scilpy/utils/streamlines.py b/scilpy/utils/streamlines.py deleted file mode 100644 index 23e91094e..000000000 --- a/scilpy/utils/streamlines.py +++ /dev/null @@ -1,405 +0,0 @@ -# -*- coding: utf-8 -*- -import copy -import logging - -from dipy.io.stateful_tractogram import StatefulTractogram -from dipy.tracking.streamline import set_number_of_points -from dipy.tracking.streamlinespeed import compress_streamlines -import numpy as np -from scipy.spatial import cKDTree -from sklearn.cluster import KMeans - -from scilpy.io.utils import load_matrix_in_any_format -from scilpy.tractanalysis.features import get_streamlines_centroid -from scilpy.tractograms.streamline_and_mask_operations import \ - get_endpoints_density_map - - -def uniformize_bundle_sft(sft, axis=None, ref_bundle=None, swap=False): - """Uniformize the streamlines in the given tractogram. - - Parameters - ---------- - sft: StatefulTractogram - The tractogram that contains the list of streamlines to be uniformized - axis: int, optional - Orient endpoints in the given axis - ref_bundle: streamlines - Orient endpoints the same way as this bundle (or centroid) - swap: boolean, optional - Swap the orientation of streamlines - """ - old_space = sft.space - old_origin = sft.origin - sft.to_vox() - sft.to_corner() - density = get_endpoints_density_map(sft, point_to_select=3) - indices = np.argwhere(density > 0) - kmeans = KMeans(n_clusters=2, random_state=0, copy_x=True, - n_init=20).fit(indices) - - labels = np.zeros(density.shape) - for i in range(len(kmeans.labels_)): - labels[tuple(indices[i])] = kmeans.labels_[i]+1 - - k_means_centers = kmeans.cluster_centers_ - main_dir_barycenter = np.argmax( - np.abs(k_means_centers[0] - k_means_centers[-1])) - - if len(sft.streamlines) > 0: - axis_name = ['x', 'y', 'z'] - if axis is None or ref_bundle is not None: - if ref_bundle is not None: - ref_bundle.to_vox() - ref_bundle.to_corner() - centroid = get_streamlines_centroid(ref_bundle.streamlines, - 20)[0] - else: - centroid = get_streamlines_centroid(sft.streamlines, 20)[0] - main_dir_ends = np.argmax(np.abs(centroid[0] - centroid[-1])) - main_dir_displacement = np.argmax( - np.abs(np.sum(np.gradient(centroid, axis=0), axis=0))) - - if main_dir_displacement != main_dir_ends \ - or main_dir_displacement != main_dir_barycenter: - logging.info('Ambiguity in orientation, you should use --axis') - axis = axis_name[main_dir_displacement] - logging.info('Orienting endpoints in the {} axis'.format(axis)) - axis_pos = axis_name.index(axis) - - if bool(k_means_centers[0][axis_pos] > - k_means_centers[1][axis_pos]) ^ bool(swap): - labels[labels == 1] = 3 - labels[labels == 2] = 1 - labels[labels == 3] = 2 - - for i in range(len(sft.streamlines)): - if ref_bundle: - res_centroid = set_number_of_points(centroid, 20) - res_streamlines = set_number_of_points(sft.streamlines[i], 20) - norm_direct = np.sum( - np.linalg.norm(res_centroid - res_streamlines, axis=0)) - norm_flip = np.sum( - np.linalg.norm(res_centroid - res_streamlines[::-1], axis=0)) - if bool(norm_direct > norm_flip) ^ bool(swap): - sft.streamlines[i] = sft.streamlines[i][::-1] - for key in sft.data_per_point[i]: - sft.data_per_point[key][i] = \ - sft.data_per_point[key][i][::-1] - else: - # Bitwise XOR - if bool(labels[tuple(sft.streamlines[i][0].astype(int))] > - labels[tuple(sft.streamlines[i][-1].astype(int))]) ^ bool(swap): - sft.streamlines[i] = sft.streamlines[i][::-1] - for key in sft.data_per_point[i]: - sft.data_per_point[key][i] = \ - sft.data_per_point[key][i][::-1] - sft.to_space(old_space) - sft.to_origin(old_origin) - - -def uniformize_bundle_sft_using_mask(sft, mask, swap=False): - """Uniformize the streamlines in the given tractogram so head is closer to - to a region of interest. - - Parameters - ---------- - sft: StatefulTractogram - The tractogram that contains the list of streamlines to be uniformized - mask: np.ndarray - Mask to use as a reference for the ROI. - swap: boolean, optional - Swap the orientation of streamlines - """ - - # barycenter = np.average(np.argwhere(mask), axis=0) - old_space = sft.space - old_origin = sft.origin - sft.to_vox() - sft.to_corner() - - tree = cKDTree(np.argwhere(mask)) - for i in range(len(sft.streamlines)): - head_dist = tree.query(sft.streamlines[i][0])[0] - tail_dist = tree.query(sft.streamlines[i][-1])[0] - if bool(head_dist > tail_dist) ^ bool(swap): - sft.streamlines[i] = sft.streamlines[i][::-1] - for key in sft.data_per_point[i]: - sft.data_per_point[key][i] = \ - sft.data_per_point[key][i][::-1] - - sft.to_space(old_space) - sft.to_origin(old_origin) - - -def clip_and_normalize_data_for_cmap(args, data): - if args.LUT: - LUT = load_matrix_in_any_format(args.LUT) - for i, val in enumerate(LUT): - data[data == i+1] = val - - if args.min_range is not None or args.max_range is not None: - data = np.clip(data, args.min_range, args.max_range) - - # get data values range - if args.min_cmap is not None: - lbound = args.min_cmap - else: - lbound = np.min(data) - if args.max_cmap is not None: - ubound = args.max_cmap - else: - ubound = np.max(data) - - if args.log: - data[data > 0] = np.log10(data[data > 0]) - - # normalize data between 0 and 1 - data -= lbound - data = data / ubound if ubound > 0 else data - return data, lbound, ubound - - -def get_color_streamlines_from_angle(sft, args): - """Color streamlines according to their length. - - Parameters - ---------- - sft: StatefulTractogram - The tractogram that contains the list of streamlines to be colored - args: NameSpace - The colormap options. - - Returns - ------- - color: np.ndarray - An array of shape (nb_streamlines, 3) containing the RGB values of - streamlines - lbound: float - Minimal value - ubound: float - Maximal value - """ - angles = [] - for i in range(len(sft.streamlines)): - dirs = np.diff(sft.streamlines[i], axis=0) - dirs /= np.linalg.norm(dirs, axis=-1, keepdims=True) - cos_angles = np.sum(dirs[:-1, :] * dirs[1:, :], axis=1) - # Resolve numerical instability - cos_angles = np.minimum(np.maximum(-1.0, cos_angles), 1.0) - line_angles = [0.0] + list(np.arccos(cos_angles)) + [0.0] - angles.extend(line_angles) - - angles = np.rad2deg(angles) - - return clip_and_normalize_data_for_cmap(args, angles) - - -def get_color_streamlines_along_length(sft, args): - """Color streamlines according to their length. - - Parameters - ---------- - sft: StatefulTractogram - The tractogram that contains the list of streamlines to be colored - args: NameSpace - The colormap options. - - Returns - ------- - color: np.ndarray - An array of shape (nb_streamlines, 3) containing the RGB values of - streamlines - lbound: int - Minimal value - ubound: int - Maximal value - """ - positions = [] - for i in range(len(sft.streamlines)): - positions.extend(list(np.linspace(0, 1, len(sft.streamlines[i])))) - - return clip_and_normalize_data_for_cmap(args, positions) - - -def filter_tractogram_data(tractogram, streamline_ids): - """Filter tractogram according to streamline ids and keep the data - - Parameters: - ----------- - tractogram: StatefulTractogram - Tractogram containing the data to be filtered - streamline_ids: array_like - List of streamline ids the data corresponds to - - Returns: - -------- - new_tractogram: Tractogram or StatefulTractogram - Returns a new tractogram with only the selected streamlines - and data - """ - - streamline_ids = np.asarray(streamline_ids, dtype=int) - - assert np.all( - np.in1d(streamline_ids, np.arange(len(tractogram.streamlines))) - ), "Received ids outside of streamline range" - - new_streamlines = tractogram.streamlines[streamline_ids] - new_data_per_streamline = tractogram.data_per_streamline[streamline_ids] - new_data_per_point = tractogram.data_per_point[streamline_ids] - - # Could have been nice to deepcopy the tractogram modify the attributes in - # place instead of creating a new one, but tractograms cant be subsampled - # if they have data - - return StatefulTractogram.from_sft( - new_streamlines, - tractogram, - data_per_point=new_data_per_point, - data_per_streamline=new_data_per_streamline) - - -def compress_sft(sft, tol_error=0.01): - """ - Compress a stateful tractogram. Uses Dipy's compress_streamlines, but - deals with space better. - - Dipy's description: - The compression consists in merging consecutive segments that are - nearly collinear. The merging is achieved by removing the point the two - segments have in common. - - The linearization process [Presseau15] ensures that every point being - removed are within a certain margin (in mm) of the resulting streamline. - Recommendations for setting this margin can be found in [Presseau15] - (in which they called it tolerance error). - - The compression also ensures that two consecutive points won't be too far - from each other (precisely less or equal than *max_segment_length* mm). - This is a tradeoff to speed up the linearization process [Rheault15]. A - low value will result in a faster linearization but low compression, - whereas a high value will result in a slower linearization but high - compression. - - [Presseau C. et al., A new compression format for fiber tracking datasets, - NeuroImage, no 109, 73-83, 2015.] - - Parameters - ---------- - sft: StatefulTractogram - The sft to compress. - tol_error: float (optional) - Tolerance error in mm (default: 0.01). A rule of thumb is to set it - to 0.01mm for deterministic streamlines and 0.1mm for probabilitic - streamlines. - - Returns - ------- - compressed_sft: StatefulTractogram - """ - # Go to world space - orig_space = sft.space - sft.to_rasmm() - - # Compress streamlines - compressed_streamlines = compress_streamlines(sft.streamlines, - tol_error=tol_error) - if sft.data_per_point is not None and sft.data_per_point.keys(): - logging.warning("Initial StatefulTractogram contained data_per_point. " - "This information will not be carried in the final " - "tractogram.") - - compressed_sft = StatefulTractogram.from_sft( - compressed_streamlines, sft, - data_per_streamline=sft.data_per_streamline) - - # Return to original space - compressed_sft.to_space(orig_space) - - return compressed_sft - - -def cut_invalid_streamlines(sft): - """ Cut streamlines so their longest segment are within the bounding box. - This function keeps the data_per_point and data_per_streamline. - - Parameters - ---------- - sft: StatefulTractogram - The sft to remove invalid points from. - - Returns - ------- - new_sft : StatefulTractogram - New object with the invalid points removed from each streamline. - cutting_counter : int - Number of streamlines that were cut. - """ - if not len(sft): - return sft, 0 - - # Keep track of the streamlines' original space/origin - space = sft.space - origin = sft.origin - - sft.to_vox() - sft.to_corner() - - copy_sft = copy.deepcopy(sft) - epsilon = 0.001 - indices_to_remove, _ = copy_sft.remove_invalid_streamlines() - - new_streamlines = [] - new_data_per_point = {} - new_data_per_streamline = {} - for key in sft.data_per_point.keys(): - new_data_per_point[key] = [] - for key in sft.data_per_streamline.keys(): - new_data_per_streamline[key] = [] - - cutting_counter = 0 - for ind in range(len(sft.streamlines)): - # No reason to try to cut if all points are within the volume - if ind in indices_to_remove: - best_pos = [0, 0] - cur_pos = [0, 0] - for pos, point in enumerate(sft.streamlines[ind]): - if (point < epsilon).any() or \ - (point >= sft.dimensions - epsilon).any(): - cur_pos = [pos+1, pos+1] - if cur_pos[1] - cur_pos[0] > best_pos[1] - best_pos[0]: - best_pos = cur_pos - cur_pos[1] += 1 - - if not best_pos == [0, 0]: - new_streamlines.append( - sft.streamlines[ind][best_pos[0]:best_pos[1]-1]) - cutting_counter += 1 - for key in sft.data_per_streamline.keys(): - new_data_per_streamline[key].append( - sft.data_per_streamline[key][ind]) - for key in sft.data_per_point.keys(): - new_data_per_point[key].append( - sft.data_per_point[key][ind][best_pos[0]:best_pos[1]-1]) - else: - logging.warning('Streamlines entirely out of the volume.') - else: - new_streamlines.append(sft.streamlines[ind]) - for key in sft.data_per_streamline.keys(): - new_data_per_streamline[key].append( - sft.data_per_streamline[key][ind]) - for key in sft.data_per_point.keys(): - new_data_per_point[key].append(sft.data_per_point[key][ind]) - new_sft = StatefulTractogram.from_sft(new_streamlines, sft, - data_per_streamline=new_data_per_streamline, - data_per_point=new_data_per_point) - - # Move the streamlines back to the original space/origin - sft.to_space(space) - sft.to_origin(origin) - - new_sft.to_space(space) - new_sft.to_origin(origin) - - return new_sft, cutting_counter diff --git a/scilpy/viz/utils.py b/scilpy/viz/utils.py index 0278ba8a8..769b61eb1 100644 --- a/scilpy/viz/utils.py +++ b/scilpy/viz/utils.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import matplotlib.pyplot as plt +import numpy as np from matplotlib import colors @@ -33,3 +34,123 @@ def get_colormap(name): return cmap return plt.colormaps.get_cmap(name) + + +def prepare_colorbar_figure(cmap, lbound, ubound, nb_values=255, nb_ticks=10, + horizontal=False, log=False,): + """ + Prepares a matplotlib figure of a colorbar. + + Parameters + ---------- + cmap: plt colormap + Ex, result from get_colormap(). + lbound: float + Lower bound + ubound: float + Upper bound + nb_values: int + Number of values. The cmap will be linearly divided between lbound and + ubound into nb_values values. Default: 255. + nb_ticks: int + The ticks on the colorbar can be set differently than the nb_values. + Default: 10. + horizontal: bool + If true, plot a horizontal cmap. + log: bool + If true, apply a logarithm scaling. + + Returns + ------- + fig: plt figure + The plt figure. + """ + gradient = cmap(np.linspace(0, 1, ))[:, 0:3] + + # TODO: Is there a better way to draw a gradient-filled rectangle? + width = int(nb_values * 0.1) + gradient = np.tile(gradient, (width, 1, 1)) + if not horizontal: + gradient = np.swapaxes(gradient, 0, 1) + + fig, ax = plt.subplots(1, 1) + ax.imshow(gradient, origin='lower') + + ticks_labels = ['{0:.3f}'.format(i) for i in + np.linspace(lbound, ubound, nb_ticks)] + + if log: + ticks_labels = ['log(' + t + ')' for t in ticks_labels] + + ticks = np.linspace(0, nb_values - 1, nb_ticks) + if not horizontal: + ax.set_yticks(ticks) + ax.set_yticklabels(ticks_labels) + ax.set_xticks([]) + else: + ax.set_xticks(ticks) + ax.set_xticklabels(ticks_labels) + ax.set_yticks([]) + return fig + + +def clip_and_normalize_data_for_cmap( + data, clip_outliers=False, min_range=None, max_range=None, + min_cmap=None, max_cmap=None, log=False, LUT=None): + """ + Normalizes data between 0 and 1 for an easier management with colormaps. + The real lower bound and upperbound are returned. + + Data can be clipped to (min_range, max_range) before normalization. + Alternatively, data can be kept as is, + + Parameters + ---------- + data: np.ndarray + The data. + clip_outliers: bool + If True, clips the data to the lowest and highest 5% quantile before + normalizing and before any other clipping. + min_range: float + Data values below min_range will be clipped. + max_range: float + Data values above max_range will be clipped. + min_cmap: float + Minimum value of the colormap. Most useful when min_range and max_range + are not set; to fix the colormap range without modifying the data. + max_cmap: float + Maximum value of the colormap. Idem. + log: bool + If True, apply a logarithmic scale to the data. + LUT: np.ndarray + If set, replaces the data values by the Look-Up Table values. In order, + the first value of the LUT is set everywhere where data==1, etc. + """ + if LUT is not None: + for i, val in enumerate(LUT): + data[data == i+1] = val + + # Clipping + if clip_outliers: + data = np.clip(data, np.quantile(data, 0.05), + np.quantile(data, 0.95)) + if min_range is not None or max_range is not None: + data = np.clip(data, min_range, max_range) + + # get data values range + if min_cmap is not None: + lbound = min_cmap + else: + lbound = np.min(data) + if max_cmap is not None: + ubound = max_cmap + else: + ubound = np.max(data) + + if log: + data[data > 0] = np.log10(data[data > 0]) + + # normalize data between 0 and 1 + data -= lbound + data = data / ubound if ubound > 0 else data + return data, lbound, ubound diff --git a/scripts/scil_bundle_compute_centroid.py b/scripts/scil_bundle_compute_centroid.py index 97af1869b..d4b726efa 100755 --- a/scripts/scil_bundle_compute_centroid.py +++ b/scripts/scil_bundle_compute_centroid.py @@ -19,7 +19,7 @@ assert_outputs_exist, add_verbose_arg, add_reference_arg) -from scilpy.tractanalysis.features import get_streamlines_centroid +from scilpy.tractanalysis.bundle_operations import get_streamlines_centroid def _build_arg_parser(): diff --git a/scripts/scil_bundle_label_map.py b/scripts/scil_bundle_label_map.py index b803fb7a5..9e35c0d1a 100755 --- a/scripts/scil_bundle_label_map.py +++ b/scripts/scil_bundle_label_map.py @@ -36,13 +36,13 @@ add_verbose_arg, assert_inputs_exist, assert_output_dirs_exist_and_empty) +from scilpy.tractanalysis.bundle_operations import uniformize_bundle_sft from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map from scilpy.tractanalysis.distance_to_centroid import min_dist_to_centroid from scilpy.tractograms.streamline_and_mask_operations import \ cut_outside_of_mask_streamlines from scilpy.tractograms.streamline_operations import \ resample_streamlines_num_points -from scilpy.utils.streamlines import uniformize_bundle_sft from scilpy.viz.utils import get_colormap diff --git a/scripts/scil_bundle_shape_measures.py b/scripts/scil_bundle_shape_measures.py index e6498751e..d07bc147e 100755 --- a/scripts/scil_bundle_shape_measures.py +++ b/scripts/scil_bundle_shape_measures.py @@ -51,15 +51,13 @@ add_reference_arg, assert_inputs_exist, assert_outputs_exist, link_bundles_and_reference, validate_nbr_processes, assert_headers_compatible) - +from scilpy.tractanalysis.bundle_operations import uniformize_bundle_sft from scilpy.tractanalysis.reproducibility_measures \ import (approximate_surface_node, compute_fractal_dimension) from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map - from scilpy.tractograms.streamline_and_mask_operations import \ get_endpoints_density_map, get_head_tail_density_maps -from scilpy.utils.streamlines import uniformize_bundle_sft EPILOG = """ References: diff --git a/scripts/scil_outlier_rejection.py b/scripts/scil_outlier_rejection.py index c2de9ba8d..fc123fed2 100755 --- a/scripts/scil_outlier_rejection.py +++ b/scripts/scil_outlier_rejection.py @@ -22,7 +22,7 @@ assert_inputs_exist, assert_outputs_exist, check_tracts_same_format) -from scilpy.tractanalysis.features import remove_outliers +from scilpy.tractanalysis.bundle_operations import remove_outliers def _build_arg_parser(): diff --git a/scripts/scil_tractogram_assign_custom_color.py b/scripts/scil_tractogram_assign_custom_color.py index ce8121407..5fb90806d 100755 --- a/scripts/scil_tractogram_assign_custom_color.py +++ b/scripts/scil_tractogram_assign_custom_color.py @@ -3,37 +3,45 @@ """ The script uses scalars from an anatomy, data_per_point or data_per_streamline -(e.g commit_weights) to visualize them on the streamlines. -Saves the RGB values in the data_per_point (color_x, color_y, color_z). +(e.g. commit_weights) to visualize them on the streamlines. +Saves the RGB values in the data_per_point 'color' with 3 values per point: +(color_x, color_y, color_z). If called with .tck, the output will always be .trk, because data_per_point has no equivalent in tck file. -The usage of --use_dps, --use_dpp and --from_anatomy is more complex. It maps -the raw values from these sources to RGB using a colormap. - --use_dps: total nbr of streamlines of the tractogram = len(streamlines) - --use_dpp: total nbr of points of the tractogram = len(streamlines._data) - +If used with a visualization software like MI-Brain +(https://github.com/imeka/mi-brain), the 'color' dps is applied by default at +loading time. + +COLORING METHOD +This script maps the raw values from these sources to RGB using a colormap. + --use_dpp: The data from each point is converted to a color. + --use_dps: The same color is applied to all points of the streamline. + --from_anatomy: The voxel's color is used for the points of the streamlines + crossing it. See also scil_tractogram_project_map_to_streamlines.py. You + can have more options to project maps to dpp, and then use --use_dpp here. + --along_profile: The data used here is each point's position in the + streamline. To have nice results, you should first uniformize head/tail. + See scil_tractogram_uniformize_endpoints.py. + --local_angle. + +COLORING OPTIONS A minimum and a maximum range can be provided to clip values. If the range of values is too large for intuitive visualization, a log transform can be applied. If the data provided from --use_dps, --use_dpp and --from_anatomy are integer labels, they can be mapped using a LookUp Table (--LUT). -The file provided as a LUT should be either .txt or .npy and if the -size is N=20, then the data provided should be between 1-20. - -Example: Use --from_anatomy with a voxel labels map (values from 1-20) with a -text file containing 20 p-values to map p-values to the bundle for -visualisation. +The file provided as a LUT should be either .txt or .npy and if the size is +N=20, then the data provided should be between 1-20. A custom colormap can be provided using --colormap. It should be a string containing a colormap name OR multiple Matplotlib named colors separated by -. The colormap used for mapping values to colors can be saved to a png/jpg image using the --out_colorbar option. -The script can also be used to color streamlines according to their length -using the --along_profile option. The streamlines must be uniformized. +See also: scil_tractogram_assign_uniform_color.py, for simplified options. Formerly: scil_assign_custom_color_to_tractogram.py """ @@ -54,12 +62,10 @@ add_reference_arg, add_verbose_arg, load_matrix_in_any_format) -from scilpy.utils.streamlines import get_color_streamlines_along_length, \ - get_color_streamlines_from_angle, clip_and_normalize_data_for_cmap -from scilpy.viz.utils import get_colormap - -COLORBAR_NB_VALUES = 255 -NB_TICKS = 10 +from scilpy.tractograms.dps_and_dpp_management import add_data_as_color_dpp +from scilpy.tractograms.streamline_operations import (get_values_along_length, + get_angles) +from scilpy.viz.utils import get_colormap, prepare_colorbar_figure def _build_arg_parser(): @@ -72,44 +78,46 @@ def _build_arg_parser(): p.add_argument('out_tractogram', help='Output tractogram (.trk or .tck).') - cbar_g = p.add_argument_group('Colorbar Options') + cbar_g = p.add_argument_group('Colorbar options') cbar_g.add_argument('--out_colorbar', help='Optional output colorbar (.png, .jpg or any ' - 'format supported by matplotlib).') + 'format \nsupported by matplotlib).') cbar_g.add_argument('--show_colorbar', action='store_true', help="Will show the colorbar. Must be used with " - "--out_colorbar to be effective.") + "--out_colorbar \nto be effective.") cbar_g.add_argument('--horizontal_cbar', action='store_true', help='Draw horizontal colorbar (vertical by default).') - g1 = p.add_argument_group(title='Coloring Methods') - p1 = g1.add_mutually_exclusive_group() + g1 = p.add_argument_group(title='Coloring method') + p1 = g1.add_mutually_exclusive_group(required=True) p1.add_argument('--use_dps', metavar='DPS_KEY', - help='Use the data_per_streamline (scalar) for coloring,\n' - 'linear scaling from min-max, e.g. commit_weights.') + help='Use the data_per_streamline (scalar) for coloring.') p1.add_argument('--use_dpp', metavar='DPP_KEY', - help='Use the data_per_point (scalar) for coloring,\n' - 'linear scaling from min-max.') - p1.add_argument('--load_dps', metavar='DPS_KEY', + help='Use the data_per_point (scalar) for coloring.') + p1.add_argument('--load_dps', metavar='DPS_FILE', help='Load data per streamline (scalar) for coloring') - p1.add_argument('--load_dpp', metavar='DPP_KEY', + p1.add_argument('--load_dpp', metavar='DPP_FILE', help='Load data per point (scalar) for coloring') p1.add_argument('--from_anatomy', metavar='FILE', help='Use the voxel data for coloring,\n' 'linear scaling from minmax.') p1.add_argument('--along_profile', action='store_true', help='Color streamlines according to each point position' - 'along its length.\nMust be uniformized head/tail.') + 'along its length.') p1.add_argument('--local_angle', action='store_true', help="Color streamlines according to the angle between " "each segment (in degree). \nAngles at first and " "last points are set to 0.") - g2 = p.add_argument_group(title='Coloring Options') + g2 = p.add_argument_group(title='Coloring options') g2.add_argument('--colormap', default='jet', help='Select the colormap for colored trk (dps/dpp) ' '[%(default)s].\nUse two Matplotlib named color separeted ' 'by a - to create your own colormap.') + g2.add_argument('--clip_outliers', action='store_true', + help="If set, we will clip the outliers (first and last " + "5%% quantile). Strongly suggested if your data " + "comes from COMMIT!") g2.add_argument('--min_range', type=float, help='Set the minimum value when using dps/dpp/anatomy.') g2.add_argument('--max_range', type=float, @@ -134,44 +142,12 @@ def _build_arg_parser(): return p -def save_colorbar(cmap, lbound, ubound, args): - gradient = cmap(np.linspace(0, 1, COLORBAR_NB_VALUES))[:, 0:3] - - # TODO: Is there a better way to draw a gradient-filled rectangle? - width = int(COLORBAR_NB_VALUES * 0.1) - gradient = np.tile(gradient, (width, 1, 1)) - if not args.horizontal_cbar: - gradient = np.swapaxes(gradient, 0, 1) - - _, ax = plt.subplots(1, 1) - ax.imshow(gradient, origin='lower') - - ticks_labels = ['{0:.3f}'.format(i) for i in - np.linspace(lbound, ubound, NB_TICKS)] - - if args.log: - ticks_labels = ['log(' + t + ')' for t in ticks_labels] - - ticks = np.linspace(0, COLORBAR_NB_VALUES - 1, NB_TICKS) - if not args.horizontal_cbar: - ax.set_yticks(ticks) - ax.set_yticklabels(ticks_labels) - ax.set_xticks([]) - else: - ax.set_xticks(ticks) - ax.set_xticklabels(ticks_labels) - ax.set_yticks([]) - - plt.savefig(args.out_colorbar, bbox_inches='tight') - if args.show_colorbar: - plt.show() - - def main(): parser = _build_arg_parser() args = parser.parse_args() logging.getLogger().setLevel(logging.getLevelName(args.verbose)) + # Verifications assert_inputs_exist(parser, args.in_tractogram, args.reference) assert_outputs_exist(parser, args, args.out_tractogram, optional=args.out_colorbar) @@ -180,74 +156,72 @@ def main(): logging.warning('Colorbar output not supplied. Ignoring ' '--horizontal_cbar.') - sft = load_tractogram_with_reference(parser, args, args.in_tractogram) + if (args.use_dps is not None and + args.use_dps in ['commit_weights', 'commit2_weights'] and not + args.clip_outliers): + logging.warning("You seem to be using commit weights. They typically " + "have outliers values. We suggest using " + "--clip_outliers.") - if args.LUT: - LUT = load_matrix_in_any_format(args.LUT) - if np.any(sft.streamlines._lengths < len(LUT)): - logging.warning('Some streamlines have fewer point than the size ' - 'of the provided LUT.\nConsider using ' - 'scil_tractogram_resample_nb_points.py') + # Loading + sft = load_tractogram_with_reference(parser, args, args.in_tractogram) + LUT = load_matrix_in_any_format(args.LUT) if args.LUT else None cmap = get_colormap(args.colormap) - if args.use_dps or args.use_dpp or args.load_dps or args.load_dpp: - if args.use_dps: - data = np.squeeze(sft.data_per_streamline[args.use_dps]) - # I believe it works well for gaussian distribution, but - # COMMIT has very weird outliers values - if args.use_dps == 'commit_weights' \ - or args.use_dps == 'commit2_weights': - data = np.clip(data, np.quantile(data, 0.05), - np.quantile(data, 0.95)) - elif args.use_dpp: - tmp = [np.squeeze(sft.data_per_point[args.use_dpp][s]) for s in - range(len(sft))] - data = np.hstack(tmp) - elif args.load_dps: - data = np.squeeze(load_matrix_in_any_format(args.load_dps)) - if len(data) != len(sft): - parser.error('Wrong dps size!') - else: # args.load_dpp - data = np.squeeze(load_matrix_in_any_format(args.load_dpp)) - if len(data) != len(sft.streamlines._data): - parser.error('Wrong dpp size!') - values, lbound, ubound = clip_and_normalize_data_for_cmap(args, data) - elif args.from_anatomy: - data = nib.load(args.from_anatomy).get_fdata() - data, lbound, ubound = clip_and_normalize_data_for_cmap(args, data) + # Loading data. Depending on the type of loading, format data now to a 1D + # array (one value per point or per streamline) + if args.use_dps: + if args.use_dps not in sft.data_per_streamline.keys(): + parser.error("DPS key {} not found in the loaded tractogram's " + "data_per_streamline.".format(args.use_dps)) + data = np.squeeze(sft.data_per_streamline[args.use_dps]) + elif args.use_dpp: + if args.use_dpp not in sft.data_per_point.keys(): + parser.error("DPP key {} not found in the loaded tractogram's " + "data_per_point.".format(args.use_dpp)) + data = np.hstack( + [np.squeeze(sft.data_per_point[args.use_dpp][s]) for s in + range(len(sft))]) + elif args.load_dps: + data = np.squeeze(load_matrix_in_any_format(args.load_dps)) + if len(data) != len(sft): + parser.error('Wrong dps size! Expected one value per streamline ' + '({}) but found {} values.' + .format(len(sft), len(data))) + elif args.load_dpp or args.from_anatomy: sft.to_vox() - values = map_coordinates(data, sft.streamlines._data.T, order=0) + concat_points = np.vstack(sft.streamlines).T + expected_shape = len(concat_points) sft.to_rasmm() + if args.load_dpp: + data = np.squeeze(load_matrix_in_any_format(args.load_dpp)) + if len(data) != expected_shape: + parser.error('Wrong dpp size! Expected a total of {} points, ' + 'but got {}'.format(expected_shape, len(data))) + else: # args.from_anatomy: + data = nib.load(args.from_anatomy).get_fdata() + data = map_coordinates(data, concat_points, order=0) elif args.along_profile: - values, lbound, ubound = get_color_streamlines_along_length( - sft, args) - elif args.local_angle: - values, lbound, ubound = get_color_streamlines_from_angle( - sft, args) - else: - parser.error('No coloring method specified.') - - color = cmap(values)[:, 0:3] * 255 - if len(color) == len(sft): - tmp = [np.tile([color[i][0], color[i][1], color[i][2]], - (len(sft.streamlines[i]), 1)) - for i in range(len(sft.streamlines))] - sft.data_per_point['color'] = tmp - elif len(color) == len(sft.streamlines._data): - sft.data_per_point['color'] = sft.streamlines - sft.data_per_point['color']._data = color - else: - raise ValueError("Error in the code... Colors do not have the right " - "shape. (this is our fault). Expecting either one" - "color per streamline ({}) or one per point ({}) but " - "got {}.".format(len(sft), len(sft.streamlines._data), - len(color))) + data = get_values_along_length(sft) + else: # args.local_angle: + data = get_angles(sft) + + # Processing + sft, lbound, ubound = add_data_as_color_dpp( + sft, cmap, data, args.clip_outliers, args.min_range, args.max_range, + args.min_cmap, args.max_cmap, args.log, LUT) + + # Saving save_tractogram(sft, args.out_tractogram) - # output colormap if args.out_colorbar: - save_colorbar(cmap, lbound, ubound, args) + _ = prepare_colorbar_figure( + cmap, lbound, ubound, + horizontal=args.horizontal_cbar, log=args.log) + plt.savefig(args.out_colorbar, bbox_inches='tight') + if args.show_colorbar: + plt.show() if __name__ == '__main__': diff --git a/scripts/scil_tractogram_detect_loops.py b/scripts/scil_tractogram_detect_loops.py index 921fc696e..3ed25f655 100755 --- a/scripts/scil_tractogram_detect_loops.py +++ b/scripts/scil_tractogram_detect_loops.py @@ -38,8 +38,9 @@ assert_outputs_exist, check_tracts_same_format, validate_nbr_processes) -from scilpy.utils.streamlines import filter_tractogram_data -from scilpy.tractanalysis.features import remove_loops_and_sharp_turns +from scilpy.tractograms.tractogram_operations import filter_tractogram_data +from scilpy.tractograms.streamline_operations import \ + remove_loops_and_sharp_turns def _build_arg_parser(): diff --git a/scripts/scil_tractogram_extract_ushape.py b/scripts/scil_tractogram_extract_ushape.py index be98a694b..5dc8a28b5 100755 --- a/scripts/scil_tractogram_extract_ushape.py +++ b/scripts/scil_tractogram_extract_ushape.py @@ -28,7 +28,7 @@ assert_inputs_exist, assert_outputs_exist, check_tracts_same_format) -from scilpy.tractanalysis.features import detect_ushape +from scilpy.tractanalysis.bundle_operations import detect_ushape def _build_arg_parser(): diff --git a/scripts/scil_tractogram_filter_by_anatomy.py b/scripts/scil_tractogram_filter_by_anatomy.py index f1a6372d0..9be85afc7 100755 --- a/scripts/scil_tractogram_filter_by_anatomy.py +++ b/scripts/scil_tractogram_filter_by_anatomy.py @@ -68,12 +68,10 @@ validate_nbr_processes, assert_headers_compatible) from scilpy.image.labels import get_data_as_labels from scilpy.segment.streamlines import filter_grid_roi -from scilpy.tractanalysis.features import remove_loops_and_sharp_turns from scilpy.tractograms.streamline_operations import \ - filter_streamlines_by_length + filter_streamlines_by_length, remove_loops_and_sharp_turns from scilpy.tractograms.tractogram_operations import \ - perform_tractogram_operation_on_sft -from scilpy.utils.streamlines import filter_tractogram_data + perform_tractogram_operation_on_sft, filter_tractogram_data EPILOG = """ diff --git a/scripts/scil_tractogram_fix_trk.py b/scripts/scil_tractogram_fix_trk.py index c4302bbb8..6f924d031 100755 --- a/scripts/scil_tractogram_fix_trk.py +++ b/scripts/scil_tractogram_fix_trk.py @@ -63,8 +63,8 @@ assert_outputs_exist) from scilpy.tractograms.tractogram_operations import (flip_sft, transform_warp_sft, - _get_axis_flip_vector) -from scilpy.utils.streamlines import cut_invalid_streamlines + get_axis_flip_vector) +from scilpy.tractograms.streamline_operations import cut_invalid_streamlines softwares = ['dsi_studio', 'startrack'] @@ -149,7 +149,7 @@ def main(): # Startrack flips the TRK flip_axis = ['x'] new_sft.to_vox() - new_sft.streamlines._data -= _get_axis_flip_vector(flip_axis) + new_sft.streamlines._data -= get_axis_flip_vector(flip_axis) new_sft = flip_sft(new_sft, flip_axis) new_sft.to_rasmm() @@ -169,7 +169,7 @@ def main(): sft_fix = StatefulTractogram(sft.streamlines, args.in_dsi_fa, Space.VOXMM) sft_fix.to_vox() - sft_fix.streamlines._data -= _get_axis_flip_vector(flip_axis) + sft_fix.streamlines._data -= get_axis_flip_vector(flip_axis) sft_flip = flip_sft(sft_fix, flip_axis) diff --git a/scripts/scil_tractogram_remove_invalid.py b/scripts/scil_tractogram_remove_invalid.py index db72c2f70..336121301 100755 --- a/scripts/scil_tractogram_remove_invalid.py +++ b/scripts/scil_tractogram_remove_invalid.py @@ -23,7 +23,7 @@ from scilpy.io.utils import (add_overwrite_arg, add_verbose_arg, add_reference_arg, assert_inputs_exist, assert_outputs_exist) -from scilpy.utils.streamlines import cut_invalid_streamlines +from scilpy.tractograms.streamline_operations import cut_invalid_streamlines def _build_arg_parser(): diff --git a/scripts/scil_tractogram_segment_bundles_for_connectivity.py b/scripts/scil_tractogram_segment_bundles_for_connectivity.py index 11755ad78..ed2e76615 100755 --- a/scripts/scil_tractogram_segment_bundles_for_connectivity.py +++ b/scripts/scil_tractogram_segment_bundles_for_connectivity.py @@ -57,12 +57,12 @@ assert_outputs_exist, assert_output_dirs_exist_and_empty, validate_nbr_processes) -from scilpy.tractanalysis.features import (remove_outliers, - remove_loops_and_sharp_turns) +from scilpy.tractanalysis.bundle_operations import remove_outliers from scilpy.tractanalysis.tools import (compute_connectivity, extract_longest_segments_from_profile) from scilpy.tractograms.uncompress import uncompress - +from scilpy.tractograms.streamline_operations import \ + remove_loops_and_sharp_turns from scilpy.tractograms.streamline_and_mask_operations import \ compute_streamline_segment diff --git a/scripts/scil_tractogram_uniformize_endpoints.py b/scripts/scil_tractogram_uniformize_endpoints.py index fcacfc82c..d148dee0c 100755 --- a/scripts/scil_tractogram_uniformize_endpoints.py +++ b/scripts/scil_tractogram_uniformize_endpoints.py @@ -32,8 +32,8 @@ add_verbose_arg, assert_outputs_exist, assert_inputs_exist, assert_headers_compatible) -from scilpy.utils.streamlines import (uniformize_bundle_sft, - uniformize_bundle_sft_using_mask) +from scilpy.tractanalysis.bundle_operations import \ + uniformize_bundle_sft, uniformize_bundle_sft_using_mask def _build_arg_parser(): diff --git a/scripts/tests/test_tractogram_assign_custom_color.py b/scripts/tests/test_tractogram_assign_custom_color.py index 5e71a415c..1b9d0df58 100644 --- a/scripts/tests/test_tractogram_assign_custom_color.py +++ b/scripts/tests/test_tractogram_assign_custom_color.py @@ -11,6 +11,11 @@ fetch_data(get_testing_files_dict(), keys=['tractometry.zip']) tmp_dir = tempfile.TemporaryDirectory() +in_bundle = os.path.join(SCILPY_HOME, 'tractometry', 'IFGWM.trk') + +# toDo for more coverage: add a LUT in test data, use option --LUT +# toDo. get a dpp / dps file to load, use options --load_dpp, --load_dps + def test_help_option(script_runner): ret = script_runner.run('scil_tractogram_assign_custom_color.py', @@ -18,13 +23,28 @@ def test_help_option(script_runner): assert ret.success -def test_execution_tractometry(script_runner): +def test_execution_from_anat(script_runner): os.chdir(os.path.expanduser(tmp_dir.name)) - in_bundle = os.path.join(SCILPY_HOME, 'tractometry', - 'IFGWM.trk') in_anat = os.path.join(SCILPY_HOME, 'tractometry', 'IFGWM_labels_map.nii.gz') + ret = script_runner.run('scil_tractogram_assign_custom_color.py', in_bundle, 'colored.trk', '--from_anatomy', - in_anat) + in_anat, '--out_colorbar', 'test_colorbar.png') + assert ret.success + + +def test_execution_along_profile(script_runner): + os.chdir(os.path.expanduser(tmp_dir.name)) + + ret = script_runner.run('scil_tractogram_assign_custom_color.py', + in_bundle, 'colored2.trk', '--along_profile') + assert ret.success + + +def test_execution_from_angle(script_runner): + os.chdir(os.path.expanduser(tmp_dir.name)) + + ret = script_runner.run('scil_tractogram_assign_custom_color.py', + in_bundle, 'colored3.trk', '--local_angle') assert ret.success