diff --git a/dmlcloud/stage.py b/dmlcloud/stage.py index ff4a7f8..6ace04d 100644 --- a/dmlcloud/stage.py +++ b/dmlcloud/stage.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Union import torch +import torch.distributed as dist from progress_table import ProgressTable from .metrics import MetricTracker, Reduction @@ -127,6 +128,7 @@ def run(self): Runs this stage. Either until max_epochs are reached, or until stop_stage() is called. """ self._pre_stage() + dist.barrier() while self.max_epochs is None or self.current_epoch <= self.max_epochs: self._pre_epoch() self.run_epoch()