Skip to content

Commit

Permalink
chore: cfg -> config and expose in Stage
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Mar 18, 2024
1 parent 037db39 commit 720d808
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
18 changes: 9 additions & 9 deletions dmlcloud/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@


class TrainingPipeline:
def __init__(self, cfg: Optional[Union[OmegaConf, Dict]] = None, name: Optional[str] = None):
if cfg is None:
self.cfg = OmegaConf.create()
elif not isinstance(cfg, OmegaConf):
self.cfg = OmegaConf.create(cfg)
def __init__(self, config: Optional[Union[OmegaConf, Dict]] = None, name: Optional[str] = None):
if config is None:
self.config = OmegaConf.create()
elif not isinstance(config, OmegaConf):
self.config = OmegaConf.create(config)
else:
self.cfg = cfg
self.config = config

self.name = name

Expand Down Expand Up @@ -138,7 +138,7 @@ def enable_wandb(
def initializer():
wandb_set_startup_timeout(startup_timeout)
wandb.init(
config=OmegaConf.to_container(self.cfg, resolve=True),
config=OmegaConf.to_container(self.config, resolve=True),
name=self.name,
entity=entity,
project=project,
Expand Down Expand Up @@ -226,15 +226,15 @@ def _pre_run(self):
self._resume_run()

diagnostics = general_diagnostics()
diagnostics += '\n* CONFIG:\n' + OmegaConf.to_yaml(self.cfg)
diagnostics += '\n* CONFIG:\n' + OmegaConf.to_yaml(self.config)
self.logger.info(diagnostics)

self.pre_run()

def _init_checkpointing(self):
if not self.checkpoint_dir.is_valid:
self.checkpoint_dir.create()
self.checkpoint_dir.save_config(self.cfg)
self.checkpoint_dir.save_config(self.config)
self.io_redirector = IORedirector(self.checkpoint_dir.log_file)
self.io_redirector.install()

Expand Down
4 changes: 4 additions & 0 deletions dmlcloud/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def logger(self):
def device(self):
return self.pipeline.device

@property
def config(self):
return self.pipeline.config

def track_reduce(
self,
name: str,
Expand Down

0 comments on commit 720d808

Please sign in to comment.