From 60b23b4d0535f60ca6127e1bd62d180572d55a54 Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Fri, 26 Apr 2024 10:54:35 +0200 Subject: [PATCH] further optimize ConvertSegmentationToRegionsTransform --- .../custom_transforms/region_based_training.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py b/nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py index 93e8f5edb..eb45d26f2 100644 --- a/nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py +++ b/nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py @@ -26,9 +26,7 @@ def __call__(self, **data_dict): b, c, *shape = seg.shape region_output = np.zeros((b, len(self.regions), *shape), dtype=bool) for region_id, region_labels in enumerate(self.regions): - if not isinstance(region_labels, (list, tuple)): - region_labels = (region_labels, ) - for label_value in region_labels: - region_output[:, region_id] |= (seg[:, self.seg_channel] == label_value) + region_output[:, region_id] |= np.isin(seg[:, self.seg_channel], region_labels) data_dict[self.output_key] = region_output.astype(np.uint8, copy=False) return data_dict +