-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbase_data_loader.py
205 lines (176 loc) · 12 KB
/
base_data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
from typing import Union, Tuple
from batchgenerators.dataloading.data_loader import DataLoader
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import *
from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset
from nnunetv2.utilities.label_handling.label_handling import LabelManager
class nnUNetDataLoaderBase(DataLoader):
def __init__(self,
data: nnUNetDataset,
batch_size: int,
patch_size: Union[List[int], Tuple[int, ...], np.ndarray],
final_patch_size: Union[List[int], Tuple[int, ...], np.ndarray],
label_manager: LabelManager,
oversample_foreground_percent: float = 0.0,
sampling_probabilities: Union[List[int], Tuple[int, ...], np.ndarray] = None,
pad_sides: Union[List[int], Tuple[int, ...], np.ndarray] = None,
probabilistic_oversampling: bool = False,
oversample_from_last_layer_mask: bool = False):
super().__init__(data, batch_size, 1, None, True, False, True, sampling_probabilities)
assert isinstance(data, nnUNetDataset), 'nnUNetDataLoaderBase only supports dictionaries as data'
self.indices = list(data.keys())
self.oversample_foreground_percent = oversample_foreground_percent
self.final_patch_size = final_patch_size
self.patch_size = patch_size
self.list_of_keys = list(self._data.keys())
# need_to_pad denotes by how much we need to pad the data so that if we sample a patch of size final_patch_size
# (which is what the network will get) these patches will also cover the border of the images
self.need_to_pad = (np.array(patch_size) - np.array(final_patch_size)).astype(int)
if pad_sides is not None:
if not isinstance(pad_sides, np.ndarray):
pad_sides = np.array(pad_sides)
self.need_to_pad += pad_sides
self.num_channels = None
self.pad_sides = pad_sides
self.data_shape, self.seg_shape = self.determine_shapes()
self.sampling_probabilities = sampling_probabilities
self.annotated_classes_key = tuple(label_manager.all_labels)
self.has_ignore = label_manager.has_ignore_label
self.get_do_oversample = self._oversample_last_XX_percent if not probabilistic_oversampling \
else self._probabilistic_oversampling
self.oversample_from_last_layer_mask = oversample_from_last_layer_mask
def _oversample_last_XX_percent(self, sample_idx: int) -> bool:
"""
determines whether sample sample_idx in a minibatch needs to be guaranteed foreground
"""
return not sample_idx < round(self.batch_size * (1 - self.oversample_foreground_percent))
def _probabilistic_oversampling(self, sample_idx: int) -> bool:
# print('YEAH BOIIIIII')
return np.random.uniform() < self.oversample_foreground_percent
def determine_shapes(self):
# load one case
data, seg, properties = self._data.load_case(self.indices[0])
num_color_channels = data.shape[0]
data_shape = (self.batch_size, num_color_channels, *self.patch_size)
seg_shape = (self.batch_size, seg.shape[0], *self.patch_size)
return data_shape, seg_shape
def get_bbox(self, data_shape: np.ndarray, force_fg: bool, class_locations: Union[dict, None],
overwrite_class: Union[int, Tuple[int, ...]] = None, verbose: bool = False):
# in dataloader 2d we need to select the slice prior to this and also modify the class_locations to only have
# locations for the given slice
need_to_pad = self.need_to_pad.copy()
dim = len(data_shape)
for d in range(dim):
# if case_all_data.shape + need_to_pad is still < patch size we need to pad more! We pad on both sides
# always
if need_to_pad[d] + data_shape[d] < self.patch_size[d]:
need_to_pad[d] = self.patch_size[d] - data_shape[d]
# we can now choose the bbox from -need_to_pad // 2 to shape - patch_size + need_to_pad // 2. Here we
# define what the upper and lower bound can be to then sample form them with np.random.randint
lbs = [- need_to_pad[i] // 2 for i in range(dim)]
ubs = [data_shape[i] + need_to_pad[i] // 2 + need_to_pad[i] % 2 - self.patch_size[i] for i in range(dim)]
# if not force_fg then we can just sample the bbox randomly from lb and ub. Else we need to make sure we get
# at least one of the foreground classes in the patch
if not force_fg and not self.has_ignore:
bbox_lbs = [np.random.randint(lbs[i], ubs[i] + 1) for i in range(dim)]
# print('I want a random location')
else:
if not force_fg and self.has_ignore:
selected_class = self.annotated_classes_key
if len(class_locations[selected_class]) == 0:
# no annotated pixels in this case. Not good. But we can hardly skip it here
print('Warning! No annotated pixels in image!')
selected_class = None
# print(f'I have ignore labels and want to pick a labeled area. annotated_classes_key: {self.annotated_classes_key}')
elif force_fg:
assert class_locations is not None, 'if force_fg is set class_locations cannot be None'
if overwrite_class is not None:
assert overwrite_class in class_locations.keys(), 'desired class ("overwrite_class") does not ' \
'have class_locations (missing key)'
# this saves us a np.unique. Preprocessing already did that for all cases. Neat.
# class_locations keys can also be tuple
eligible_classes_or_regions = [i for i in class_locations.keys() if len(class_locations[i]) > 0]
# if we have annotated_classes_key locations and other classes are present, remove the annotated_classes_key from the list
# strange formulation needed to circumvent
# ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
tmp = [i == self.annotated_classes_key if isinstance(i, tuple) else False for i in eligible_classes_or_regions]
if any(tmp):
if len(eligible_classes_or_regions) > 1:
eligible_classes_or_regions.pop(np.where(tmp)[0][0])
if len(eligible_classes_or_regions) == 0:
# this only happens if some image does not contain foreground voxels at all
selected_class = None
if verbose:
print('case does not contain any foreground classes')
else:
# I hate myself. Future me aint gonna be happy to read this
# 2022_11_25: had to read it today. Wasn't too bad
selected_class = eligible_classes_or_regions[np.random.choice(len(eligible_classes_or_regions))] if \
(overwrite_class is None or (overwrite_class not in eligible_classes_or_regions)) else overwrite_class
# print(f'I want to have foreground, selected class: {selected_class}')
else:
raise RuntimeError('lol what!?')
voxels_of_that_class = class_locations[selected_class] if selected_class is not None else None
if voxels_of_that_class is not None and len(voxels_of_that_class) > 0:
selected_voxel = voxels_of_that_class[np.random.choice(len(voxels_of_that_class))]
# selected voxel is center voxel. Subtract half the patch size to get lower bbox voxel.
# Make sure it is within the bounds of lb and ub
# i + 1 because we have first dimension 0!
bbox_lbs = [max(lbs[i], selected_voxel[i + 1] - self.patch_size[i] // 2) for i in range(dim)]
else:
# If the image does not contain any foreground classes, we fall back to random cropping
bbox_lbs = [np.random.randint(lbs[i], ubs[i] + 1) for i in range(dim)]
bbox_ubs = [bbox_lbs[i] + self.patch_size[i] for i in range(dim)]
return bbox_lbs, bbox_ubs
def get_bbox_from_last_layer_mask(self, data_shape: np.ndarray, force_fg: bool, class_locations: Union[dict, None],
class_locations_data: Union[dict, None], overwrite_class: Union[int, Tuple[int, ...]] = None, verbose: bool = False):
need_to_pad = self.need_to_pad.copy()
dim = len(data_shape)
for d in range(dim):
if need_to_pad[d] + data_shape[d] < self.patch_size[d]:
need_to_pad[d] = self.patch_size[d] - data_shape[d]
lbs = [- need_to_pad[i] // 2 for i in range(dim)]
ubs = [data_shape[i] + need_to_pad[i] // 2 + need_to_pad[i] % 2 - self.patch_size[i] for i in range(dim)]
if not force_fg and not self.has_ignore:
bbox_lbs = [np.random.randint(lbs[i], ubs[i] + 1) for i in range(dim)]
else:
if not force_fg and self.has_ignore:
selected_class = self.annotated_classes_key
if len(class_locations[selected_class]) == 0:
print('Warning! No annotated pixels in image!')
selected_class = None
elif force_fg:
assert class_locations is not None, 'if force_fg is set class_locations cannot be None'
if overwrite_class is not None:
assert overwrite_class in class_locations.keys(), 'desired class ("overwrite_class") does not ' \
'have class_locations (missing key)'
eligible_classes_or_regions = [i for i in class_locations.keys() if len(class_locations[i]) > 0]
eligible_classes_or_regions_in_data = [i for i in class_locations_data.keys() if len(class_locations_data[i]) > 0]
tmp = [i == self.annotated_classes_key if isinstance(i, tuple) else False for i in eligible_classes_or_regions]
tmp_data = [i == self.annotated_classes_key if isinstance(i, tuple) else False for i in eligible_classes_or_regions_in_data]
if any(tmp):
if len(eligible_classes_or_regions) > 1:
eligible_classes_or_regions.pop(np.where(tmp)[0][0])
if any(tmp_data):
if len(eligible_classes_or_regions_in_data) > 1:
eligible_classes_or_regions_in_data.pop(np.where(tmp_data)[0][0])
if len(eligible_classes_or_regions) == 0:
selected_class = None
if verbose:
print('case does not contain any foreground classes')
else:
selected_class = eligible_classes_or_regions[np.random.choice(len(eligible_classes_or_regions))] if \
(overwrite_class is None or (overwrite_class not in eligible_classes_or_regions)) else overwrite_class
else:
raise RuntimeError('lol what!?')
voxels_of_that_class = class_locations[selected_class] if selected_class is not None else None
if selected_class in eligible_classes_or_regions_in_data:
voxels_coming_from_data_layer = class_locations_data[selected_class] if selected_class is not None else None
voxels_of_that_class = np.vstack((voxels_of_that_class,voxels_coming_from_data_layer))
if voxels_of_that_class is not None and len(voxels_of_that_class) > 0:
selected_voxel = voxels_of_that_class[np.random.choice(len(voxels_of_that_class))]
bbox_lbs = [max(lbs[i], selected_voxel[i + 1] - self.patch_size[i] // 2) for i in range(dim)]
else:
bbox_lbs = [np.random.randint(lbs[i], ubs[i] + 1) for i in range(dim)]
bbox_ubs = [bbox_lbs[i] + self.patch_size[i] for i in range(dim)]
return bbox_lbs, bbox_ubs