Skip to content

Commit

Permalink
feat: wait for all workers before starting stage
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Apr 2, 2024
1 parent a9d9615 commit 04a8a4f
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions dmlcloud/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict, List, Optional, Union

import torch
import torch.distributed as dist
from progress_table import ProgressTable

from .metrics import MetricTracker, Reduction
Expand Down Expand Up @@ -127,6 +128,7 @@ def run(self):
Runs this stage. Either until max_epochs are reached, or until stop_stage() is called.
"""
self._pre_stage()
dist.barrier()
while self.max_epochs is None or self.current_epoch <= self.max_epochs:
self._pre_epoch()
self.run_epoch()
Expand Down

0 comments on commit 04a8a4f

Please sign in to comment.