diff --git a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py index 426bbf047..8a091689b 100644 --- a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py +++ b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py @@ -647,6 +647,9 @@ def get_dataloaders(self): transform=val_transforms, num_processes=max(1, allowed_num_processes // 2), num_cached=3, seeds=None, pin_memory=self.device.type == 'cuda', wait_time=0.02) + # # let's get this party started + _ = next(mt_gen_train) + _ = next(mt_gen_val) return mt_gen_train, mt_gen_val def get_plain_dataloaders(self, initial_patch_size: Tuple[int, ...], dim: int): @@ -819,6 +822,10 @@ def set_deep_supervision_enabled(self, enabled: bool): mod.decoder.deep_supervision = enabled def on_train_start(self): + # dataloaders must be instantiated here (instead of __init__) because they need access to the training data + # which may not be present when doing inference + self.dataloader_train, self.dataloader_val = self.get_dataloaders() + if not self.was_initialized: self.initialize() @@ -840,10 +847,6 @@ def on_train_start(self): if self.is_ddp: dist.barrier() - # dataloaders must be instantiated here because they need access to the training data which may not be present - # when doing inference - self.dataloader_train, self.dataloader_val = self.get_dataloaders() - # copy plans and dataset.json so that they can be used for restoring everything we need for inference save_json(self.plans_manager.plans, join(self.output_folder_base, 'plans.json'), sort_keys=False) save_json(self.dataset_json, join(self.output_folder_base, 'dataset.json'), sort_keys=False) @@ -1284,15 +1287,19 @@ def run_training(self): self.on_train_epoch_start() train_outputs = [] + st = time() for batch_id in range(self.num_iterations_per_epoch): train_outputs.append(self.train_step(next(self.dataloader_train))) + print('train time', time() - st) self.on_train_epoch_end(train_outputs) with torch.no_grad(): self.on_validation_epoch_start() val_outputs = [] + st = time() for batch_id in range(self.num_val_iterations_per_epoch): val_outputs.append(self.validation_step(next(self.dataloader_val))) + print('val time', time() - st) self.on_validation_epoch_end(val_outputs) self.on_epoch_end()