From 4667cf0c44efff81909e416c429a3d96f38927e7 Mon Sep 17 00:00:00 2001
From: Fabian Isensee <f.isensee@dkfz.de>
Date: Tue, 9 Apr 2024 20:28:34 +0200
Subject: [PATCH] let dataloaders start working immediately, initialize them as
 soon as possible

---
 nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py | 15 +++++++++++----
 1 file changed, 11 insertions(+), 4 deletions(-)

diff --git a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py
index 426bbf047..8a091689b 100644
--- a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py
+++ b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py
@@ -647,6 +647,9 @@ def get_dataloaders(self):
                                            transform=val_transforms, num_processes=max(1, allowed_num_processes // 2),
                                            num_cached=3, seeds=None, pin_memory=self.device.type == 'cuda',
                                            wait_time=0.02)
+        # # let's get this party started
+        _ = next(mt_gen_train)
+        _ = next(mt_gen_val)
         return mt_gen_train, mt_gen_val
 
     def get_plain_dataloaders(self, initial_patch_size: Tuple[int, ...], dim: int):
@@ -819,6 +822,10 @@ def set_deep_supervision_enabled(self, enabled: bool):
         mod.decoder.deep_supervision = enabled
 
     def on_train_start(self):
+        # dataloaders must be instantiated here (instead of __init__) because they need access to the training data
+        # which may not be present  when doing inference
+        self.dataloader_train, self.dataloader_val = self.get_dataloaders()
+
         if not self.was_initialized:
             self.initialize()
 
@@ -840,10 +847,6 @@ def on_train_start(self):
         if self.is_ddp:
             dist.barrier()
 
-        # dataloaders must be instantiated here because they need access to the training data which may not be present
-        # when doing inference
-        self.dataloader_train, self.dataloader_val = self.get_dataloaders()
-
         # copy plans and dataset.json so that they can be used for restoring everything we need for inference
         save_json(self.plans_manager.plans, join(self.output_folder_base, 'plans.json'), sort_keys=False)
         save_json(self.dataset_json, join(self.output_folder_base, 'dataset.json'), sort_keys=False)
@@ -1284,15 +1287,19 @@ def run_training(self):
 
             self.on_train_epoch_start()
             train_outputs = []
+            st = time()
             for batch_id in range(self.num_iterations_per_epoch):
                 train_outputs.append(self.train_step(next(self.dataloader_train)))
+            print('train time', time() - st)
             self.on_train_epoch_end(train_outputs)
 
             with torch.no_grad():
                 self.on_validation_epoch_start()
                 val_outputs = []
+                st = time()
                 for batch_id in range(self.num_val_iterations_per_epoch):
                     val_outputs.append(self.validation_step(next(self.dataloader_val)))
+                print('val time', time() - st)
                 self.on_validation_epoch_end(val_outputs)
 
             self.on_epoch_end()