diff --git a/scilpy/io/streamlines.py b/scilpy/io/streamlines.py index eebf54a79..20ce48654 100644 --- a/scilpy/io/streamlines.py +++ b/scilpy/io/streamlines.py @@ -7,9 +7,12 @@ from dipy.io.streamline import load_tractogram import nibabel as nib +from dipy.io.utils import is_header_compatible from nibabel.streamlines.array_sequence import ArraySequence import numpy as np +from scilpy.io.utils import load_matrix_in_any_format + def check_tracts_same_format(parser, tractogram_1, tractogram_2): """ @@ -115,6 +118,145 @@ def load_tractogram_with_reference(parser, args, filepath, arg_name=None): return sft +def verify_compatibility_with_reference_sft(ref_sft, files_to_verify, + parser, args): + """ + Verifies the compatibility of a reference sft with a list of files. + + Params + ------ + ref_sft: StatefulTractogram + A tractogram to be used as reference. + files_to_verify: List[str] + List of files that should be compatible with the reference sft. Files + can be either other tractograms or nifti files (ex: masks). + parser: argument parser + Will raise an error if a file is not compatible. + args: Namespace + Should contain a args.reference if any file is a .tck, and possibly a + args.bbox_check (set to True by default). + """ + save_ref = args.reference + + for file in files_to_verify: + if file is not None: + _, ext = os.path.splitext(file) + if ext in ['.trk', '.tck', '.fib', '.vtk', '.dpy']: + # Cheating ref because it may send a lot of warning if loading + # many trk with ref (reference was maybe added only for some + # of these files) + if ext == '.trk': + args.reference = None + else: + args.reference = save_ref + mask = load_tractogram_with_reference(parser, args, file) + else: # should be a nifti file. + mask = file + compatible = is_header_compatible(ref_sft, mask) + if not compatible: + parser.error("Reference tractogram incompatible with {}" + .format(file)) + + +def load_dps_files_as_dps(parser, dps_files, sft, keys=None, overwrite=False): + """ + Load dps information. They must be scalar values. + + Parameters + ---------- + parser: parser + dps_files: list[str] + Either .npy or .txt files. + sft: StatefulTractogram + keys: list[str] + If None, use the filenames as keys. + overwrite: bool + If True, allow overwriting an existing dps key. + + Returns + ------- + sft: StatefulTractogram + The modified SFT. (Note that it is modified in-place even if the + returned variable is not used!) + new_keys: list[str] + Added keys. + """ + if keys is not None and len(keys) != len(dps_files): + parser.error("You must provide one key name per dps file.") + + new_keys = [] + for i, file in enumerate(dps_files): + if keys is None: + name = os.path.basename(file) + key, ext = os.path.splitext(name) + else: + key = keys[i] + + if key in sft.data_per_streamline and not overwrite: + parser.error("Key {} already exists in your tractogram's dps. " + "You must allow overwriting keys." + .format(key)) + + data = np.squeeze(load_matrix_in_any_format(file)) + if len(data) != len(sft): + parser.error('Wrong dps size in file {}. Expected one value per ' + 'streamline ({}) but got {} values!' + .format(file, len(sft), len(data))) + + new_keys.append(key) + sft.data_per_streamline[key] = data + return sft, new_keys + + +def load_dpp_files_as_dpp(parser, dpp_files, sft, keys=None, overwrite=False): + """ + Load dpp information. They must be scalar values. + + Parameters + ---------- + parser: parser + dpp_files: list[str] + Either .npy or .txt files. + sft: StatefulTractogram + keys: list[str] + If None, use the filenames as keys. + overwrite: bool + If True, allow overwriting an existing dpp key. + + Returns + ------- + sft: StatefulTractogram + The modified SFT. (Note that it is modified in-place even if the + returned variable is not used!) + new_keys: list[str] + Added keys. + """ + if keys is not None and len(keys) != len(dpp_files): + parser.error("You must provide one key name per dps file.") + + new_keys = [] + for i, file in enumerate(dpp_files): + if keys is None: + name = os.path.basename(file) + key, ext = os.path.splitext(name) + else: + key = keys[i] + + if key in sft.data_per_streamline and not overwrite: + parser.error("Key {} already exists in your tractogram's dpp. " + "You must allow overwriting keys." + .format(key)) + + data = np.squeeze(load_matrix_in_any_format(file)) + if len(data) != len(sft.streamlines._data): + parser.error('Wrong dpp size in file {}. Expected one value per ' + 'point in your tractogram ({}) but got {}!' + .format(file, len(sft.streamlines._data), len(data))) + new_keys.append(key) + sft.data_per_point[key] = data + return sft, new_keys + + def streamlines_to_memmap(input_streamlines, strs_dtype='float32'): """ diff --git a/scilpy/io/utils.py b/scilpy/io/utils.py index ed687312d..14e504b47 100644 --- a/scilpy/io/utils.py +++ b/scilpy/io/utils.py @@ -19,7 +19,6 @@ from scipy.io import loadmat import six -from scilpy.io.streamlines import load_tractogram_with_reference from scilpy.gradients.bvec_bval_tools import DEFAULT_B0_THRESHOLD from scilpy.utils.filenames import split_name_with_nii @@ -593,46 +592,6 @@ def assert_roi_radii_format(parser): return roi_radii -def verify_compatibility_with_reference_sft(ref_sft, files_to_verify, - parser, args): - """ - Verifies the compatibility of a reference sft with a list of files. - - Params - ------ - ref_sft: StatefulTractogram - A tractogram to be used as reference. - files_to_verify: List[str] - List of files that should be compatible with the reference sft. Files - can be either other tractograms or nifti files (ex: masks). - parser: argument parser - Will raise an error if a file is not compatible. - args: Namespace - Should contain a args.reference if any file is a .tck, and possibly a - args.bbox_check (set to True by default). - """ - save_ref = args.reference - - for file in files_to_verify: - if file is not None: - _, ext = os.path.splitext(file) - if ext in ['.trk', '.tck', '.fib', '.vtk', '.dpy']: - # Cheating ref because it may send a lot of warning if loading - # many trk with ref (reference was maybe added only for some - # of these files) - if ext == '.trk': - args.reference = None - else: - args.reference = save_ref - mask = load_tractogram_with_reference(parser, args, file) - else: # should be a nifti file. - mask = file - compatible = is_header_compatible(ref_sft, mask) - if not compatible: - parser.error("Reference tractogram incompatible with {}" - .format(file)) - - def is_header_compatible_multiple_files(parser, list_files, verbose_all_compatible=False, reference=None): diff --git a/scilpy/tractograms/dps_and_dpp_management.py b/scilpy/tractograms/dps_and_dpp_management.py index 1f85ef2dc..216de74d0 100644 --- a/scilpy/tractograms/dps_and_dpp_management.py +++ b/scilpy/tractograms/dps_and_dpp_management.py @@ -2,9 +2,37 @@ import numpy as np +def convert_dps_to_dpp(sft, keys, overwrite=False): + """ + Copy the value of the data_per_streamline to each point of the + streamline, as data_per_point. The dps key is removed and added as dpp key. + + Parameters + ---------- + sft: StatefulTractogram + keys: List[str], optional + The list of dps keys to convert to dpp. + overwrite: bool + If true, allow continuing even if the key already existed as dpp. + """ + for key in keys: + if key not in sft.data_per_streamline: + raise ValueError("Dps key {} not found!".format(key)) + if key in sft.data_per_point and not overwrite: + raise ValueError("Dpp key {} already existed. Please allow " + "overwriting.".format(key)) + sft.data_per_point[key] = [[val]*len(s) for val, s in + zip(sft.data_per_streamline[key], + sft.streamlines)] + del sft.data_per_streamline[key] + + return sft + + def project_map_to_streamlines(sft, map_volume, endpoints_only=False): """ - Projects a map onto the points of streamlines. + Projects a map onto the points of streamlines. The result is a + data_per_point. Parameters ---------- @@ -12,18 +40,16 @@ def project_map_to_streamlines(sft, map_volume, endpoints_only=False): Input tractogram. map_volume: DataVolume Input map. - - Optional: - --------- - endpoints_only: bool + endpoints_only: bool, optional If True, will only project the map_volume onto the endpoints of the streamlines (all values along streamlines set to zero). If False, will project the map_volume onto all points of the streamlines. Returns ------- - streamline_data: - map_volume projected to each point of the streamlines. + streamline_data: List + The values that could now be associated to a data_per_point key. + The map_volume projected to each point of the streamlines. """ if len(map_volume.data.shape) == 4: dimension = map_volume.data.shape[3] @@ -63,9 +89,62 @@ def project_map_to_streamlines(sft, map_volume, endpoints_only=False): return streamline_data -def perform_streamline_operation_per_point(op_name, sft, dpp_name='metric', - endpoints_only=False): - """Peforms an operation per point for all streamlines. +def project_dpp_to_map(sft, dpp_key, sum_lines=False, endpoints_only=False): + """ + Saves the values of data_per_point keys to the underlying voxels. Averages + the values of various streamlines in each voxel. Returns one map per key. + The streamlines are not preprocessed here. You should probably first + uncompress your streamlines to have smoother maps. + + Parameters + ---------- + sft: StatefulTractogram + dpp_key: str + The data_per_point key to project to a map. + sum_lines: bool + Do not average values of streamlines that cross a same voxel; sum them + instead. + endpoints_only: bool + If true, only project the streamline's endpoints. + + Returns + ------- + the_map: np.ndarray + The 3D resulting map. + """ + sft.to_vox() + + # Using to_corner, if we simply floor the coordinates of the point, we find + # the voxel where it is. + sft.to_corner() + + # count: could also use compute_tract_counts_map. + count = np.zeros(sft.dimensions) + the_map = np.zeros(sft.dimensions) + for s in range(len(sft)): + if endpoints_only: + points = [0, -1] + else: + points = range(len(sft.streamlines[s])) + + for p in points: + x, y, z = sft.streamlines[s][p, :].astype(int) # Or floor + count[x, y, z] += 1 + the_map[x, y, z] += sft.data_per_point[dpp_key][s][p] + + if not sum_lines: + count = np.maximum(count, 1e-6) # Avoid division by 0 + the_map /= count + + return the_map + + +def perform_operation_on_dpp(op_name, sft, dpp_name, endpoints_only=False): + """ + Peforms an operation on the data per point for all streamlines (mean, sum, + min, max, correlation). The operation is applied on each point invidiually, + and thus makes sense if the data_per_point at each point is a vector. The + result is a new data_per_point. Parameters ---------- @@ -77,12 +156,13 @@ def perform_streamline_operation_per_point(op_name, sft, dpp_name='metric', dpp_name: str The name of the data per point to be used in the operation. endpoints_only: bool - If True, will only perform operation on endpoints + If True, will only perform operation on endpoints. Values at other + points will be set to NaN. Returns ------- - new_sft: StatefulTractogram - sft with data per streamline resulting from the operation. + new_data_per_point: list + The values that could now be associated to a new data_per_point key. """ # Performing operation @@ -104,13 +184,15 @@ def perform_streamline_operation_per_point(op_name, sft, dpp_name='metric', new_data_per_point.append( np.reshape(this_data_per_point, (len(this_data_per_point), 1))) - # Extracting streamlines return new_data_per_point -def perform_operation_per_streamline(op_name, sft, dpp_name='metric', - endpoints_only=False): - """Performs an operation across all data points for each streamline. +def perform_operation_dpp_to_dps(op_name, sft, dpp_name, endpoints_only=False): + """ + Converts dpp to dps, using a chosen operation. + + Performs an operation across all data_per_points for each streamline (mean, + sum, min, max, correlation). The result is a data_per_streamline. Parameters ---------- @@ -122,12 +204,14 @@ def perform_operation_per_streamline(op_name, sft, dpp_name='metric', dpp_name: str The name of the data per point to be used in the operation. endpoints_only: bool - If True, will only perform operation on endpoints + If True, will only perform operation on endpoints. Other points will be + ignored in the operation. Returns ------- - new_sft: StatefulTractogram - sft with data per streamline resulting from the operation. + new_data_per_streamline: list + The values that could now be associated to a new data_per_streamline + key. """ # Performing operation call_op = OPERATIONS[op_name] diff --git a/scilpy/tractograms/tests/test_dps_and_dpp_management.py b/scilpy/tractograms/tests/test_dps_and_dpp_management.py index 1fc2a7e88..7d57e5a69 100644 --- a/scilpy/tractograms/tests/test_dps_and_dpp_management.py +++ b/scilpy/tractograms/tests/test_dps_and_dpp_management.py @@ -1,18 +1,32 @@ +# -*- coding: utf-8 -*- + +def test_convert_dps_to_dpp(): + pass + + def test_project_map_to_streamlines(): # toDo pass -def test_perform_streamline_operation_per_point(): +def test_project_dpp_to_map(): + pass + + +def test_perform_operation_on_dpp(): # toDo pass -def test_perform_operation_per_streamline(): +def test_perform_operation_dpp_to_dps(): # toDo pass +def test_pairwise_streamline_operation_on_endpoints(): + pass + + def test_perform_streamline_operation_on_endpoints(): # toDo pass diff --git a/scripts/legacy/scil_project_streamlines_to_map.py b/scripts/legacy/scil_project_streamlines_to_map.py new file mode 100755 index 000000000..52f4bf48c --- /dev/null +++ b/scripts/legacy/scil_project_streamlines_to_map.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from scilpy.io.deprecator import deprecate_script +from scripts.scil_tractogram_project_streamlines_to_map import main as new_main + + +DEPRECATION_MSG = """ +This script has been renamed scil_tractogram_project_streamlines_to_map.py. +Please change your existing pipelines accordingly. +""" + + +@deprecate_script("scil_project_streamlines_to_map.py", + DEPRECATION_MSG, '1.7.0') +def main(): + new_main() + + +if __name__ == "__main__": + main() diff --git a/scripts/legacy/tests/test_legacy_scripts.py b/scripts/legacy/tests/test_legacy_scripts.py index 52d08518c..809926cb8 100644 --- a/scripts/legacy/tests/test_legacy_scripts.py +++ b/scripts/legacy/tests/test_legacy_scripts.py @@ -123,7 +123,7 @@ "scil_prepare_topup_command.py", "scil_print_connectivity_filenames.py", "scil_print_header.py", - "scil_project_streamlines_to_map.py", + "scil_tractogram_project_streamlines_to_map.py", "scil_recognize_multi_bundles.py", "scil_recognize_single_bundle.py", "scil_register_tractogram.py", diff --git a/scripts/scil_bundle_mean_std.py b/scripts/scil_bundle_mean_std.py index 91a453072..d672cfdfb 100755 --- a/scripts/scil_bundle_mean_std.py +++ b/scripts/scil_bundle_mean_std.py @@ -28,11 +28,11 @@ from scilpy.image.labels import get_data_as_labels from scilpy.utils.filenames import split_name_with_nii -from scilpy.io.streamlines import load_tractogram_with_reference +from scilpy.io.streamlines import (load_tractogram_with_reference, + verify_compatibility_with_reference_sft) from scilpy.io.utils import (add_json_args, add_reference_arg, add_verbose_arg, - assert_inputs_exist, assert_outputs_exist, - verify_compatibility_with_reference_sft) + assert_inputs_exist, assert_outputs_exist) from scilpy.utils.metrics_tools import get_bundle_metrics_mean_std, \ get_bundle_metrics_mean_std_per_point diff --git a/scripts/scil_project_streamlines_to_map.py b/scripts/scil_project_streamlines_to_map.py deleted file mode 100755 index 24b1adaf3..000000000 --- a/scripts/scil_project_streamlines_to_map.py +++ /dev/null @@ -1,249 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -""" -Projects metrics onto the endpoints of streamlines. The idea is to visualize -the cortical areas affected by metrics (assuming streamlines start/end in -the cortex). - -This script can project data from maps (--in_metrics), from data_per_point -(dpp) or data_per_streamline (dps): --load_dpp and --load_dps require an array -from a file (must be the right shape), --use_dpp and --use_dps work only for -.trk file and the key must exist in the metadata. - -The default options will take data from endpoints and project it to endpoints. ---from_wm will use data from whole streamlines. ---to_wm will project the data to whole streamline coverage. -This creates 4 combinations of data source and projection. -""" - -import argparse -import logging -import os - -import nibabel as nib -from nibabel.streamlines import ArraySequence -import numpy as np - -from scilpy.io.image import assert_same_resolution -from scilpy.io.streamlines import load_tractogram_with_reference -from scilpy.io.utils import (add_overwrite_arg, - add_verbose_arg, - assert_inputs_exist, - assert_output_dirs_exist_and_empty, - add_reference_arg, - load_matrix_in_any_format) -from scilpy.utils.filenames import split_name_with_nii -from scilpy.tractanalysis.streamlines_metrics import \ - compute_tract_counts_map -from scilpy.tractograms.uncompress import uncompress - - -def _build_arg_parser(): - p = argparse.ArgumentParser( - description=__doc__, - formatter_class=argparse.RawTextHelpFormatter) - - p.add_argument('in_bundle', - help='Fiber bundle file.') - p.add_argument('out_folder', - help='Folder where to save endpoints metric.') - - p1 = p.add_mutually_exclusive_group(required=True) - p1.add_argument('--in_metrics', nargs='+', default=[], - help='Nifti metric(s) to compute statistics on.') - p1.add_argument('--use_dps', metavar='DPS_KEY', nargs='+', - help='Use the data_per_streamline (scalar) from file, ' - 'e.g. commit_weights.') - p1.add_argument('--use_dpp', metavar='DPP_KEY', nargs='+', default=[], - help='Use the data_per_point (scalar) from file.') - p1.add_argument('--load_dps', metavar='DPS_KEY', nargs='+', default=[], - help='Load data per streamline (scalar) .txt or .npy.') - p1.add_argument('--load_dpp', metavar='DPP_KEY', nargs='+', default=[], - help='Load data per point (scalar) from .txt or .npy.') - - p.add_argument('--from_wm', action='store_true', - help='Project metrics from whole streamlines coverage.') - p.add_argument('--to_wm', action='store_true', - help='Project metrics into streamlines coverage.') - - add_reference_arg(p) - add_verbose_arg(p) - add_overwrite_arg(p) - - return p - - -def _compute_streamline_mean(cur_ind, cur_min, cur_max, data): - # From the precomputed indices, compute the binary map - # and use it to weight the metric data for this specific streamline. - cur_range = tuple(cur_max - cur_min) - - if len(cur_ind) == 2: - streamline_density = np.zeros(cur_range, dtype=int) - streamline_density[cur_ind[:, 0], cur_ind[:, 1]] = 1 - else: - streamline_density = compute_tract_counts_map(ArraySequence([cur_ind]), - cur_range) - streamline_data = data[cur_min[0]:cur_max[0], - cur_min[1]:cur_max[1], - cur_min[2]:cur_max[2]] - streamline_average = np.average(streamline_data, - weights=streamline_density) - return streamline_average - - -def _process_streamlines(streamlines, just_endpoints): - # Compute the bounding boxes and indices for all streamlines. - # just_endpoints will get the indices of the endpoints only for the - # usecase of projecting GM metrics into the WM. - mins = [] - maxs = [] - offset_streamlines = [] - - # Offset the streamlines to compute the indices only in the bounding box. - # Reduces memory use later on. - for idx, s in enumerate(streamlines): - mins.append(np.min(s.astype(int), 0)) - maxs.append(np.max(s.astype(int), 0) + 1) - if just_endpoints: - s = np.stack((s[0, :], s[-1, :]), axis=0) - offset_streamlines.append((s - mins[-1]).astype(np.float32)) - - offset_streamlines = ArraySequence(offset_streamlines) - - if not just_endpoints: - indices = uncompress(offset_streamlines) - else: - indices = ArraySequence() - indices._offsets = offset_streamlines._offsets - indices._lengths = offset_streamlines._lengths - indices._data = np.floor(offset_streamlines._data).astype(int) - - return mins, maxs, indices - - -def _project_metrics(curr_metric_map, count, orig_s, streamline_mean, - just_endpoints): - if just_endpoints: - xyz = orig_s[0, :].astype(int) - curr_metric_map[xyz[0], xyz[1], xyz[2]] += streamline_mean - count[xyz[0], xyz[1], xyz[2]] += 1 - - xyz = orig_s[-1, :].astype(int) - curr_metric_map[xyz[0], xyz[1], xyz[2]] += streamline_mean - count[xyz[0], xyz[1], xyz[2]] += 1 - else: - for x, y, z in orig_s[:].astype(int): - curr_metric_map[x, y, z] += streamline_mean - count[x, y, z] += 1 - - -def _pick_data(args, sft): - if args.use_dps or args.load_dps: - if args.use_dps: - for dps in args.use_dps: - if dps not in sft.data_per_streamline: - raise IOError('DPS key not in the sft: {}'.format(dps)) - name = args.use_dps - data = [sft.data_per_streamline[dps] for dps in args.use_dps] - else: - name = args.load_dps - data = [load_matrix_in_any_format(dps) for dps in args.load_dps] - for i in range(len(data)): - if len(data[i]) != len(sft): - raise IOError('DPS length does not match the SFT: {}' - .format(name[i])) - elif args.use_dpp or args.load_dpp: - if args.use_dpp: - name = args.use_dpp - data = [sft.data_per_point[dpp]._data for dpp in args.use_dpp] - else: - name = args.load_dpp - data = [load_matrix_in_any_format(dpp) for dpp in args.load_dpp] - for i in range(len(data)): - if len(data[i]) != len(sft.streamlines._data): - raise IOError('DPP length does not match the SFT: {}' - .format(name[i])) - return zip(name, data) - - -def main(): - parser = _build_arg_parser() - args = parser.parse_args() - logging.getLogger().setLevel(logging.getLevelName(args.verbose)) - - assert_inputs_exist(parser, [args.in_bundle], args.in_metrics + - args.load_dps + args.load_dpp) - assert_output_dirs_exist_and_empty(parser, args, - args.out_folder, - create_dir=True) - - sft = load_tractogram_with_reference(parser, args, args.in_bundle) - sft.to_vox() - sft.to_corner() - - if len(sft.streamlines) == 0: - logging.warning('Empty bundle file {}. Skipping'.format(args.bundle)) - return - - mins, maxs, indices = _process_streamlines(sft.streamlines, - not args.from_wm) - - if args.in_metrics: - assert_same_resolution(args.in_metrics) - metrics = [nib.load(metric) for metric in args.in_metrics] - for metric in metrics: - data = metric.get_fdata(dtype=np.float32) - curr_metric_map = np.zeros(metric.shape) - count = np.zeros(metric.shape) - for cur_min, cur_max, cur_ind, orig_s in zip(mins, maxs, indices, - sft.streamlines): - streamline_mean = _compute_streamline_mean(cur_ind, - cur_min, - cur_max, - data) - - _project_metrics(curr_metric_map, count, orig_s, - streamline_mean, not args.to_wm) - curr_metric_map[count != 0] /= count[count != 0] - metric_fname, ext = split_name_with_nii( - os.path.basename(metric.get_filename())) - nib.save(nib.Nifti1Image(curr_metric_map, metric.affine, - metric.header), - os.path.join(args.out_folder, - '{}_endpoints_metric{}'.format(metric_fname, - ext))) - else: - for fname, data in _pick_data(args, sft): - curr_metric_map = np.zeros(sft.dimensions) - count = np.zeros(sft.dimensions) - - for j in range(len(sft.streamlines)): - if args.use_dps or args.load_dps: - streamline_mean = np.mean(data[j]) - else: - tmp_data = ArraySequence() - tmp_data._data = data - tmp_data._offsets = sft.streamlines._offsets - tmp_data._lengths = sft.streamlines._lengths - - if not args.to_wm: - streamline_mean = (np.mean(tmp_data[j][-1]) - + np.mean(tmp_data[j][0])) / 2 - else: - streamline_mean = np.mean(tmp_data[j]) - - _project_metrics(curr_metric_map, count, sft.streamlines[j], - streamline_mean, not args.to_wm) - - curr_metric_map[count != 0] /= count[count != 0] - metric_fname, _ = os.path.splitext(os.path.basename(fname)) - nib.save(nib.Nifti1Image(curr_metric_map, sft.affine), - os.path.join(args.out_folder, - '{}_endpoints_metric{}'.format(metric_fname, - '.nii.gz'))) - - -if __name__ == '__main__': - main() diff --git a/scripts/scil_score_tractogram.py b/scripts/scil_score_tractogram.py index 32d0dc7ba..77301db54 100755 --- a/scripts/scil_score_tractogram.py +++ b/scripts/scil_score_tractogram.py @@ -78,7 +78,8 @@ from dipy.io.streamline import save_tractogram from dipy.io.utils import is_header_compatible -from scilpy.io.streamlines import load_tractogram_with_reference +from scilpy.io.streamlines import (load_tractogram_with_reference, + verify_compatibility_with_reference_sft) from scilpy.io.utils import (add_bbox_arg, add_overwrite_arg, add_json_args, @@ -86,7 +87,6 @@ add_verbose_arg, assert_inputs_exist, assert_output_dirs_exist_and_empty, - verify_compatibility_with_reference_sft, assert_outputs_exist) from scilpy.segment.tractogram_from_roi import (compute_masks_from_bundles, compute_endpoint_masks, diff --git a/scripts/scil_tractogram_dpp_math.py b/scripts/scil_tractogram_dpp_math.py index bedaace47..54e3a6a07 100755 --- a/scripts/scil_tractogram_dpp_math.py +++ b/scripts/scil_tractogram_dpp_math.py @@ -13,24 +13,23 @@ - In dpp mode, the operation is performed on each point separately, resulting in a new dpp. -If endpoints_only and dpp mode is set the operation will only -be calculated at the streamline endpoints the rest of the -values along the streamline will be NaN +If endpoints_only and dpp mode is set the operation will only be calculated at +the streamline endpoints the rest of the values along the streamline will be +NaN. -If endpoints_only and dps mode is set operation will be calculated -across the data at the endpoints and stored as a -single value (or array in the 4D case) per streamline. +If endpoints_only and dps mode is set operation will be calculated across the +data at the endpoints and stored as a single value (or array in the 4D case) +per streamline. Endpoint only operation: -correlation: correlation calculated between arrays extracted from -streamline endpoints (data must be multivalued per point) and dps -mode must be set. +correlation: correlation calculated between arrays extracted from streamline +endpoints (data must be multivalued per point) and dps mode must be set. """ import argparse import logging -from dipy.io.streamline import save_tractogram, StatefulTractogram +from dipy.io.streamline import save_tractogram from scilpy.io.streamlines import load_tractogram_with_reference from scilpy.io.utils import (add_bbox_arg, @@ -41,8 +40,8 @@ assert_outputs_exist) from scilpy.tractograms.dps_and_dpp_management import ( perform_pairwise_streamline_operation_on_endpoints, - perform_streamline_operation_per_point, - perform_operation_per_streamline) + perform_operation_on_dpp, + perform_operation_dpp_to_dps) def _build_arg_parser(): @@ -65,12 +64,12 @@ def _build_arg_parser(): 'streamline. Set to dpp if the operation is to be \n' 'performed on each point separately resulting in a \n' 'single value per point.') - p.add_argument('--in_dpp_name', nargs='+', required=True, + p.add_argument('--in_dpp_name', nargs='+', required=True, metavar='key', help='Name or list of names of the data_per_point for \n' 'operation to be performed on. If more than one dpp \n' 'is selected, the same operation will be applied \n' 'separately to each one.') - p.add_argument('--out_name', nargs='+', required=True, + p.add_argument('--out_name', nargs='+', required=True, metavar='key', help='Name of the resulting data_per_point or \n' 'data_per_streamline to be saved in the output \n' 'tractogram. If more than one --in_dpp_name was used, \n' @@ -88,9 +87,9 @@ def _build_arg_parser(): 'keys will be saved.') p.add_argument('--overwrite_dpp_dps', action='store_true', help='If set, if --keep_all_dpp_dps is set and some \n' - '--out_dpp_name keys already existed in your \n' - ' data_per_point or data_per_streamline, allow \n' - ' overwriting old data_per_point.') + '--out_name keys already existed in your \n' + 'data_per_point or data_per_streamline, allow \n' + 'overwriting old data_per_point.') add_reference_arg(p) add_verbose_arg(p) @@ -115,8 +114,7 @@ def main(): sft = load_tractogram_with_reference(parser, args, args.in_tractogram) if len(sft.streamlines) == 0: - logging.info("Input tractogram contains no streamlines. Exiting.") - return + parser.error("Input tractogram contains no streamlines. Exiting.") if len(args.in_dpp_name) != len(args.out_name): parser.error('The number of in_dpp_names and out_names must be ' @@ -130,9 +128,8 @@ def main(): for in_dpp_name in args.in_dpp_name: # Check to see if the data per point exists. if in_dpp_name not in sft.data_per_point: - logging.info('Data per point {} not found in input tractogram.' + parser.error('Data per point {} not found in input tractogram.' .format(in_dpp_name)) - return # warning if dpp mode and data in single number per point data_shape = sft.data_per_point[in_dpp_name][0].shape @@ -143,20 +140,17 @@ def main(): # Check if first data_per_point is multivalued if args.operation == 'correlation' and data_shape[0] == 1: - logging.info('Correlation operation requires multivalued data per ' + parser.error('Correlation operation requires multivalued data per ' 'point. Exiting.') - return if args.operation == 'correlation' and args.mode == 'dpp': - logging.info('Correlation operation requires dps mode. Exiting.') - return + parser.error('Correlation operation requires dps mode. Exiting.') if not args.overwrite_dpp_dps: if in_dpp_name in args.out_name: - logging.info('out_name {} already exists in input tractogram. ' + parser.error('out_name {} already exists in input tractogram. ' 'Set overwrite_dpp_dps or choose a different ' 'out_name. Exiting.'.format(in_dpp_name)) - return data_per_point = {} data_per_streamline = {} @@ -176,7 +170,7 @@ def main(): 'Performing {} on data from each streamine point ' 'and saving as new dpp {}'.format( args.operation, out_name)) - new_dpp = perform_streamline_operation_per_point( + new_dpp = perform_operation_on_dpp( args.operation, sft, in_dpp_name, args.endpoints_only) data_per_point[out_name] = new_dpp elif args.mode == 'dps': @@ -184,7 +178,7 @@ def main(): logging.info( 'Performing {} across each streamline and saving resulting ' 'data per streamline {}'.format(args.operation, out_name)) - new_data_per_streamline = perform_operation_per_streamline( + new_data_per_streamline = perform_operation_dpp_to_dps( args.operation, sft, in_dpp_name, args.endpoints_only) data_per_streamline[out_name] = new_data_per_streamline diff --git a/scripts/scil_tractogram_project_map_to_streamlines.py b/scripts/scil_tractogram_project_map_to_streamlines.py index 027515781..eb0f7370b 100755 --- a/scripts/scil_tractogram_project_map_to_streamlines.py +++ b/scripts/scil_tractogram_project_map_to_streamlines.py @@ -8,6 +8,20 @@ project it onto the points of streamlines. If the image is 4D, the data is stored as a list of 1D arrays per streamline. If the image is 3D, the data is stored as a list of values per streamline. + +See also scil_tractogram_project_streamlines_to_map.py for the reverse action. + +* Note that the data from your maps will be projected only on the coordinates +of the points of your streamlines. Data underlying the whole segments between +two consecutive points is not used. If your streamlines are strongly +compressed, or if they have a very big step size, the result will possibly +reflect poorly your map. You may use scil_tractogram_resample.py to upsample +your streamlines first. +* Hint: The streamlines themselves are not modified here, only their dpp. To +avoid multiplying data on disk, you could use the following arguments to save +the new dpp in your current tractogram: +>> scil_tractogram_project_map_to_streamlines.py $in_bundle $in_bundle + --keep_all_dpp -f """ import argparse diff --git a/scripts/scil_tractogram_project_streamlines_to_map.py b/scripts/scil_tractogram_project_streamlines_to_map.py new file mode 100755 index 000000000..d128f3c3a --- /dev/null +++ b/scripts/scil_tractogram_project_streamlines_to_map.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +Projects metrics onto the underlying voxels of a streamlines. This script can +project data from data_per_point (dpp) or data_per_streamline (dps) to maps. + +You choose to project data from all points of the streamlines, or from the +endpoints only. The idea then is to visualize the cortical areas affected by +metrics (assuming streamlines start/end in the cortex). + +See also scil_tractogram_project_map_to_streamlines.py for the reverse action. + +How to the data is loaded: + - From dps: uses the same value for each point of the streamline. + - From dpp: one value per point. + +How the data is used: + 1. Average all points of the streamline to get a mean value, set this value + to all points. + 2. Average the two endpoints and get their mean value, set this value to + all points. + 3. Keep each point individually. + +How the data is projected to a map: + A. Using each point. + B. Using the endpoints only. + +For more complex operations than the average per streamline, see +scil_tractogram_dpp_math.py. +""" + +import argparse +import logging +import os + +import nibabel as nib +import numpy as np + +from scilpy.io.streamlines import (load_dpp_files_as_dpp, + load_dps_files_as_dps, + load_tractogram_with_reference) +from scilpy.io.utils import (add_overwrite_arg, add_reference_arg, + add_verbose_arg, assert_inputs_exist, + assert_outputs_exist) +from scilpy.tractograms.dps_and_dpp_management import ( + convert_dps_to_dpp, perform_operation_dpp_to_dps, project_dpp_to_map) +from scilpy.utils.filenames import split_name_with_nii + + +def _build_arg_parser(): + p = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawTextHelpFormatter) + + p.add_argument('in_bundle', + help='Fiber bundle file.') + p.add_argument('out_prefix', + help='Folder + prefix to save endpoints metric(s). We will ' + 'save \none nifti file per per dpp/dps key given.\n' + 'Ex: my_path/subjX_bundleY_ with --use_dpp key1 ' + 'will output \nmy_path/subjX_bundleY_key1.nii.gz') + + p1 = p.add_argument_group( + description='Where to get the statistics from. (Choose one)') + p1 = p1.add_mutually_exclusive_group(required=True) + p1.add_argument('--use_dps', metavar='key', nargs='+', + help='Use the data_per_streamline from the tractogram.\n' + 'It must be a .trk') + p1.add_argument('--use_dpp', metavar='key', nargs='+', default=[], + help='Use the data_per_point from the tractogram. \n' + 'It must be a trk.') + p1.add_argument('--load_dps', metavar='file', nargs='+', default=[], + help='Load data per streamline (scalar) .txt or .npy.\n' + 'Must load an array with the right shape.') + p1.add_argument('--load_dpp', metavar='file', nargs='+', default=[], + help='Load data per point (scalar) from .txt or .npy.\n' + 'Must load an array with the right shape.') + + p2 = p.add_argument_group(description='Processing choices. (Choose one)') + p2 = p2.add_mutually_exclusive_group(required=True) + p2.add_argument('--mean_endpoints', action='store_true', + help="Uses one single value per streamline: the mean " + "of the two \nendpoints.") + p2.add_argument('--mean_streamline', action='store_true', + help='Use one single value per streamline: ' + 'the mean of all \npoints of the streamline.') + p2.add_argument('--point_by_point', action='store_true', + help="Directly project the streamlines values onto the " + "map.\n") + + p3 = p.add_argument_group( + description='Where to send the statistics. (Choose one)') + p3 = p3.add_mutually_exclusive_group(required=True) + p3.add_argument('--to_endpoints', action='store_true', + help="Project metrics onto a mask of the endpoints.") + p3.add_argument('--to_wm', action='store_true', + help='Project metrics into streamlines coverage.') + + add_reference_arg(p) + add_verbose_arg(p) + add_overwrite_arg(p) + + return p + + +def _load_dpp_dps(args, parser, sft): + # In call cases: only one of the values below can be set at the time. + dps_to_use = None + dpp_to_use = None + + # 1. With options --use_dps, --use_dpp: check that dps / dpp key is found. + # 2. With options --load_dps, --load_dpp: Load them now to SFT, check that + # they fit with the data. + if args.use_dps: + dps_to_use = args.use_dps + possible_dps = list(sft.data_per_streamline.keys()) + for key in args.use_dps: + if key not in possible_dps: + parser.error('DPS key not ({}) not found in your tractogram!' + .format(key)) + elif args.use_dpp: + dpp_to_use = args.use_dpp + possible_dpp = list(sft.data_per_point.keys()) + for key in args.use_dpp: + if key not in possible_dpp: + parser.error('DPP key ({}) not found in your tractogram!' + .format(key)) + elif args.load_dps: + logging.info("Loading dps from file.") + + # It does not matter if we overwrite: Not saving the result sft. + sft, dps_to_use = load_dps_files_as_dps(parser, args.load_dps, sft, + overwrite=True) + else: # args.load_dpp: + # Loading dpp for all points even if we won't use them all to make + # sure that the loaded files have the correct shape. + logging.info("Loading dpp from file") + sft, dpp_to_use = load_dpp_files_as_dpp(parser, args.load_dpp, sft, + overwrite=True) + + # Verify that we have singular values. (Ex, not colors) + # Remove unused keys to save memory. + all_keys = list(sft.data_per_point.keys()) + for key in all_keys: + if dpp_to_use is not None and key in dpp_to_use: + d0 = sft.data_per_point[key][0][0] + if len(d0) > 1: + raise ValueError( + "Expecting scalar values as data_per_point. Got data of " + "shape {} for key {}".format(d0.shape, key)) + else: + del sft.data_per_point[key] + + all_keys = list(sft.data_per_streamline.keys()) + for key in all_keys: + if dps_to_use is not None and key in dps_to_use: + d0 = sft.data_per_streamline[key][0] + if len(d0) > 1: + raise ValueError( + "Expecting scalar values as data_per_streamline. Got data " + "of shape {} for key {}.".format(d0.shape, key)) + else: + del sft.data_per_streamline[key] + + return sft, dps_to_use, dpp_to_use + + +def main(): + parser = _build_arg_parser() + args = parser.parse_args() + logging.getLogger().setLevel(logging.getLevelName(args.verbose)) + + # -------- General checks ---------- + assert_inputs_exist(parser, [args.in_bundle], + args.load_dps + args.load_dpp) + + # Find all final output files (one per metric). + if args.load_dps or args.load_dpp: + files = args.load_dps or args.load_dpp + metrics_names = [] + for file in files: + # Prepare dpp key from filename. + name = os.path.basename(file) + name, ext = split_name_with_nii(name) + metrics_names.append(name) + else: + metrics_names = args.use_dpp or args.use_dps + out_files = [args.out_prefix + m + '.nii.gz' for m in metrics_names] + assert_outputs_exist(parser, args, out_files) + + # -------- Load streamlines and checking compatibility ---------- + logging.info("Loading tractogram {}".format(args.in_bundle)) + sft = load_tractogram_with_reference(parser, args, args.in_bundle) + sft.to_vox() + sft.to_corner() + + if len(sft.streamlines) == 0: + logging.warning('Empty bundle file {}. Skipping'.format(args.bundle)) + return + + # -------- Load dps / dpp. ---------- + sft, dps_to_use, dpp_to_use = _load_dpp_dps(args, parser, sft) + + # Convert dps to dpp. Easier to manage all the remaining options without + # multiplying if - else calls. + if dps_to_use is not None: + # Then dpp_to_use is None, and the sft contains no dpp key. + # Can overwrite. + sft = convert_dps_to_dpp(sft, dps_to_use, overwrite=True) + all_keys = dps_to_use + else: + all_keys = dpp_to_use + + # -------- Format values ---------- + # In case where we average the dpp, average it now and pretend it's a dps, + # then re-copy to all dpp. + if args.mean_streamline or args.mean_endpoints: + logging.info("Averaging values for all streamlines.") + for key in all_keys: + sft.data_per_streamline[key] = perform_operation_dpp_to_dps( + 'mean', sft, key, endpoints_only=args.mean_endpoints) + sft = convert_dps_to_dpp(sft, all_keys, overwrite=True) + + # -------- Projection and saving ---------- + for key in all_keys: + logging.info("Projecting streamlines metric {} to a map".format(key)) + the_map = project_dpp_to_map(sft, key, endpoints_only=args.to_endpoints) + + out_file = args.out_prefix + key + '.nii.gz' + logging.info("Saving file {}".format(out_file)) + nib.save(nib.Nifti1Image(the_map, sft.affine), out_file) + + +if __name__ == '__main__': + main() diff --git a/scripts/tests/test_project_streamlines_to_map.py b/scripts/tests/test_project_streamlines_to_map.py deleted file mode 100644 index cdb35e7e4..000000000 --- a/scripts/tests/test_project_streamlines_to_map.py +++ /dev/null @@ -1,42 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -import os -import tempfile - -from scilpy.io.fetcher import get_testing_files_dict, fetch_data, get_home - - -# If they already exist, this only takes 5 seconds (check md5sum) -fetch_data(get_testing_files_dict(), keys=['tractometry.zip']) -tmp_dir = tempfile.TemporaryDirectory() - - -def test_help_option(script_runner): - ret = script_runner.run('scil_project_streamlines_to_map.py', '--help') - assert ret.success - - -def test_execution_tractometry_default(script_runner): - os.chdir(os.path.expanduser(tmp_dir.name)) - in_bundle = os.path.join(get_home(), 'tractometry', - 'IFGWM_uni.trk') - in_ref = os.path.join(get_home(), 'tractometry', - 'mni_masked.nii.gz') - ret = script_runner.run('scil_project_streamlines_to_map.py', in_bundle, - 'out_def/', '--in_metrics', in_ref) - - assert ret.success - - -def test_execution_tractometry_wm(script_runner): - os.chdir(os.path.expanduser(tmp_dir.name)) - in_bundle = os.path.join(get_home(), 'tractometry', - 'IFGWM_uni.trk') - in_ref = os.path.join(get_home(), 'tractometry', - 'mni_masked.nii.gz') - ret = script_runner.run('scil_project_streamlines_to_map.py', in_bundle, - 'out_wm/', '--in_metrics', in_ref, - '--to_wm', '--from_wm') - - assert ret.success diff --git a/scripts/tests/test_tractogram_dpp_math.py b/scripts/tests/test_tractogram_dpp_math.py index 69387b53d..60a9bb6f7 100644 --- a/scripts/tests/test_tractogram_dpp_math.py +++ b/scripts/tests/test_tractogram_dpp_math.py @@ -61,7 +61,7 @@ def test_execution_tractogram_point_math_mean_4D_correlation(script_runner): 'correlation', fodf_on_bundle, 'fodf_correlation_on_streamlines.trk', - '--mode', 'dpp', + '--mode', 'dps', '--in_dpp_name', 'fodf', '--out_name', 'fodf_correlation') diff --git a/scripts/tests/test_tractogram_project_streamlines_to_map.py b/scripts/tests/test_tractogram_project_streamlines_to_map.py new file mode 100644 index 000000000..48cdc5a1b --- /dev/null +++ b/scripts/tests/test_tractogram_project_streamlines_to_map.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import tempfile + +from scilpy.io.fetcher import get_testing_files_dict, fetch_data, get_home + + +# If they already exist, this only takes 5 seconds (check md5sum) +fetch_data(get_testing_files_dict(), keys=['tractometry.zip']) +tmp_dir = tempfile.TemporaryDirectory() + + +def test_help_option(script_runner): + ret = script_runner.run('scil_tractogram_project_streamlines_to_map.py', + '--help') + assert ret.success + + +def test_execution_dpp(script_runner): + os.chdir(os.path.expanduser(tmp_dir.name)) + in_bundle = os.path.join(get_home(), 'tractometry', 'IFGWM_uni.trk') + in_mni = os.path.join(get_home(), 'tractometry', 'mni_masked.nii.gz') + in_bundle_with_dpp = 'IFGWM_uni_with_dpp.trk' + + # Create our test data with dpp: add metrics as dpp. + # Or get a tractogram that already as some dpp in the test data. + script_runner.run('scil_tractogram_project_map_to_streamlines.py', + in_bundle, in_bundle_with_dpp, '-f', + '--in_maps', in_mni, '--out_dpp_name', 'some_metric') + + # Tests with dpp. + ret = script_runner.run('scil_tractogram_project_streamlines_to_map.py', + in_bundle_with_dpp, 'project_dpp_', + '--use_dpp', 'some_metric', '--point_by_point', + '--to_endpoints') + assert ret.success + + ret = script_runner.run('scil_tractogram_project_streamlines_to_map.py', + in_bundle_with_dpp, 'project_mean_to_endpoints_', + '--use_dpp', 'some_metric', '--mean_streamline', + '--to_endpoints') + assert ret.success + + ret = script_runner.run('scil_tractogram_project_streamlines_to_map.py', + in_bundle_with_dpp, 'project_end_to_wm', + '--use_dpp', 'some_metric', '--mean_endpoints', + '--to_wm') + assert ret.success + + +def test_execution_dps(script_runner): + os.chdir(os.path.expanduser(tmp_dir.name)) + in_bundle = os.path.join(get_home(), 'tractometry', 'IFGWM_uni.trk') + in_mni = os.path.join(get_home(), 'tractometry', 'mni_masked.nii.gz') + in_bundle_with_dpp = 'IFGWM_uni_with_dpp.trk' + in_bundle_with_dps = 'IFGWM_uni_with_dps.trk' + + # Create our test data with dps: add metrics as dps. + # Or get a tractogram that already as some dps in the test data. + script_runner.run('scil_tractogram_project_map_to_streamlines.py', + in_bundle, in_bundle_with_dpp, '-f', + '--in_maps', in_mni, '--out_dpp_name', 'some_metric') + script_runner.run('scil_tractogram_dpp_math.py', 'min', in_bundle_with_dpp, + in_bundle_with_dps, '--in_dpp_name', 'some_metric', + '--out_name', 'some_metric_dps', '--mode', 'dps', + '--keep_all') + + # Tests with dps. + ret = script_runner.run('scil_tractogram_project_streamlines_to_map.py', + in_bundle_with_dps, 'project_dps_', + '--use_dps', 'some_metric_dps', '--point_by_point', + '--to_wm') + assert ret.success