diff --git a/nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py b/nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py index 52d2fc0a3..93e8f5edb 100644 --- a/nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py +++ b/nnunetv2/training/data_augmentation/custom_transforms/region_based_training.py @@ -22,17 +22,13 @@ def __init__(self, regions: Union[List, Tuple], def __call__(self, **data_dict): seg = data_dict.get(self.seg_key) - num_regions = len(self.regions) if seg is not None: - seg_shp = seg.shape - output_shape = list(seg_shp) - output_shape[1] = num_regions - region_output = np.zeros(output_shape, dtype=seg.dtype) - for b in range(seg_shp[0]): - for region_id, region_source_labels in enumerate(self.regions): - if not isinstance(region_source_labels, (list, tuple)): - region_source_labels = (region_source_labels, ) - for label_value in region_source_labels: - region_output[b, region_id][seg[b, self.seg_channel] == label_value] = 1 - data_dict[self.output_key] = region_output + b, c, *shape = seg.shape + region_output = np.zeros((b, len(self.regions), *shape), dtype=bool) + for region_id, region_labels in enumerate(self.regions): + if not isinstance(region_labels, (list, tuple)): + region_labels = (region_labels, ) + for label_value in region_labels: + region_output[:, region_id] |= (seg[:, self.seg_channel] == label_value) + data_dict[self.output_key] = region_output.astype(np.uint8, copy=False) return data_dict