From 720d8088e3af5afb4051ee8a39693f87b1edc7f4 Mon Sep 17 00:00:00 2001 From: Sebastian Hoffmann Date: Mon, 18 Mar 2024 16:09:24 +0100 Subject: [PATCH] chore: cfg -> config and expose in Stage --- dmlcloud/pipeline.py | 18 +++++++++--------- dmlcloud/stage.py | 4 ++++ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/dmlcloud/pipeline.py b/dmlcloud/pipeline.py index 34b2ebd..9a0d0e2 100644 --- a/dmlcloud/pipeline.py +++ b/dmlcloud/pipeline.py @@ -17,13 +17,13 @@ class TrainingPipeline: - def __init__(self, cfg: Optional[Union[OmegaConf, Dict]] = None, name: Optional[str] = None): - if cfg is None: - self.cfg = OmegaConf.create() - elif not isinstance(cfg, OmegaConf): - self.cfg = OmegaConf.create(cfg) + def __init__(self, config: Optional[Union[OmegaConf, Dict]] = None, name: Optional[str] = None): + if config is None: + self.config = OmegaConf.create() + elif not isinstance(config, OmegaConf): + self.config = OmegaConf.create(config) else: - self.cfg = cfg + self.config = config self.name = name @@ -138,7 +138,7 @@ def enable_wandb( def initializer(): wandb_set_startup_timeout(startup_timeout) wandb.init( - config=OmegaConf.to_container(self.cfg, resolve=True), + config=OmegaConf.to_container(self.config, resolve=True), name=self.name, entity=entity, project=project, @@ -226,7 +226,7 @@ def _pre_run(self): self._resume_run() diagnostics = general_diagnostics() - diagnostics += '\n* CONFIG:\n' + OmegaConf.to_yaml(self.cfg) + diagnostics += '\n* CONFIG:\n' + OmegaConf.to_yaml(self.config) self.logger.info(diagnostics) self.pre_run() @@ -234,7 +234,7 @@ def _pre_run(self): def _init_checkpointing(self): if not self.checkpoint_dir.is_valid: self.checkpoint_dir.create() - self.checkpoint_dir.save_config(self.cfg) + self.checkpoint_dir.save_config(self.config) self.io_redirector = IORedirector(self.checkpoint_dir.log_file) self.io_redirector.install() diff --git a/dmlcloud/stage.py b/dmlcloud/stage.py index 3fa9d4f..8690122 100644 --- a/dmlcloud/stage.py +++ b/dmlcloud/stage.py @@ -45,6 +45,10 @@ def logger(self): def device(self): return self.pipeline.device + @property + def config(self): + return self.pipeline.config + def track_reduce( self, name: str,