Skip to content

Commit

Permalink
fix: wandb support
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Jan 3, 2025
1 parent 60faff3 commit 80f5e26
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 42 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pip install git+https://github.com/sehoffmann/dmlcloud.git
```

## Minimal Example
See [examples/barebone_mnist.py](https://github.com/sehoffmann/dmlcloud/blob/develop/examples/barebone_mnist.py) for a minimal and barebone example on how to distributely train MNIST.
See [examples/mnist.py](https://github.com/sehoffmann/dmlcloud/blob/develop/examples/mnist.py) for a minimal and barebone example on how to distributely train MNIST.
To run it on a single node with 4 GPUs, use
```
dmlrun -n 4 python examples/barebone_mnist.py
Expand Down
27 changes: 24 additions & 3 deletions dmlcloud/core/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from typing import Callable, Optional, TYPE_CHECKING, Union

import torch
from omegaconf import OmegaConf
from progress_table import ProgressTable

from ..util.logging import DevNullIO, IORedirector
from ..util.wandb import wandb_is_initialized, wandb_set_startup_timeout
from . import logging as dml_logging
from .distributed import is_root

Expand Down Expand Up @@ -318,17 +320,36 @@ class WandbCallback(Callback):
A callback that logs metrics to Weights & Biases.
"""

def __init__(self):
def __init__(self, entity, project, group, tags, startup_timeout, **kwargs):
try:
import wandb
except ImportError:
raise ImportError('wandb is required for the WandbCallback')

self.wandb = wandb
self.entity = entity
self.project = project
self.group = group
self.tags = tags
self.startup_timeout = startup_timeout
self.kwargs = kwargs

def pre_stage(self, stage: 'Stage'):
self.wandb.init(project='dmlcloud', config=stage.config)
def pre_run(self, pipe: 'Pipeline'):
wandb_set_startup_timeout(self.startup_timeout)
self.wandb.init(
config=OmegaConf.to_container(pipe.config, resolve=True),
name=pipe.name,
entity=self.entity,
project=self.project,
group=self.group,
tags=self.tags,
**self.kwargs,
)

def post_epoch(self, stage: 'Stage'):
metrics = stage.history.last()
self.wandb.log(metrics)

def cleanup(self, pipe, exc_type, exc_value, traceback):
if wandb_is_initialized():
self.wandb.finish(exit_code=0 if exc_type is None else 1)
70 changes: 32 additions & 38 deletions dmlcloud/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
import torch.distributed as dist
from omegaconf import OmegaConf

from dmlcloud.util.wandb import wandb, wandb_is_initialized, wandb_set_startup_timeout
from ..util.logging import experiment_header, general_diagnostics
from . import logging as dml_logging
from .callbacks import Callback, CheckpointCallback, CsvCallback
from .callbacks import Callback, CheckpointCallback, CsvCallback, WandbCallback
from .checkpoint import CheckpointDir, find_slurm_checkpoint, generate_checkpoint_path
from .distributed import all_gather_object, broadcast_object, init, is_root, local_rank, root_only
from .distributed import all_gather_object, broadcast_object, init, is_root, local_rank
from .stage import Stage


Expand Down Expand Up @@ -76,18 +75,32 @@ def __exit__(self, exc_type, exc_value, traceback):


class Pipeline:
"""
A training pipeline that consists of multiple stages.
This is the main entry point for training with dmlcloud. The pipeline manages the training process and
orchestrates the execution of multiple stages. It also provides a way to add callbacks that are executed at
different points during the training process.
Use the `append()` method to add stages to the pipeline and `add_callback()` to add callbacks.
Checkpointing can be enabled with `enable_checkpointing()` and Weights & Biases integration with `enable_wandb()`.
Once the pipeline is set up, call `run()` to start the training process.
"""

def __init__(self, config: Optional[Union[OmegaConf, Dict]] = None, name: Optional[str] = None):
# Auto-init torch.distributed if not already initialized
if not dist.is_initialized():
init()

if config is None:
self.config = OmegaConf.create()
elif not isinstance(config, OmegaConf):
self.config = OmegaConf.create(config)
else:
self.config = config

# Auto-init distributed if not already initialized
if not dist.is_initialized():
init()

self.name = name

self.checkpoint_dir = None
Expand Down Expand Up @@ -160,22 +173,15 @@ def enable_wandb(
startup_timeout: int = 360,
**kwargs,
):
import wandb # import now to avoid potential long import times later on

@root_only
def initializer():
wandb_set_startup_timeout(startup_timeout)
wandb.init(
config=OmegaConf.to_container(self.config, resolve=True),
name=self.name,
entity=entity,
project=project if project else self.name,
group=group,
tags=tags,
**kwargs,
)
if self.wandb:
raise ValueError('Wandb already enabled')

import wandb # import now to avoid potential long import times later on # noqa

if is_root():
project = project or self.name
self.add_callback(WandbCallback(project, entity, group, tags, startup_timeout, **kwargs))

self._wandb_initalizer = initializer
self.wandb = True

def barrier(self, timeout=None):
Expand Down Expand Up @@ -231,14 +237,10 @@ def _pre_run(self):
for stage in self.stages:
stage.add_callback(_ForwardCallback()) # forward callbacks to pipeline callbacks

self.barrier(
timeout=10 * 60
) # important to prevent checkpoint dir creation before all processes searched for it

if self.wandb:
self._wandb_initalizer()
# make sure everything is set up before starting the run
# important to prevent checkpoint dir creation before all processes searched for it
self.barrier(timeout=10 * 60)

self.barrier(timeout=10 * 60) # make sure everything is set up before starting the run
self.start_time = datetime.now()

header = '\n' + experiment_header(self.name, self.checkpoint_dir, self.start_time)
Expand Down Expand Up @@ -272,16 +274,8 @@ def _post_run(self):
dml_logging.info(f'Finished training in {self.stop_time - self.start_time} ({self.stop_time})')
if self.checkpointing_enabled:
dml_logging.info(f'Outputs have been saved to {self.checkpoint_dir}')

self.post_run()

for callback in self.callbacks:
callback.post_run(self)

def _cleanup(self, exc_type, exc_value, traceback):
"""
Called by _RunGuard to ensure that the pipeline is properly cleaned up
"""
if self.wandb and wandb_is_initialized():
wandb.finish(exit_code=0 if exc_type is None else 1)

return False
10 changes: 10 additions & 0 deletions examples/README.me
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Examples

This directory contains multiple examples that demonstrate the usage of dmlcloud and its features.
The `mnist` example is a good starting point for beginners. It demonstrates how to train a simple neural network
on the MNIST dataset using dmlcloud.

| Example | Description |
| --- | --- |
| [mnist.py](mnist.py) | Minimal example that demonstrates how to train a simple neural network on the MNIST dataset using dmlcloud. |
| - | - |
1 change: 1 addition & 0 deletions examples/barebone_mnist.py → examples/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def _val_epoch(self):
def main():
pipe = dml.Pipeline()
pipe.enable_checkpointing('checkpoints')
pipe.enable_wandb()
pipe.append(MNISTStage(epochs=3))
pipe.run()

Expand Down

0 comments on commit 80f5e26

Please sign in to comment.