Skip to content

Commit

Permalink
feat: callback priorities
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Jan 5, 2025
1 parent 0117c38 commit 31948bb
Show file tree
Hide file tree
Showing 5 changed files with 358 additions and 77 deletions.
68 changes: 64 additions & 4 deletions dmlcloud/core/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import csv
import os
import sys
from datetime import datetime, timedelta
from enum import IntEnum
from pathlib import Path
from typing import Callable, Optional, TYPE_CHECKING, Union

import torch
from omegaconf import OmegaConf
from progress_table import ProgressTable

from ..util.logging import DevNullIO, general_diagnostics, IORedirector
from ..util.logging import DevNullIO, experiment_header, general_diagnostics, IORedirector
from ..util.wandb import wandb_is_initialized, wandb_set_startup_timeout
from . import logging as dml_logging
from .distributed import all_gather_object, is_root
Expand All @@ -20,6 +22,8 @@

__all__ = [
'TimedeltaFormatter',
'CallbackList',
'CbPriority',
'Callback',
'TimerCallback',
'TableCallback',
Expand All @@ -45,6 +49,54 @@ def __call__(self, value: torch.Tensor) -> str:
return str(delta)


class CallbackList:
"""
A priority queue of callbacks.
"""

def __init__(self):
self.callbacks = []

def append(self, callback: 'Callback', priority: int = 0):
"""
Append a callback to the list with the given priority.
Args:
callback (Callback): The callback to append.
priority (int, optional): The priority of the callback. Defaults to 0.
"""
self.callbacks.append((priority, callback))

def __iter__(self):
for _, callback in sorted(self.callbacks, key=lambda x: x[0]):
yield callback

def __len__(self):
return len(self.callbacks)

def __add__(self, other: 'CallbackList'):
result = CallbackList()
result.callbacks = self.callbacks + other.callbacks
return result


class CbPriority(IntEnum):
"""
Default priorities for callbacks used by the pipeline and stage classes.
"""

WANDB = -200
CHECKPOINT = -190
STAGE_TIMER = -180
DIAGNOSTICS = -170
METRIC_REDUCTION = -160

OBJECT_METHODS = 0

CSV = 110
TABLE = 120


class Callback:
"""
A callback that can be registered to a stage or the whole pipeline to receive updates on the training progress.
Expand Down Expand Up @@ -131,9 +183,6 @@ def post_epoch(self, stage: 'Stage'):
eta = average_epoch_time * (stage.max_epochs - stage.current_epoch - 1)
stage.log('misc/eta', eta.total_seconds(), prefixed=False)

if len(stage.pipe.stages) > 1:
dml_logging.info(f'Finished stage in {stage.end_time - stage.start_time}')


class TableCallback(Callback):
"""
Expand Down Expand Up @@ -255,6 +304,10 @@ def pre_run(self, pipe: 'Pipeline'):
self.io_redirector = IORedirector(pipe.checkpoint_dir.log_file)
self.io_redirector.install()

with open(pipe.checkpoint_dir.path / "environment.txt", 'w') as f:
for k, v in os.environ.items():
f.write(f"{k}={v}\n")

def cleanup(self, pipe, exc_type, exc_value, traceback):
if self.io_redirector is not None:
self.io_redirector.uninstall()
Expand Down Expand Up @@ -361,6 +414,9 @@ class DiagnosticsCallback(Callback):
"""

def pre_run(self, pipe):
header = '\n' + experiment_header(pipe.name, pipe.checkpoint_dir, pipe.start_time)
dml_logging.info(header)

diagnostics = general_diagnostics()

diagnostics += '\n* DEVICES:\n'
Expand All @@ -372,6 +428,10 @@ def pre_run(self, pipe):

dml_logging.info(diagnostics)

def post_stage(self, stage):
if len(stage.pipe.stages) > 1:
dml_logging.info(f'Finished stage in {stage.end_time - stage.start_time}')

def post_run(self, pipe):
dml_logging.info(f'Finished training in {pipe.stop_time - pipe.start_time} ({pipe.stop_time})')
if pipe.checkpointing_enabled:
Expand Down
86 changes: 41 additions & 45 deletions dmlcloud/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,16 @@
import torch.distributed as dist
from omegaconf import OmegaConf

from ..util.logging import experiment_header
from . import logging as dml_logging
from .callbacks import Callback, CheckpointCallback, CsvCallback, DiagnosticsCallback, WandbCallback
from .callbacks import (
Callback,
CallbackList,
CbPriority,
CheckpointCallback,
CsvCallback,
DiagnosticsCallback,
WandbCallback,
)
from .checkpoint import CheckpointDir, find_slurm_checkpoint, generate_checkpoint_path
from .distributed import broadcast_object, init, is_root, local_rank
from .stage import Stage
Expand All @@ -20,28 +27,6 @@
]


class _ForwardCallback(Callback):
"""
A callback class that forwards the callback methods to all callbacks in the pipeline.
"""

def pre_stage(self, stage):
for callback in stage.pipe.callbacks:
callback.pre_stage(stage)

def post_stage(self, stage):
for callback in stage.pipe.callbacks:
callback.post_stage(stage)

def pre_epoch(self, stage):
for callback in stage.pipe.callbacks:
callback.pre_epoch(stage)

def post_epoch(self, stage):
for callback in stage.pipe.callbacks:
callback.post_epoch(stage)


class _RunGuard:
"""
Context manager that ensures that the pipeline is properly cleaned up in case of an exception or interruption.
Expand Down Expand Up @@ -74,6 +59,19 @@ def __exit__(self, exc_type, exc_value, traceback):
return suppress_exception


class _ForwardCallback(Callback):
"""
Invokes the pre_run, post_run methods of the Pipeline.
Stage-specific callbacks are managed by the Stage object.
"""

def pre_run(self, pipe):
pipe.pre_run()

def post_run(self, pipe):
pipe.post_run()


class Pipeline:
"""
A training pipeline that consists of multiple stages.
Expand Down Expand Up @@ -112,7 +110,10 @@ def __init__(self, config: Optional[Union[OmegaConf, Dict]] = None, name: Option
self.wandb = False

self.stages = []
self.callbacks = []
self.callbacks = CallbackList()

self.add_callback(DiagnosticsCallback(), CbPriority.DIAGNOSTICS)
self.add_callback(_ForwardCallback(), CbPriority.OBJECT_METHODS) # methods have priority 0

if dist.is_gloo_available():
self.gloo_group = dist.new_group(backend='gloo')
Expand All @@ -123,14 +124,21 @@ def __init__(self, config: Optional[Union[OmegaConf, Dict]] = None, name: Option
def checkpointing_enabled(self):
return self.checkpoint_dir is not None

def add_callback(self, callback: Callback):
def add_callback(self, callback: Callback, priority: int = 1):
"""
Adds a callback to the pipeline.
Adds a callback to this pipeline.
Callbacks added to the pipeline and not to individual stages are executed for all stages in the pipeline.
Callbacks are executed based on their priority, with lower values being executed first.
Callbacks with the same priority are executed in the order they were added.
The callback will be invoked for each stage in the pipeline and are executed in the order they are added.
Callbacks added to individual stages will be executed before the pipeline callbacks.
Methods of the stage and pipeline objects, e.g. pre_run(), have priority 0.
Args:
callback (StageCallback): The callback to add.
priority (int, optional): The priority of the callback. Defaults to 1.
"""
self.callbacks.append(callback)
self.callbacks.append(callback, priority)

def append(self, stage: Stage):
if not isinstance(stage, Stage):
Expand Down Expand Up @@ -163,8 +171,8 @@ def enable_checkpointing(
self.checkpoint_dir = CheckpointDir(path)

if is_root():
self.add_callback(CheckpointCallback(self.checkpoint_dir.path))
self.add_callback(CsvCallback(self.checkpoint_dir.path, append_stage_name=True))
self.add_callback(CheckpointCallback(self.checkpoint_dir.path), CbPriority.CHECKPOINT)
self.add_callback(CsvCallback(self.checkpoint_dir.path, append_stage_name=True), CbPriority.CSV)

def enable_wandb(
self,
Expand All @@ -182,7 +190,7 @@ def enable_wandb(

if is_root():
project = project or self.name
self.add_callback(WandbCallback(project, entity, group, tags, startup_timeout, **kwargs))
self.add_callback(WandbCallback(project, entity, group, tags, startup_timeout, **kwargs), CbPriority.WANDB)

self.wandb = True

Expand All @@ -200,11 +208,6 @@ def run(self):
if len(self.stages) == 0:
raise ValueError('No stages defined. Use append() to add stages to the pipeline.')

for stage in self.stages:
stage.add_callback(_ForwardCallback()) # forward callbacks to pipeline callbacks

self.add_callback(DiagnosticsCallback())

# make sure everything is set up before starting the run
# important to prevent checkpoint dir creation before all processes searched for it
self.barrier(timeout=10 * 60)
Expand Down Expand Up @@ -242,14 +245,9 @@ def device(self):
def _pre_run(self):
self.start_time = datetime.now()

header = '\n' + experiment_header(self.name, self.checkpoint_dir, self.start_time)
dml_logging.info(header)

if self.resumed:
self._resume_run()

self.pre_run()

for callback in self.callbacks:
callback.pre_run(self)

Expand All @@ -260,7 +258,5 @@ def _resume_run(self):
def _post_run(self):
self.stop_time = datetime.now()

self.post_run()

for callback in self.callbacks:
callback.post_run(self)
Loading

0 comments on commit 31948bb

Please sign in to comment.