Skip to content

Commit

Permalink
Merge pull request #1009 from AntoineTheb/atheb/trim_streamlines
Browse files Browse the repository at this point in the history
ENH: rework `scil_tractogram_cut_streamlines` + `scil_labels_from_mask`
  • Loading branch information
arnaudbore authored Jul 22, 2024
2 parents 54fa376 + a9b75fa commit 7a98922
Show file tree
Hide file tree
Showing 10 changed files with 855 additions and 183 deletions.
62 changes: 62 additions & 0 deletions scilpy/image/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os

import numpy as np
from scipy import ndimage as ndi
from scipy.spatial import cKDTree


Expand Down Expand Up @@ -67,6 +68,67 @@ def get_binary_mask_from_labels(atlas, label_list):
return mask


def get_labels_from_mask(mask_data, labels=None, background_label=0):
"""
Get labels from a binary mask which contains multiple blobs. Each blob
will be assigned a label, by default starting from 1. Background will
be assigned the background_label value.
Parameters
----------
mask_data: np.ndarray
The mask data.
labels: list, optional
Labels to assign to each blobs in the mask. Excludes the background
label.
background_label: int
Label for the background.
Returns
-------
label_map: np.ndarray
The labels.
"""
# Get the number of structures and assign labels to each blob
label_map, nb_structures = ndi.label(mask_data)
# Assign labels to each blob if provided
if labels:
# Only keep the first nb_structures labels if the number of labels
# provided is greater than the number of blobs in the mask.
if len(labels) > nb_structures:
logging.warning("Number of labels ({}) does not match the number "
"of blobs in the mask ({}). Only the first {} "
"labels will be used.".format(
len(labels), nb_structures, nb_structures))
# Cannot assign fewer labels than the number of blobs in the mask.
elif len(labels) < nb_structures:
raise ValueError("Number of labels ({}) is less than the number of"
" blobs in the mask ({}).".format(
len(labels), nb_structures))

# Copy the label map to avoid scenarios where the label list contains
# labels that are already present in the label map
custom_label_map = label_map.copy()
# Assign labels to each blob
for idx, label in enumerate(labels[:nb_structures]):
custom_label_map[label_map == idx + 1] = label
label_map = custom_label_map

logging.info('Assigned labels {} to the mask.'.format(
np.unique(label_map[label_map != background_label])))

if background_label != 0 and background_label in label_map:
logging.warning("Background label {} corresponds to a label "
"already in the map. This will cause issues.".format(
background_label))

# Assign background label
if background_label:
label_map[label_map == 0] = background_label

return label_map


def get_lut_dir():
"""
Return LUT directory in scilpy repository
Expand Down
52 changes: 50 additions & 2 deletions scilpy/image/tests/test_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import pytest

from scilpy.image.labels import (combine_labels, dilate_labels,
get_data_as_labels, get_lut_dir,
remove_labels, split_labels)
get_data_as_labels, get_labels_from_mask,
get_lut_dir, remove_labels, split_labels)
from scilpy.tests.arrays import ref_in_labels, ref_out_labels


Expand Down Expand Up @@ -132,6 +132,54 @@ def test_get_data_as_labels_float():
_ = get_data_as_labels(img)


def test_get_labels_from_mask():
""" Test get_labels_from_mask with default labels. """
# ref_out_labels contains disjoint blobs with values 2,4,6
data = deepcopy(ref_out_labels)
data[data == 2] = 1
data[data == 4] = 2
data[data == 6] = 3
mask = data.astype(bool)

labels = get_labels_from_mask(mask)

assert_equal(labels, data)


def test_get_labels_from_mask_custom_labels_raises():
""" Test get_labels_from_mask with custom labels. """
# ref_out_labels contains disjoint blobs with values 2,4,6
data = deepcopy(ref_out_labels)
mask = data.astype(bool)
labels = get_labels_from_mask(mask, [2, 4, 6, 8])

assert np.unique(labels).size == 4 # including background


def test_get_labels_from_mask_custom_labels():
""" Test get_labels_from_mask with custom labels. """
# ref_out_labels contains disjoint blobs with values 2,4,6
data = deepcopy(ref_out_labels)
mask = data.astype(bool)

labels = get_labels_from_mask(mask, [2, 4, 6])

assert_equal(labels, data)


def test_get_labels_from_mask_custom_background():
""" test get_labels_from_mask with custom background. """
# ref_out_labels contains disjoint blobs with values 2,4,6
data = deepcopy(ref_out_labels)
mask = data.copy().astype(bool)

data[data == 0] = 9

labels = get_labels_from_mask(mask, [2, 4, 6], background_label=9)

assert_equal(labels, data)


def test_get_lut_dir():
lut_dir = get_lut_dir()
assert os.path.isdir(lut_dir)
Expand Down
4 changes: 2 additions & 2 deletions scilpy/segment/tractogram_from_roi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
from dipy.io.streamline import save_tractogram
from dipy.tracking.utils import length as compute_length

from scilpy.image.utils import \
split_mask_blobs_kmeans
from scilpy.io.image import get_data_as_mask
from scilpy.io.streamlines import load_tractogram_with_reference
from scilpy.segment.streamlines import filter_grid_roi, filter_grid_roi_both
from scilpy.tractograms.streamline_operations import \
remove_loops_and_sharp_turns
from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map

from scilpy.tractograms.streamline_and_mask_operations import \
split_mask_blobs_kmeans
from scilpy.tractograms.streamline_operations import \
filter_streamlines_by_total_length_per_dim
from scilpy.utils.filenames import split_name_with_nii
Expand Down
Loading

0 comments on commit 7a98922

Please sign in to comment.