Skip to content

Commit

Permalink
Merge pull request #911 from EmmaRenauld/project_streamlines_part2
Browse files Browse the repository at this point in the history
Refactor project_streamlines_to_map
  • Loading branch information
arnaudbore authored Feb 26, 2024
2 parents bcf63aa + 39a9f2c commit f79ffa3
Show file tree
Hide file tree
Showing 15 changed files with 638 additions and 390 deletions.
142 changes: 142 additions & 0 deletions scilpy/io/streamlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@

from dipy.io.streamline import load_tractogram
import nibabel as nib
from dipy.io.utils import is_header_compatible
from nibabel.streamlines.array_sequence import ArraySequence
import numpy as np

from scilpy.io.utils import load_matrix_in_any_format


def check_tracts_same_format(parser, tractogram_1, tractogram_2):
"""
Expand Down Expand Up @@ -115,6 +118,145 @@ def load_tractogram_with_reference(parser, args, filepath, arg_name=None):
return sft


def verify_compatibility_with_reference_sft(ref_sft, files_to_verify,
parser, args):
"""
Verifies the compatibility of a reference sft with a list of files.
Params
------
ref_sft: StatefulTractogram
A tractogram to be used as reference.
files_to_verify: List[str]
List of files that should be compatible with the reference sft. Files
can be either other tractograms or nifti files (ex: masks).
parser: argument parser
Will raise an error if a file is not compatible.
args: Namespace
Should contain a args.reference if any file is a .tck, and possibly a
args.bbox_check (set to True by default).
"""
save_ref = args.reference

for file in files_to_verify:
if file is not None:
_, ext = os.path.splitext(file)
if ext in ['.trk', '.tck', '.fib', '.vtk', '.dpy']:
# Cheating ref because it may send a lot of warning if loading
# many trk with ref (reference was maybe added only for some
# of these files)
if ext == '.trk':
args.reference = None
else:
args.reference = save_ref
mask = load_tractogram_with_reference(parser, args, file)
else: # should be a nifti file.
mask = file
compatible = is_header_compatible(ref_sft, mask)
if not compatible:
parser.error("Reference tractogram incompatible with {}"
.format(file))


def load_dps_files_as_dps(parser, dps_files, sft, keys=None, overwrite=False):
"""
Load dps information. They must be scalar values.
Parameters
----------
parser: parser
dps_files: list[str]
Either .npy or .txt files.
sft: StatefulTractogram
keys: list[str]
If None, use the filenames as keys.
overwrite: bool
If True, allow overwriting an existing dps key.
Returns
-------
sft: StatefulTractogram
The modified SFT. (Note that it is modified in-place even if the
returned variable is not used!)
new_keys: list[str]
Added keys.
"""
if keys is not None and len(keys) != len(dps_files):
parser.error("You must provide one key name per dps file.")

new_keys = []
for i, file in enumerate(dps_files):
if keys is None:
name = os.path.basename(file)
key, ext = os.path.splitext(name)
else:
key = keys[i]

if key in sft.data_per_streamline and not overwrite:
parser.error("Key {} already exists in your tractogram's dps. "
"You must allow overwriting keys."
.format(key))

data = np.squeeze(load_matrix_in_any_format(file))
if len(data) != len(sft):
parser.error('Wrong dps size in file {}. Expected one value per '
'streamline ({}) but got {} values!'
.format(file, len(sft), len(data)))

new_keys.append(key)
sft.data_per_streamline[key] = data
return sft, new_keys


def load_dpp_files_as_dpp(parser, dpp_files, sft, keys=None, overwrite=False):
"""
Load dpp information. They must be scalar values.
Parameters
----------
parser: parser
dpp_files: list[str]
Either .npy or .txt files.
sft: StatefulTractogram
keys: list[str]
If None, use the filenames as keys.
overwrite: bool
If True, allow overwriting an existing dpp key.
Returns
-------
sft: StatefulTractogram
The modified SFT. (Note that it is modified in-place even if the
returned variable is not used!)
new_keys: list[str]
Added keys.
"""
if keys is not None and len(keys) != len(dpp_files):
parser.error("You must provide one key name per dps file.")

new_keys = []
for i, file in enumerate(dpp_files):
if keys is None:
name = os.path.basename(file)
key, ext = os.path.splitext(name)
else:
key = keys[i]

if key in sft.data_per_streamline and not overwrite:
parser.error("Key {} already exists in your tractogram's dpp. "
"You must allow overwriting keys."
.format(key))

data = np.squeeze(load_matrix_in_any_format(file))
if len(data) != len(sft.streamlines._data):
parser.error('Wrong dpp size in file {}. Expected one value per '
'point in your tractogram ({}) but got {}!'
.format(file, len(sft.streamlines._data), len(data)))
new_keys.append(key)
sft.data_per_point[key] = data
return sft, new_keys


def streamlines_to_memmap(input_streamlines,
strs_dtype='float32'):
"""
Expand Down
41 changes: 0 additions & 41 deletions scilpy/io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from scipy.io import loadmat
import six

from scilpy.io.streamlines import load_tractogram_with_reference
from scilpy.gradients.bvec_bval_tools import DEFAULT_B0_THRESHOLD
from scilpy.utils.filenames import split_name_with_nii

Expand Down Expand Up @@ -593,46 +592,6 @@ def assert_roi_radii_format(parser):
return roi_radii


def verify_compatibility_with_reference_sft(ref_sft, files_to_verify,
parser, args):
"""
Verifies the compatibility of a reference sft with a list of files.
Params
------
ref_sft: StatefulTractogram
A tractogram to be used as reference.
files_to_verify: List[str]
List of files that should be compatible with the reference sft. Files
can be either other tractograms or nifti files (ex: masks).
parser: argument parser
Will raise an error if a file is not compatible.
args: Namespace
Should contain a args.reference if any file is a .tck, and possibly a
args.bbox_check (set to True by default).
"""
save_ref = args.reference

for file in files_to_verify:
if file is not None:
_, ext = os.path.splitext(file)
if ext in ['.trk', '.tck', '.fib', '.vtk', '.dpy']:
# Cheating ref because it may send a lot of warning if loading
# many trk with ref (reference was maybe added only for some
# of these files)
if ext == '.trk':
args.reference = None
else:
args.reference = save_ref
mask = load_tractogram_with_reference(parser, args, file)
else: # should be a nifti file.
mask = file
compatible = is_header_compatible(ref_sft, mask)
if not compatible:
parser.error("Reference tractogram incompatible with {}"
.format(file))


def is_header_compatible_multiple_files(parser, list_files,
verbose_all_compatible=False,
reference=None):
Expand Down
124 changes: 104 additions & 20 deletions scilpy/tractograms/dps_and_dpp_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,54 @@
import numpy as np


def convert_dps_to_dpp(sft, keys, overwrite=False):
"""
Copy the value of the data_per_streamline to each point of the
streamline, as data_per_point. The dps key is removed and added as dpp key.
Parameters
----------
sft: StatefulTractogram
keys: List[str], optional
The list of dps keys to convert to dpp.
overwrite: bool
If true, allow continuing even if the key already existed as dpp.
"""
for key in keys:
if key not in sft.data_per_streamline:
raise ValueError("Dps key {} not found!".format(key))
if key in sft.data_per_point and not overwrite:
raise ValueError("Dpp key {} already existed. Please allow "
"overwriting.".format(key))
sft.data_per_point[key] = [[val]*len(s) for val, s in
zip(sft.data_per_streamline[key],
sft.streamlines)]
del sft.data_per_streamline[key]

return sft


def project_map_to_streamlines(sft, map_volume, endpoints_only=False):
"""
Projects a map onto the points of streamlines.
Projects a map onto the points of streamlines. The result is a
data_per_point.
Parameters
----------
sft: StatefulTractogram
Input tractogram.
map_volume: DataVolume
Input map.
Optional:
---------
endpoints_only: bool
endpoints_only: bool, optional
If True, will only project the map_volume onto the endpoints of the
streamlines (all values along streamlines set to zero). If False,
will project the map_volume onto all points of the streamlines.
Returns
-------
streamline_data:
map_volume projected to each point of the streamlines.
streamline_data: List
The values that could now be associated to a data_per_point key.
The map_volume projected to each point of the streamlines.
"""
if len(map_volume.data.shape) == 4:
dimension = map_volume.data.shape[3]
Expand Down Expand Up @@ -63,9 +89,62 @@ def project_map_to_streamlines(sft, map_volume, endpoints_only=False):
return streamline_data


def perform_streamline_operation_per_point(op_name, sft, dpp_name='metric',
endpoints_only=False):
"""Peforms an operation per point for all streamlines.
def project_dpp_to_map(sft, dpp_key, sum_lines=False, endpoints_only=False):
"""
Saves the values of data_per_point keys to the underlying voxels. Averages
the values of various streamlines in each voxel. Returns one map per key.
The streamlines are not preprocessed here. You should probably first
uncompress your streamlines to have smoother maps.
Parameters
----------
sft: StatefulTractogram
dpp_key: str
The data_per_point key to project to a map.
sum_lines: bool
Do not average values of streamlines that cross a same voxel; sum them
instead.
endpoints_only: bool
If true, only project the streamline's endpoints.
Returns
-------
the_map: np.ndarray
The 3D resulting map.
"""
sft.to_vox()

# Using to_corner, if we simply floor the coordinates of the point, we find
# the voxel where it is.
sft.to_corner()

# count: could also use compute_tract_counts_map.
count = np.zeros(sft.dimensions)
the_map = np.zeros(sft.dimensions)
for s in range(len(sft)):
if endpoints_only:
points = [0, -1]
else:
points = range(len(sft.streamlines[s]))

for p in points:
x, y, z = sft.streamlines[s][p, :].astype(int) # Or floor
count[x, y, z] += 1
the_map[x, y, z] += sft.data_per_point[dpp_key][s][p]

if not sum_lines:
count = np.maximum(count, 1e-6) # Avoid division by 0
the_map /= count

return the_map


def perform_operation_on_dpp(op_name, sft, dpp_name, endpoints_only=False):
"""
Peforms an operation on the data per point for all streamlines (mean, sum,
min, max, correlation). The operation is applied on each point invidiually,
and thus makes sense if the data_per_point at each point is a vector. The
result is a new data_per_point.
Parameters
----------
Expand All @@ -77,12 +156,13 @@ def perform_streamline_operation_per_point(op_name, sft, dpp_name='metric',
dpp_name: str
The name of the data per point to be used in the operation.
endpoints_only: bool
If True, will only perform operation on endpoints
If True, will only perform operation on endpoints. Values at other
points will be set to NaN.
Returns
-------
new_sft: StatefulTractogram
sft with data per streamline resulting from the operation.
new_data_per_point: list
The values that could now be associated to a new data_per_point key.
"""

# Performing operation
Expand All @@ -104,13 +184,15 @@ def perform_streamline_operation_per_point(op_name, sft, dpp_name='metric',
new_data_per_point.append(
np.reshape(this_data_per_point, (len(this_data_per_point), 1)))

# Extracting streamlines
return new_data_per_point


def perform_operation_per_streamline(op_name, sft, dpp_name='metric',
endpoints_only=False):
"""Performs an operation across all data points for each streamline.
def perform_operation_dpp_to_dps(op_name, sft, dpp_name, endpoints_only=False):
"""
Converts dpp to dps, using a chosen operation.
Performs an operation across all data_per_points for each streamline (mean,
sum, min, max, correlation). The result is a data_per_streamline.
Parameters
----------
Expand All @@ -122,12 +204,14 @@ def perform_operation_per_streamline(op_name, sft, dpp_name='metric',
dpp_name: str
The name of the data per point to be used in the operation.
endpoints_only: bool
If True, will only perform operation on endpoints
If True, will only perform operation on endpoints. Other points will be
ignored in the operation.
Returns
-------
new_sft: StatefulTractogram
sft with data per streamline resulting from the operation.
new_data_per_streamline: list
The values that could now be associated to a new data_per_streamline
key.
"""
# Performing operation
call_op = OPERATIONS[op_name]
Expand Down
Loading

0 comments on commit f79ffa3

Please sign in to comment.