Skip to content

Commit

Permalink
fix: duplicate checkpoints and wandb runs on multigpu trainining
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Apr 2, 2024
1 parent a25dfb4 commit b472055
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions dmlcloud/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .checkpoint import CheckpointDir, find_slurm_checkpoint, generate_checkpoint_path
from .metrics import MetricTracker, Reduction
from .stage import Stage
from .util.distributed import local_rank
from .util.distributed import local_rank, root_only
from .util.logging import add_log_handlers, experiment_header, general_diagnostics, IORedirector


Expand Down Expand Up @@ -110,7 +110,7 @@ def append_stage(self, stage: Stage, max_epochs: Optional[int] = None, name: Opt
def enable_checkpointing(
self,
root: str,
resume: bool = True,
resume: bool = False,
):
if self.checkpointing_enabled:
raise ValueError('Checkpointing already enabled')
Expand All @@ -122,11 +122,16 @@ def enable_checkpointing(
elif resume and find_slurm_checkpoint(root):
path = find_slurm_checkpoint(root)
self.resumed = True
if path is None:
path = generate_checkpoint_path(root=root, name=self.name, creation_time=self.start_time)

if path is None: # no need for a barrier here, dir creation happens in _pre_run()
obj_list = [generate_checkpoint_path(root=root, name=self.name, creation_time=self.start_time)]
dist.broadcast_object_list(obj_list)
path = obj_list[0]
self.resumed = False

self.checkpoint_dir = CheckpointDir(path)

@root_only
def enable_wandb(
self,
project: str | None = None,
Expand All @@ -140,13 +145,14 @@ def enable_wandb(

self.wandb = True

@root_only
def initializer():
wandb_set_startup_timeout(startup_timeout)
wandb.init(
config=OmegaConf.to_container(self.config, resolve=True),
name=self.name,
entity=entity,
project=project,
project=project if project else self.name,
group=group,
tags=tags,
**kwargs,
Expand Down Expand Up @@ -215,12 +221,14 @@ def _pre_run(self):
else:
self.device = torch.device('cpu')

dist.barrier() # important to prevent checkpoint dir creation before all processes searched for it
if self.checkpointing_enabled:
self._init_checkpointing()

if self.wandb:
self._wandb_initalizer()

dist.barrier() # make sure everything is set up before starting the run
self.start_time = datetime.now()

add_log_handlers(self.logger)
Expand All @@ -237,6 +245,7 @@ def _pre_run(self):

self.pre_run()

@root_only
def _init_checkpointing(self):
if not self.checkpoint_dir.is_valid:
self.checkpoint_dir.create()
Expand Down

0 comments on commit b472055

Please sign in to comment.