diff --git a/README.md b/README.md index 0bd38e1..21b655f 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ pip install git+https://github.com/sehoffmann/dmlcloud.git ``` ## Minimal Example -See [examples/barebone_mnist.py](https://github.com/sehoffmann/dmlcloud/blob/develop/examples/barebone_mnist.py) for a minimal and barebone example on how to distributely train MNIST. +See [examples/mnist.py](https://github.com/sehoffmann/dmlcloud/blob/develop/examples/mnist.py) for a minimal and barebone example on how to distributely train MNIST. To run it on a single node with 4 GPUs, use ``` dmlrun -n 4 python examples/barebone_mnist.py diff --git a/dmlcloud/core/callbacks.py b/dmlcloud/core/callbacks.py index 04c9d7b..0651d44 100644 --- a/dmlcloud/core/callbacks.py +++ b/dmlcloud/core/callbacks.py @@ -5,9 +5,11 @@ from typing import Callable, Optional, TYPE_CHECKING, Union import torch +from omegaconf import OmegaConf from progress_table import ProgressTable from ..util.logging import DevNullIO, IORedirector +from ..util.wandb import wandb_is_initialized, wandb_set_startup_timeout from . import logging as dml_logging from .distributed import is_root @@ -318,17 +320,36 @@ class WandbCallback(Callback): A callback that logs metrics to Weights & Biases. """ - def __init__(self): + def __init__(self, entity, project, group, tags, startup_timeout, **kwargs): try: import wandb except ImportError: raise ImportError('wandb is required for the WandbCallback') self.wandb = wandb + self.entity = entity + self.project = project + self.group = group + self.tags = tags + self.startup_timeout = startup_timeout + self.kwargs = kwargs - def pre_stage(self, stage: 'Stage'): - self.wandb.init(project='dmlcloud', config=stage.config) + def pre_run(self, pipe: 'Pipeline'): + wandb_set_startup_timeout(self.startup_timeout) + self.wandb.init( + config=OmegaConf.to_container(pipe.config, resolve=True), + name=pipe.name, + entity=self.entity, + project=self.project, + group=self.group, + tags=self.tags, + **self.kwargs, + ) def post_epoch(self, stage: 'Stage'): metrics = stage.history.last() self.wandb.log(metrics) + + def cleanup(self, pipe, exc_type, exc_value, traceback): + if wandb_is_initialized(): + self.wandb.finish(exit_code=0 if exc_type is None else 1) diff --git a/dmlcloud/core/pipeline.py b/dmlcloud/core/pipeline.py index e1e2cba..52f6557 100644 --- a/dmlcloud/core/pipeline.py +++ b/dmlcloud/core/pipeline.py @@ -7,12 +7,11 @@ import torch.distributed as dist from omegaconf import OmegaConf -from dmlcloud.util.wandb import wandb, wandb_is_initialized, wandb_set_startup_timeout from ..util.logging import experiment_header, general_diagnostics from . import logging as dml_logging -from .callbacks import Callback, CheckpointCallback, CsvCallback +from .callbacks import Callback, CheckpointCallback, CsvCallback, WandbCallback from .checkpoint import CheckpointDir, find_slurm_checkpoint, generate_checkpoint_path -from .distributed import all_gather_object, broadcast_object, init, is_root, local_rank, root_only +from .distributed import all_gather_object, broadcast_object, init, is_root, local_rank from .stage import Stage @@ -76,7 +75,25 @@ def __exit__(self, exc_type, exc_value, traceback): class Pipeline: + """ + A training pipeline that consists of multiple stages. + + This is the main entry point for training with dmlcloud. The pipeline manages the training process and + orchestrates the execution of multiple stages. It also provides a way to add callbacks that are executed at + different points during the training process. + + Use the `append()` method to add stages to the pipeline and `add_callback()` to add callbacks. + + Checkpointing can be enabled with `enable_checkpointing()` and Weights & Biases integration with `enable_wandb()`. + + Once the pipeline is set up, call `run()` to start the training process. + """ + def __init__(self, config: Optional[Union[OmegaConf, Dict]] = None, name: Optional[str] = None): + # Auto-init torch.distributed if not already initialized + if not dist.is_initialized(): + init() + if config is None: self.config = OmegaConf.create() elif not isinstance(config, OmegaConf): @@ -84,10 +101,6 @@ def __init__(self, config: Optional[Union[OmegaConf, Dict]] = None, name: Option else: self.config = config - # Auto-init distributed if not already initialized - if not dist.is_initialized(): - init() - self.name = name self.checkpoint_dir = None @@ -160,22 +173,15 @@ def enable_wandb( startup_timeout: int = 360, **kwargs, ): - import wandb # import now to avoid potential long import times later on - - @root_only - def initializer(): - wandb_set_startup_timeout(startup_timeout) - wandb.init( - config=OmegaConf.to_container(self.config, resolve=True), - name=self.name, - entity=entity, - project=project if project else self.name, - group=group, - tags=tags, - **kwargs, - ) + if self.wandb: + raise ValueError('Wandb already enabled') + + import wandb # import now to avoid potential long import times later on # noqa + + if is_root(): + project = project or self.name + self.add_callback(WandbCallback(project, entity, group, tags, startup_timeout, **kwargs)) - self._wandb_initalizer = initializer self.wandb = True def barrier(self, timeout=None): @@ -231,14 +237,10 @@ def _pre_run(self): for stage in self.stages: stage.add_callback(_ForwardCallback()) # forward callbacks to pipeline callbacks - self.barrier( - timeout=10 * 60 - ) # important to prevent checkpoint dir creation before all processes searched for it - - if self.wandb: - self._wandb_initalizer() + # make sure everything is set up before starting the run + # important to prevent checkpoint dir creation before all processes searched for it + self.barrier(timeout=10 * 60) - self.barrier(timeout=10 * 60) # make sure everything is set up before starting the run self.start_time = datetime.now() header = '\n' + experiment_header(self.name, self.checkpoint_dir, self.start_time) @@ -272,16 +274,8 @@ def _post_run(self): dml_logging.info(f'Finished training in {self.stop_time - self.start_time} ({self.stop_time})') if self.checkpointing_enabled: dml_logging.info(f'Outputs have been saved to {self.checkpoint_dir}') + self.post_run() for callback in self.callbacks: callback.post_run(self) - - def _cleanup(self, exc_type, exc_value, traceback): - """ - Called by _RunGuard to ensure that the pipeline is properly cleaned up - """ - if self.wandb and wandb_is_initialized(): - wandb.finish(exit_code=0 if exc_type is None else 1) - - return False diff --git a/examples/README.me b/examples/README.me new file mode 100644 index 0000000..ac11d92 --- /dev/null +++ b/examples/README.me @@ -0,0 +1,10 @@ +# Examples + +This directory contains multiple examples that demonstrate the usage of dmlcloud and its features. +The `mnist` example is a good starting point for beginners. It demonstrates how to train a simple neural network +on the MNIST dataset using dmlcloud. + +| Example | Description | +| --- | --- | +| [mnist.py](mnist.py) | Minimal example that demonstrates how to train a simple neural network on the MNIST dataset using dmlcloud. | +| - | - | \ No newline at end of file diff --git a/examples/barebone_mnist.py b/examples/mnist.py similarity index 99% rename from examples/barebone_mnist.py rename to examples/mnist.py index bbb19b8..d9c0ba1 100644 --- a/examples/barebone_mnist.py +++ b/examples/mnist.py @@ -102,6 +102,7 @@ def _val_epoch(self): def main(): pipe = dml.Pipeline() pipe.enable_checkpointing('checkpoints') + pipe.enable_wandb() pipe.append(MNISTStage(epochs=3)) pipe.run()