Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: remove sft._data usage part 1 - tractogram coloring scripts + more #1105

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 80 additions & 59 deletions scilpy/tractograms/dps_and_dpp_management.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,105 @@
# -*- coding: utf-8 -*-
import numpy as np

from scilpy.viz.color import clip_and_normalize_data_for_cmap
from nibabel.streamlines import ArraySequence


def add_data_as_color_dpp(sft, cmap, data, clip_outliers=False, min_range=None,
max_range=None, min_cmap=None, max_cmap=None,
log=False, LUT=None):
def get_data_as_arraysequence(data, ref_sft):
""" Get data in the same shape as a reference StatefulTractogram's
streamlines, so it can be used to set data_per_point or
data_per_streamline. The data may represent one value per streamline or one
value per point. The function will return an ArraySequence with the same
shape as the streamlines.

Parameters
----------
data: np.ndarray
The data to convert to ArraySequence.
ref_sft: StatefulTractogram
The reference StatefulTractogram containing the streamlines.

Returns
-------
data_as_arraysequence: ArraySequence
The data as an ArraySequence.
"""
Normalizes data between 0 and 1 for an easier management with colormaps.
The real lower bound and upperbound are returned.
# Check if data has the right shape, either one value per streamline or one
# value per point.
if data.shape[0] == ref_sft._get_streamline_count():
# Two consective if statements to handle both 1D and 2D arrays
# and turn them into lists of lists of lists.
# Check if the data is a vector or a scalar.
if len(data.shape) == 1:
data = data[:, None]
# ArraySequence expects a list of lists of lists, so we need to add
# an extra dimension.
if len(data.shape) == 2:
data = data[:, None, :]
data_as_arraysequence = ArraySequence(data)

elif data.shape[0] == ref_sft._get_point_count():
# Split the data into a list of arrays, one per streamline.
# np.split takes the indices at which to split the array, so use
# np.cumsum to get the indices of the end of each streamline.
data_split = np.split(
data, np.cumsum(ref_sft.streamlines._lengths)[:-1])
# Create an ArraySequence from the list of arrays.
data_as_arraysequence = ArraySequence(data_split)
else:
raise ValueError("Data has the wrong shape. Expecting either one value"
" per streamline ({}) or one per point ({}) but got "
"{}."
.format(len(ref_sft), len(ref_sft.streamlines._data),
data.shape[0]))
return data_as_arraysequence

Data can be clipped to (min_range, max_range) before normalization.
Alternatively, data can be kept as is, but the colormap be fixed to
(min_cmap, max_cmap).

def add_data_as_color_dpp(sft, color):
"""
Ensures the color data is in the right shape and adds it to the
data_per_point of the StatefulTractogram. The color data can be either one
color per streamline or one color per point. The function will return the
StatefulTractogram with the color data added.

Parameters
----------
sft: StatefulTractogram
The tractogram
cmap: plt colormap
The colormap. Ex, see scilpy.viz.utils.get_colormap().
data: np.ndarray or list[list] or list[np.ndarray]
The data to convert to color. Expecting one value per point to add as
dpp. If instead data has one value per streamline, setting the same
color to all points of the streamline (as dpp).
Either a vector numpy array (all streamlines concatenated), or a list
of arrays per streamline.
clip_outliers: bool
See description of the following parameters in
clip_and_normalize_data_for_cmap.
min_range: float
Data values below min_range will be clipped.
max_range: float
Data values above max_range will be clipped.
min_cmap: float
Minimum value of the colormap. Most useful when min_range and max_range
are not set; to fix the colormap range without modifying the data.
max_cmap: float
Maximum value of the colormap. Idem.
log: bool
If True, apply a logarithmic scale to the data.
LUT: np.ndarray
If set, replaces the data values by the Look-Up Table values. In order,
the first value of the LUT is set everywhere where data==1, etc.
color: ArraySequence
The color data.

Returns
-------
sft: StatefulTractogram
The tractogram, with dpp 'color' added.
lbound: float
The lower bound of the associated colormap.
ubound: float
The upper bound of the associated colormap.
"""
# If data is a list of lists, merge.
if isinstance(data[0], list) or isinstance(data[0], np.ndarray):
data = np.hstack(data)

values, lbound, ubound = clip_and_normalize_data_for_cmap(
data, clip_outliers, min_range, max_range,
min_cmap, max_cmap, log, LUT)

# Important: values are in float after clip_and_normalize.
color = np.asarray(cmap(values)[:, 0:3]) * 255
if len(color) == len(sft):
tmp = [np.tile([color[i][0], color[i][1], color[i][2]],

if len(color) == sft._get_streamline_count():
if color.common_shape != (3,):
raise ValueError("Colors do not have the right shape. Expecting "
"RBG values, but got values of shape {}.".format(
color.common_shape))

tmp = [np.tile([color[i][0][0], color[i][0][1], color[i][0][2]],
(len(sft.streamlines[i]), 1))
for i in range(len(sft.streamlines))]
sft.data_per_point['color'] = tmp
elif len(color) == len(sft.streamlines._data):
sft.data_per_point['color'] = sft.streamlines
sft.data_per_point['color']._data = color

elif len(color) == sft._get_point_count():

if color.common_shape != (3,):
raise ValueError("Colors do not have the right shape. Expecting "
"RBG values, but got values of shape {}.".format(
color.common_shape))

sft.data_per_point['color'] = color
else:
raise ValueError("Error in the code... Colors do not have the right "
"shape. Expecting either one color per streamline "
"({}) or one per point ({}) but got {}."
.format(len(sft), len(sft.streamlines._data),
len(color)))
return sft, lbound, ubound
raise ValueError("Colors do not have the right shape. Expecting either"
" one color per streamline ({}) or one per point ({})"
" but got {}.".format(sft._get_streamline_count(),
sft._get_point_count(),
color.total_nb_rows))
return sft


def convert_dps_to_dpp(sft, keys, overwrite=False):
Expand Down
118 changes: 94 additions & 24 deletions scilpy/tractograms/tests/test_dps_and_dpp_management.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# -*- coding: utf-8 -*-
import nibabel as nib
import numpy as np
import pytest

from dipy.io.stateful_tractogram import StatefulTractogram, Space, Origin

from scilpy.image.volume_space_management import DataVolume
from scilpy.tests.utils import nan_array_equal
from scilpy.tractograms.dps_and_dpp_management import (
get_data_as_arraysequence,
add_data_as_color_dpp, convert_dps_to_dpp, project_map_to_streamlines,
project_dpp_to_map, perform_operation_on_dpp, perform_operation_dpp_to_dps,
perform_correlation_on_endpoints)
Expand All @@ -27,45 +30,112 @@ def _get_small_sft():
return fake_sft


def test_add_data_as_color_dpp():
lut = get_lookup_table('viridis')
def test_get_data_as_arraysequence_dpp():
fake_sft = _get_small_sft()

some_data = np.asarray([2, 20, 200, 0.1, 0.3, 22, 5])

# Test 1: One value per point.
array_seq = get_data_as_arraysequence(some_data, fake_sft)

assert fake_sft._get_point_count() == array_seq.total_nb_rows


def test_get_data_as_arraysequence_dps():
fake_sft = _get_small_sft()

some_data = np.asarray([2, 20])

# Test: One value per streamline.
array_seq = get_data_as_arraysequence(some_data, fake_sft)
assert fake_sft._get_streamline_count() == array_seq.total_nb_rows


def test_get_data_as_arraysequence_dps_2D():
fake_sft = _get_small_sft()

some_data = np.asarray([[2], [20]])

# Test: One value per streamline.
array_seq = get_data_as_arraysequence(some_data, fake_sft)
assert fake_sft._get_streamline_count() == array_seq.total_nb_rows


def test_get_data_as_arraysequence_error():
fake_sft = _get_small_sft()

some_data = np.asarray([2, 20, 200, 0.1])

# Test: Too many values per streamline, not enough per point.
with pytest.raises(ValueError):
_ = get_data_as_arraysequence(some_data, fake_sft)


# Important. cmap(1) != cmap(1.0)
lowest_color = np.asarray(lut(0.0)[0:3]) * 255
highest_color = np.asarray(lut(1.0)[0:3]) * 255
def test_add_data_as_dpp_1_per_point():

fake_sft = _get_small_sft()
cmap = get_lookup_table('jet')

# Not testing the clipping options. Will be tested through viz.utils tests

# Test 1: One value per point.
# Lowest cmap color should be first point of second streamline.
AntoineTheb marked this conversation as resolved.
Show resolved Hide resolved
some_data = [[2, 20, 200], [0.1, 0.3, 22, 5]]
colored_sft, lbound, ubound = add_data_as_color_dpp(
fake_sft, lut, some_data)
values = np.asarray([2, 20, 200, 0.1, 0.3, 22, 5])
color = (np.asarray(cmap(values)[:, 0:3]) * 255).astype(np.uint8)

array_seq = get_data_as_arraysequence(color, fake_sft)
colored_sft = add_data_as_color_dpp(
fake_sft, array_seq)
assert len(colored_sft.data_per_streamline.keys()) == 0
assert list(colored_sft.data_per_point.keys()) == ['color']
assert lbound == 0.1
assert ubound == 200
assert np.array_equal(colored_sft.data_per_point['color'][1][0, :],
lowest_color)
assert np.array_equal(colored_sft.data_per_point['color'][0][2, :],
highest_color)


def test_add_data_as_dpp_1_per_streamline():

fake_sft = _get_small_sft()
cmap = get_lookup_table('jet')

# Test 2: One value per streamline
# Lowest cmap color should be every point in first streamline
AntoineTheb marked this conversation as resolved.
Show resolved Hide resolved
some_data = np.asarray([4, 5])
colored_sft, lbound, ubound = add_data_as_color_dpp(
fake_sft, lut, some_data)
values = np.asarray([4, 5])
color = (np.asarray(cmap(values)[:, 0:3]) * 255).astype(np.uint8)
array_seq = get_data_as_arraysequence(color, fake_sft)

colored_sft = add_data_as_color_dpp(
fake_sft, array_seq)

assert len(colored_sft.data_per_streamline.keys()) == 0
assert list(colored_sft.data_per_point.keys()) == ['color']
assert lbound == 4
assert ubound == 5
# Lowest cmap color should be first point of second streamline.
# Same value for all points.
colors_first_line = colored_sft.data_per_point['color'][0]
assert np.array_equal(colors_first_line[0, :], lowest_color)
assert np.all(colors_first_line[1:, :] == colors_first_line[0, :])

AntoineTheb marked this conversation as resolved.
Show resolved Hide resolved

def test_add_data_as_color_error_common_shape():

fake_sft = _get_small_sft()

# Test: One value per streamline
# Should fail because the values aren't RGB values
values = np.asarray([4, 5])
array_seq = get_data_as_arraysequence(values, fake_sft)

with pytest.raises(ValueError):
_ = add_data_as_color_dpp(
fake_sft, array_seq)


def test_add_data_as_color_error_number():

fake_sft = _get_small_sft()
cmap = get_lookup_table('jet')

# Test: One value per streamline
# Should fail because the values aren't RGB values
values = np.asarray([2, 20, 200, 0.1, 0.3, 22, 5])
array_seq = get_data_as_arraysequence(values, fake_sft)
color = (np.asarray(cmap(values)[:, 0:3]) * 255).astype(np.uint8)
color = color[:-2] # Remove last streamline colors
with pytest.raises(ValueError):
_ = add_data_as_color_dpp(
fake_sft, array_seq)


def test_convert_dps_to_dpp():
Expand Down
11 changes: 6 additions & 5 deletions scilpy/viz/color.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def get_lookup_table(name):
name_list = name.split('-')
colors_list = [mcolors.to_rgba(color)[0:3] for color in name_list]
cmap = mcolors.LinearSegmentedColormap.from_list('CustomCmap',
colors_list)
colors_list)
return cmap

return plt.colormaps.get_cmap(name)
Expand Down Expand Up @@ -283,10 +283,10 @@ def prepare_colorbar_figure(cmap, lbound, ubound, nb_values=255, nb_ticks=10,
return fig


def ambiant_occlusion(sft, colors, factor=4):
def ambient_occlusion(sft, colors, factor=4):
"""
Apply ambiant occlusion to a set of colors based on point density
around each points.
around each points.
AntoineTheb marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
Expand All @@ -296,14 +296,14 @@ def ambiant_occlusion(sft, colors, factor=4):
The original colors to modify.
factor : float
The factor of occlusion (how density will affect the saturation).

Returns
-------
np.ndarray
The modified colors.
"""

pts = sft.streamlines._data
pts = sft.streamlines.get_data()
hsv = mcolors.rgb_to_hsv(colors)

tree = KDTree(pts)
Expand All @@ -324,6 +324,7 @@ def ambiant_occlusion(sft, colors, factor=4):

return mcolors.hsv_to_rgb(hsv)


def generate_local_coloring(sft):
"""
Generate a coloring based on the local orientation of the streamlines.
Expand Down
2 changes: 1 addition & 1 deletion scripts/scil_bundle_diameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def main():
counter = 0
labels_dict = {label: ([], []) for label in unique_labels}
pts_labels = map_coordinates(data_labels,
sft.streamlines._data.T-0.5,
sft.streamlines.get_data().T-0.5,
order=0)
# For each label, all positions and directions are needed to get
# a tube estimation per label.
Expand Down
Loading
Loading