Skip to content

Commit

Permalink
add more info + possibility to choose label ids
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaudbore committed Mar 13, 2024
1 parent 94c77c8 commit 152c7b1
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions scripts/scil_tractogram_cut_streamlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
import numpy as np
import scipy.ndimage as ndi

from scilpy.io.image import get_data_as_mask, get_data_as_labels
from scilpy.image.labels import get_data_as_labels
from scilpy.io.image import get_data_as_mask
from scilpy.io.streamlines import load_tractogram_with_reference
from scilpy.io.utils import (add_overwrite_arg,
add_reference_arg,
Expand Down Expand Up @@ -69,6 +70,9 @@ def _build_arg_parser():
p.add_argument('out_tractogram',
help='Output tractogram file.')

p.add_argument('--label_ids', nargs=2, type=int,
help='List of labels indices to use to cut '
'streamlines (2 values).')
p.add_argument('--resample', dest='step_size', type=float, default=None,
help='Resample streamlines to a specific step-size in mm '
'[%(default)s].')
Expand Down Expand Up @@ -135,21 +139,24 @@ def main():
' the tractogram and label.')

label_data = get_data_as_labels(label_img)
unique_vals = np.unique(label_data[label_data != 0])

if len(unique_vals) == 2:
label_data_1 = np.copy(label_data)
mask = label_data_1 != unique_vals[0]
label_data_1[mask] = 0
if args.label_ids:
unique_vals = args.label_ids
else:
unique_vals = np.unique(label_data[label_data != 0])
if len(unique_vals) != 2:
parser.error('More than two values in the label file.')

label_data_2 = np.copy(label_data)
mask = label_data_2 != unique_vals[1]
label_data_2[mask] = 0
label_data_1 = np.copy(label_data)
mask = label_data_1 != unique_vals[0]
label_data_1[mask] = 0

new_sft = cut_between_mask_two_blobs_streamlines(sft, label_data_1,
label_data_2)
else:
parser.error('More than wo values in the label file.')
label_data_2 = np.copy(label_data)
mask = label_data_2 != unique_vals[1]
label_data_2[mask] = 0

new_sft = cut_between_mask_two_blobs_streamlines(sft, label_data_1,
label_data_2)

if len(new_sft) == 0:
logging.warning('No streamline intersected the provided mask. '
Expand Down

0 comments on commit 152c7b1

Please sign in to comment.