From 45a4be1d11285b592fb38e71a1aa27ac14551f66 Mon Sep 17 00:00:00 2001 From: Fabian Isensee Date: Sat, 20 Apr 2024 13:39:52 +0200 Subject: [PATCH] wip --- .../resampling/default_resampling.py | 42 ++++++++----------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/nnunetv2/preprocessing/resampling/default_resampling.py b/nnunetv2/preprocessing/resampling/default_resampling.py index 333eb76d9..a2fb132fb 100644 --- a/nnunetv2/preprocessing/resampling/default_resampling.py +++ b/nnunetv2/preprocessing/resampling/default_resampling.py @@ -124,7 +124,7 @@ def resample_data_or_seg_to_shape(data: Union[torch.Tensor, np.ndarray], def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], List[float], np.ndarray], is_seg: bool = False, axis: Union[None, int] = None, order: int = 3, - do_separate_z: bool = False, order_z: int = 0): + do_separate_z: bool = False, order_z: int = 0, dtype_out = None): """ separate_z=True will resample with order 0 along z :param data: @@ -145,9 +145,11 @@ def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], L else: resize_fn = resize kwargs = {'mode': 'edge', 'anti_aliasing': False} - dtype_data = data.dtype shape = np.array(data[0].shape) new_shape = np.array(new_shape) + if dtype_out is None: + dtype_out = data.dtype + reshaped_final = np.zeros((data.shape[0], *new_shape), dtype=dtype_out) if np.any(shape != new_shape): data = data.astype(float) if do_separate_z: @@ -161,22 +163,20 @@ def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], L else: new_shape_2d = new_shape[:-1] - reshaped_final_data = [] for c in range(data.shape[0]): - reshaped_data = [] + reshaped_here = np.zeros((data.shape[1], *new_shape_2d)) for slice_id in range(shape[axis]): if axis == 0: - reshaped_data.append(resize_fn(data[c, slice_id], new_shape_2d, order, **kwargs)) + reshaped_here[slice_id] = resize_fn(data[c, slice_id], new_shape_2d, order, **kwargs) elif axis == 1: - reshaped_data.append(resize_fn(data[c, :, slice_id], new_shape_2d, order, **kwargs)) + reshaped_here[slice_id] = resize_fn(data[c, :, slice_id], new_shape_2d, order, **kwargs) else: - reshaped_data.append(resize_fn(data[c, :, :, slice_id], new_shape_2d, order, **kwargs)) - reshaped_data = np.stack(reshaped_data, axis) + reshaped_here[slice_id] = resize_fn(data[c, :, :, slice_id], new_shape_2d, order, **kwargs) if shape[axis] != new_shape[axis]: # The following few lines are blatantly copied and modified from sklearn's resize() rows, cols, dim = new_shape[0], new_shape[1], new_shape[2] - orig_rows, orig_cols, orig_dim = reshaped_data.shape + orig_rows, orig_cols, orig_dim = reshaped_here.shape row_scale = float(orig_rows) / rows col_scale = float(orig_cols) / cols @@ -189,28 +189,20 @@ def resample_data_or_seg(data: np.ndarray, new_shape: Union[Tuple[float, ...], L coord_map = np.array([map_rows, map_cols, map_dims]) if not is_seg or order_z == 0: - reshaped_final_data.append(map_coordinates(reshaped_data, coord_map, order=order_z, - mode='nearest')[None]) + reshaped_final[c] = map_coordinates(reshaped_here, coord_map, order=order_z, mode='nearest')[None] else: - unique_labels = np.sort(pd.unique(reshaped_data.ravel())) # np.unique(reshaped_data) - reshaped = np.zeros(new_shape, dtype=dtype_data) - + unique_labels = np.sort(pd.unique(reshaped_here.ravel())) # np.unique(reshaped_data) for i, cl in enumerate(unique_labels): - reshaped_multihot = np.round( - map_coordinates((reshaped_data == cl).astype(float), coord_map, order=order_z, - mode='nearest')) - reshaped[reshaped_multihot > 0.5] = cl - reshaped_final_data.append(reshaped[None]) + reshaped_final[c][np.round( + map_coordinates((reshaped_here == cl).astype(float), coord_map, order=order_z, + mode='nearest')) > 0.5] = cl else: - reshaped_final_data.append(reshaped_data[None]) - reshaped_final_data = np.vstack(reshaped_final_data) + reshaped_final[c] = reshaped_here else: # print("no separate z, order", order) - reshaped = [] for c in range(data.shape[0]): - reshaped.append(resize_fn(data[c], new_shape, order, **kwargs)[None]) - reshaped_final_data = np.vstack(reshaped) - return reshaped_final_data.astype(dtype_data) + reshaped_final[c] = resize_fn(data[c], new_shape, order, **kwargs) + return reshaped_final else: # print("no resampling necessary") return data