Skip to content

Commit

Permalink
fix: missing last epoch in wandb, closes #44
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Jan 22, 2025
1 parent e5e0d9b commit e583b26
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
35 changes: 26 additions & 9 deletions dmlcloud/core/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
'ReduceMetricsCallback',
'CheckpointCallback',
'CsvCallback',
'WandbCallback',
'WandbInitCallback',
'WandbLoggerCallback',
'TensorboardCallback',
'CudaCallback',
]
Expand Down Expand Up @@ -91,7 +92,7 @@ class CbPriority(IntEnum):
Default priorities for callbacks used by the pipeline and stage classes.
"""

WANDB = -200
WANDB_INIT = -200
CHECKPOINT = -190
STAGE_TIMER = -180
DIAGNOSTICS = -170
Expand All @@ -101,6 +102,7 @@ class CbPriority(IntEnum):

OBJECT_METHODS = 0

WANDB_LOGGER = 110
CSV = 110
TENSORBOARD = 110
TABLE = 120
Expand Down Expand Up @@ -390,16 +392,17 @@ def post_epoch(self, stage: 'Stage'):
writer.writerow(row)


class WandbCallback(Callback):
class WandbInitCallback(Callback):
"""
A callback that logs metrics to Weights & Biases.
A callback that initializes Weights & Biases and closes it at the end.
This is separated from the WandbLoggerCallback to ensure it is called right at the beginning of training.
"""

def __init__(self, project, entity, group, tags, startup_timeout, **kwargs):
try:
import wandb
except ImportError:
raise ImportError('wandb is required for the WandbCallback')
raise ImportError('wandb is required for the WandbInitCallback')

self.wandb = wandb
self.project = project
Expand All @@ -421,15 +424,29 @@ def pre_run(self, pipe: 'Pipeline'):
**self.kwargs,
)

def post_epoch(self, stage: 'Stage'):
metrics = stage.history.last()
self.wandb.log(metrics)

def cleanup(self, pipe, exc_type, exc_value, traceback):
if wandb_is_initialized():
self.wandb.finish(exit_code=0 if exc_type is None else 1)


class WandbLoggerCallback(Callback):
"""
A callback that logs metrics to Weights & Biases.
"""

def __init__(self):
try:
import wandb
except ImportError:
raise ImportError('wandb is required for the WandbLoggerCallback')

self.wandb = wandb

def post_epoch(self, stage: 'Stage'):
metrics = stage.history.last()
self.wandb.log(metrics, commit=True, step=stage.current_epoch)


class TensorboardCallback(Callback):
"""
A callback that logs metrics to Tensorboard.
Expand Down
8 changes: 5 additions & 3 deletions dmlcloud/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
DiagnosticsCallback,
GitDiffCallback,
TensorboardCallback,
WandbCallback,
WandbInitCallback,
WandbLoggerCallback,
)
from .checkpoint import CheckpointDir, find_slurm_checkpoint, generate_checkpoint_path
from .distributed import broadcast_object, init, is_root, local_rank
Expand Down Expand Up @@ -197,15 +198,16 @@ def enable_wandb(
import wandb # import now to avoid potential long import times later on # noqa

if is_root():
callback = WandbCallback(
init_callback = WandbInitCallback(
project=project,
entity=entity,
group=group,
tags=tags,
startup_timeout=startup_timeout,
**kwargs,
)
self.add_callback(callback, CbPriority.WANDB)
self.add_callback(init_callback, CbPriority.WANDB_INIT)
self.add_callback(WandbLoggerCallback(), CbPriority.WANDB_LOGGER)

self.wandb = True

Expand Down

0 comments on commit e583b26

Please sign in to comment.