Skip to content

Commit

Permalink
add possibility to use a label mask with two values
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaudbore committed Mar 13, 2024
1 parent 546189e commit 94c77c8
Showing 1 changed file with 77 additions and 34 deletions.
111 changes: 77 additions & 34 deletions scripts/scil_tractogram_cut_streamlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@

"""
Filters streamlines and only keeps the parts of streamlines within or
between the ROIs. The script accepts a single input mask, the mask has either
1 entity/blob or 2 entities/blobs (does not support disconnected voxels).
between the ROIs. Two options are available.
Input mask:
The mask has either 1 entity/blob or
2 entities/blobs (does not support disconnected voxels).
The option --biggest_blob can help if you have such a scenario.
The 1 entity scenario will 'trim' the streamlines so their longest segment is
Expand All @@ -13,7 +17,13 @@
The 2 entities scenario will cut streamlines so their segment are within the
bounding box or going from binary mask #1 to binary mask #2.
Both scenarios will erase data_per_point and data_per_streamline.
Input label:
The label MUST contain 2 labels different from zero.
Label values could be anything.
The script will cut streamlines going from label 1 to label 2.
Both inputs and scenarios will erase data_per_point and data_per_streamline.
Formerly: scil_cut_streamlines.py
"""
Expand All @@ -29,7 +39,7 @@
import numpy as np
import scipy.ndimage as ndi

from scilpy.io.image import get_data_as_mask
from scilpy.io.image import get_data_as_mask, get_data_as_labels
from scilpy.io.streamlines import load_tractogram_with_reference
from scilpy.io.utils import (add_overwrite_arg,
add_reference_arg,
Expand All @@ -47,8 +57,15 @@ def _build_arg_parser():
formatter_class=argparse.RawTextHelpFormatter)
p.add_argument('in_tractogram',
help='Input tractogram file.')
p.add_argument('in_mask',
help='Binary mask containing either 1 or 2 blobs.')

g1 = p.add_argument_group('Mandatory mask options',
'Choose between mask or label input.')
g2 = g1.add_mutually_exclusive_group(required=True)
g2.add_argument('--mask',
help='Binary mask containing either 1 or 2 blobs.')
g2.add_argument('--label',
help='Label containing 2 blobs.')

p.add_argument('out_tractogram',
help='Output tractogram file.')

Expand All @@ -72,41 +89,67 @@ def main():
args = parser.parse_args()
logging.getLogger().setLevel(logging.getLevelName(args.verbose))

assert_inputs_exist(parser, [args.in_tractogram, args.in_mask])
assert_inputs_exist(parser, [args.in_tractogram, args.mask])
assert_outputs_exist(parser, args, args.out_tractogram)

sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
if args.step_size is not None:
sft = resample_streamlines_step_size(sft, args.step_size)

mask_img = nib.load(args.in_mask)
binary_mask = get_data_as_mask(mask_img)

if not is_header_compatible(sft, mask_img):
parser.error('Incompatible header between the tractogram and mask.')

bundle_disjoint, _ = ndi.label(binary_mask)
unique, count = np.unique(bundle_disjoint, return_counts=True)
if args.biggest_blob:
val = unique[np.argmax(count[1:])+1]
binary_mask[bundle_disjoint != val] = 0
unique = [0, val]
if len(unique) == 2:
logging.info('The provided mask has 1 entity '
'cut_outside_of_mask_streamlines function selected.')
new_sft = cut_outside_of_mask_streamlines(sft, binary_mask)
elif len(unique) == 3:
logging.info('The provided mask has 2 entity '
'cut_between_mask_two_blobs_streamlines '
'function selected.')
new_sft = cut_between_mask_two_blobs_streamlines(sft, binary_mask)

if args.mask:
mask_img = nib.load(args.mask)
binary_mask = get_data_as_mask(mask_img)

if not is_header_compatible(sft, mask_img):
parser.error('Incompatible header between the tractogram'
' and mask.')

bundle_disjoint, _ = ndi.label(binary_mask)
unique, count = np.unique(bundle_disjoint, return_counts=True)
if args.biggest_blob:
val = unique[np.argmax(count[1:])+1]
binary_mask[bundle_disjoint != val] = 0
unique = [0, val]
if len(unique) == 2:
logging.info('The provided mask has 1 entity '
'cut_outside_of_mask_streamlines function selected.')
new_sft = cut_outside_of_mask_streamlines(sft, binary_mask)
elif len(unique) == 3:
logging.info('The provided mask has 2 entity '
'cut_between_mask_two_blobs_streamlines '
'function selected.')
new_sft = cut_between_mask_two_blobs_streamlines(sft, binary_mask)

else:
logging.warning('The provided mask has MORE THAN 2 entity '
'cut_between_mask_two_blobs_streamlines function '
'selected. This may cause problems with '
'the outputed streamlines.'
' Please inspect the output carefully.')
new_sft = cut_between_mask_two_blobs_streamlines(sft, binary_mask)
else:
logging.warning('The provided mask has MORE THAN 2 entity '
'cut_between_mask_two_blobs_streamlines function '
'selected. This may cause problems with the outputed '
'streamlines. Please inspect the output carefully.')
new_sft = cut_between_mask_two_blobs_streamlines(sft, binary_mask)
label_img = nib.load(args.label)

if not is_header_compatible(sft, label_img):
parser.error('Incompatible header between'
' the tractogram and label.')

label_data = get_data_as_labels(label_img)
unique_vals = np.unique(label_data[label_data != 0])

if len(unique_vals) == 2:
label_data_1 = np.copy(label_data)
mask = label_data_1 != unique_vals[0]
label_data_1[mask] = 0

label_data_2 = np.copy(label_data)
mask = label_data_2 != unique_vals[1]
label_data_2[mask] = 0

new_sft = cut_between_mask_two_blobs_streamlines(sft, label_data_1,
label_data_2)
else:
parser.error('More than wo values in the label file.')

if len(new_sft) == 0:
logging.warning('No streamline intersected the provided mask. '
Expand Down

0 comments on commit 94c77c8

Please sign in to comment.