diff --git a/nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py b/nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py index 93e8f5edb..eb45d26f2 100644 --- a/nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py +++ b/nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py @@ -26,9 +26,7 @@ def __call__(self, **data_dict): b, c, *shape = seg.shape region_output = np.zeros((b, len(self.regions), *shape), dtype=bool) for region_id, region_labels in enumerate(self.regions): - if not isinstance(region_labels, (list, tuple)): - region_labels = (region_labels, ) - for label_value in region_labels: - region_output[:, region_id] |= (seg[:, self.seg_channel] == label_value) + region_output[:, region_id] |= np.isin(seg[:, self.seg_channel], region_labels) data_dict[self.output_key] = region_output.astype(np.uint8, copy=False) return data_dict +