diff --git a/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDAOrd0.py b/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDAOrd0.py index e87ff8f92..be31857b3 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDAOrd0.py +++ b/nnunetv2/training/nnUNetTrainer/variants/data_augmentation/nnUNetTrainerDAOrd0.py @@ -102,3 +102,56 @@ def get_dataloaders(self): max(1, allowed_num_processes // 2), 3, None, True, 0.02) return mt_gen_train, mt_gen_val + + +class nnUNetTrainer_DASegOrd0_NoMirroring(nnUNetTrainer): + def get_dataloaders(self): + """ + changed order_resampling_data, order_resampling_seg + """ + # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether + # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be + patch_size = self.configuration_manager.patch_size + dim = len(patch_size) + + # needed for deep supervision: how much do we need to downscale the segmentation targets for the different + # outputs? + deep_supervision_scales = self._get_deep_supervision_scales() + + rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \ + self.configure_rotation_dummyDA_mirroring_and_inital_patch_size() + + # Deactivate mirroring data augmentation + mirror_axes = None + self.inference_allowed_mirroring_axes = None + + # training pipeline + tr_transforms = self.get_training_transforms( + patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug, + order_resampling_data=3, order_resampling_seg=0, + use_mask_for_norm=self.configuration_manager.use_mask_for_norm, + is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.all_labels, + regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + # validation pipeline + val_transforms = self.get_validation_transforms(deep_supervision_scales, + is_cascaded=self.is_cascaded, + foreground_labels=self.label_manager.all_labels, + regions=self.label_manager.foreground_regions if + self.label_manager.has_regions else None, + ignore_label=self.label_manager.ignore_label) + + dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim) + + allowed_num_processes = get_allowed_n_proc_DA() + if allowed_num_processes == 0: + mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms) + mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms) + else: + mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms, + allowed_num_processes, 6, None, True, 0.02) + mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms, + max(1, allowed_num_processes // 2), 3, None, True, 0.02) + + return mt_gen_train, mt_gen_val