diff --git a/dmlcloud/core/callbacks.py b/dmlcloud/core/callbacks.py index 3deb623..7c981aa 100644 --- a/dmlcloud/core/callbacks.py +++ b/dmlcloud/core/callbacks.py @@ -1,6 +1,8 @@ import csv +import os import sys from datetime import datetime, timedelta +from enum import IntEnum from pathlib import Path from typing import Callable, Optional, TYPE_CHECKING, Union @@ -8,7 +10,7 @@ from omegaconf import OmegaConf from progress_table import ProgressTable -from ..util.logging import DevNullIO, general_diagnostics, IORedirector +from ..util.logging import DevNullIO, experiment_header, general_diagnostics, IORedirector from ..util.wandb import wandb_is_initialized, wandb_set_startup_timeout from . import logging as dml_logging from .distributed import all_gather_object, is_root @@ -20,6 +22,8 @@ __all__ = [ 'TimedeltaFormatter', + 'CallbackList', + 'CbPriority', 'Callback', 'TimerCallback', 'TableCallback', @@ -45,6 +49,54 @@ def __call__(self, value: torch.Tensor) -> str: return str(delta) +class CallbackList: + """ + A priority queue of callbacks. + """ + + def __init__(self): + self.callbacks = [] + + def append(self, callback: 'Callback', priority: int = 0): + """ + Append a callback to the list with the given priority. + + Args: + callback (Callback): The callback to append. + priority (int, optional): The priority of the callback. Defaults to 0. + """ + self.callbacks.append((priority, callback)) + + def __iter__(self): + for _, callback in sorted(self.callbacks, key=lambda x: x[0]): + yield callback + + def __len__(self): + return len(self.callbacks) + + def __add__(self, other: 'CallbackList'): + result = CallbackList() + result.callbacks = self.callbacks + other.callbacks + return result + + +class CbPriority(IntEnum): + """ + Default priorities for callbacks used by the pipeline and stage classes. + """ + + WANDB = -200 + CHECKPOINT = -190 + STAGE_TIMER = -180 + DIAGNOSTICS = -170 + METRIC_REDUCTION = -160 + + OBJECT_METHODS = 0 + + CSV = 110 + TABLE = 120 + + class Callback: """ A callback that can be registered to a stage or the whole pipeline to receive updates on the training progress. @@ -131,9 +183,6 @@ def post_epoch(self, stage: 'Stage'): eta = average_epoch_time * (stage.max_epochs - stage.current_epoch - 1) stage.log('misc/eta', eta.total_seconds(), prefixed=False) - if len(stage.pipe.stages) > 1: - dml_logging.info(f'Finished stage in {stage.end_time - stage.start_time}') - class TableCallback(Callback): """ @@ -255,6 +304,10 @@ def pre_run(self, pipe: 'Pipeline'): self.io_redirector = IORedirector(pipe.checkpoint_dir.log_file) self.io_redirector.install() + with open(pipe.checkpoint_dir.path / "environment.txt", 'w') as f: + for k, v in os.environ.items(): + f.write(f"{k}={v}\n") + def cleanup(self, pipe, exc_type, exc_value, traceback): if self.io_redirector is not None: self.io_redirector.uninstall() @@ -361,6 +414,9 @@ class DiagnosticsCallback(Callback): """ def pre_run(self, pipe): + header = '\n' + experiment_header(pipe.name, pipe.checkpoint_dir, pipe.start_time) + dml_logging.info(header) + diagnostics = general_diagnostics() diagnostics += '\n* DEVICES:\n' @@ -372,6 +428,10 @@ def pre_run(self, pipe): dml_logging.info(diagnostics) + def post_stage(self, stage): + if len(stage.pipe.stages) > 1: + dml_logging.info(f'Finished stage in {stage.end_time - stage.start_time}') + def post_run(self, pipe): dml_logging.info(f'Finished training in {pipe.stop_time - pipe.start_time} ({pipe.stop_time})') if pipe.checkpointing_enabled: diff --git a/dmlcloud/core/pipeline.py b/dmlcloud/core/pipeline.py index e2ee322..d97b87f 100644 --- a/dmlcloud/core/pipeline.py +++ b/dmlcloud/core/pipeline.py @@ -7,9 +7,16 @@ import torch.distributed as dist from omegaconf import OmegaConf -from ..util.logging import experiment_header from . import logging as dml_logging -from .callbacks import Callback, CheckpointCallback, CsvCallback, DiagnosticsCallback, WandbCallback +from .callbacks import ( + Callback, + CallbackList, + CbPriority, + CheckpointCallback, + CsvCallback, + DiagnosticsCallback, + WandbCallback, +) from .checkpoint import CheckpointDir, find_slurm_checkpoint, generate_checkpoint_path from .distributed import broadcast_object, init, is_root, local_rank from .stage import Stage @@ -20,28 +27,6 @@ ] -class _ForwardCallback(Callback): - """ - A callback class that forwards the callback methods to all callbacks in the pipeline. - """ - - def pre_stage(self, stage): - for callback in stage.pipe.callbacks: - callback.pre_stage(stage) - - def post_stage(self, stage): - for callback in stage.pipe.callbacks: - callback.post_stage(stage) - - def pre_epoch(self, stage): - for callback in stage.pipe.callbacks: - callback.pre_epoch(stage) - - def post_epoch(self, stage): - for callback in stage.pipe.callbacks: - callback.post_epoch(stage) - - class _RunGuard: """ Context manager that ensures that the pipeline is properly cleaned up in case of an exception or interruption. @@ -74,6 +59,19 @@ def __exit__(self, exc_type, exc_value, traceback): return suppress_exception +class _ForwardCallback(Callback): + """ + Invokes the pre_run, post_run methods of the Pipeline. + Stage-specific callbacks are managed by the Stage object. + """ + + def pre_run(self, pipe): + pipe.pre_run() + + def post_run(self, pipe): + pipe.post_run() + + class Pipeline: """ A training pipeline that consists of multiple stages. @@ -112,7 +110,10 @@ def __init__(self, config: Optional[Union[OmegaConf, Dict]] = None, name: Option self.wandb = False self.stages = [] - self.callbacks = [] + self.callbacks = CallbackList() + + self.add_callback(DiagnosticsCallback(), CbPriority.DIAGNOSTICS) + self.add_callback(_ForwardCallback(), CbPriority.OBJECT_METHODS) # methods have priority 0 if dist.is_gloo_available(): self.gloo_group = dist.new_group(backend='gloo') @@ -123,14 +124,21 @@ def __init__(self, config: Optional[Union[OmegaConf, Dict]] = None, name: Option def checkpointing_enabled(self): return self.checkpoint_dir is not None - def add_callback(self, callback: Callback): + def add_callback(self, callback: Callback, priority: int = 1): """ - Adds a callback to the pipeline. + Adds a callback to this pipeline. + + Callbacks added to the pipeline and not to individual stages are executed for all stages in the pipeline. + Callbacks are executed based on their priority, with lower values being executed first. + Callbacks with the same priority are executed in the order they were added. - The callback will be invoked for each stage in the pipeline and are executed in the order they are added. - Callbacks added to individual stages will be executed before the pipeline callbacks. + Methods of the stage and pipeline objects, e.g. pre_run(), have priority 0. + + Args: + callback (StageCallback): The callback to add. + priority (int, optional): The priority of the callback. Defaults to 1. """ - self.callbacks.append(callback) + self.callbacks.append(callback, priority) def append(self, stage: Stage): if not isinstance(stage, Stage): @@ -163,8 +171,8 @@ def enable_checkpointing( self.checkpoint_dir = CheckpointDir(path) if is_root(): - self.add_callback(CheckpointCallback(self.checkpoint_dir.path)) - self.add_callback(CsvCallback(self.checkpoint_dir.path, append_stage_name=True)) + self.add_callback(CheckpointCallback(self.checkpoint_dir.path), CbPriority.CHECKPOINT) + self.add_callback(CsvCallback(self.checkpoint_dir.path, append_stage_name=True), CbPriority.CSV) def enable_wandb( self, @@ -182,7 +190,7 @@ def enable_wandb( if is_root(): project = project or self.name - self.add_callback(WandbCallback(project, entity, group, tags, startup_timeout, **kwargs)) + self.add_callback(WandbCallback(project, entity, group, tags, startup_timeout, **kwargs), CbPriority.WANDB) self.wandb = True @@ -200,11 +208,6 @@ def run(self): if len(self.stages) == 0: raise ValueError('No stages defined. Use append() to add stages to the pipeline.') - for stage in self.stages: - stage.add_callback(_ForwardCallback()) # forward callbacks to pipeline callbacks - - self.add_callback(DiagnosticsCallback()) - # make sure everything is set up before starting the run # important to prevent checkpoint dir creation before all processes searched for it self.barrier(timeout=10 * 60) @@ -242,14 +245,9 @@ def device(self): def _pre_run(self): self.start_time = datetime.now() - header = '\n' + experiment_header(self.name, self.checkpoint_dir, self.start_time) - dml_logging.info(header) - if self.resumed: self._resume_run() - self.pre_run() - for callback in self.callbacks: callback.pre_run(self) @@ -260,7 +258,5 @@ def _resume_run(self): def _post_run(self): self.stop_time = datetime.now() - self.post_run() - for callback in self.callbacks: callback.post_run(self) diff --git a/dmlcloud/core/stage.py b/dmlcloud/core/stage.py index 1496887..96015aa 100644 --- a/dmlcloud/core/stage.py +++ b/dmlcloud/core/stage.py @@ -1,17 +1,33 @@ -from typing import Any, Callable, TYPE_CHECKING +from typing import Any, Callable from . import logging as dml_logging -from .callbacks import ReduceMetricsCallback, TableCallback, TimerCallback +from .callbacks import Callback, CallbackList, CbPriority, ReduceMetricsCallback, TableCallback, TimerCallback from .metrics import Tracker, TrainingHistory -if TYPE_CHECKING: - from .callbacks import Callback __all__ = [ 'Stage', ] +class _ForwardCallback(Callback): + """ + Invokes the pre_stage, post_stage, pre_epoch, and post_epoch methods of the Stage. + """ + + def pre_stage(self, stage): + stage.pre_stage() + + def post_stage(self, stage): + stage.post_stage() + + def pre_epoch(self, stage): + stage.pre_epoch() + + def post_epoch(self, stage): + stage.post_epoch() + + class Stage: """ Hook Points: @@ -25,24 +41,23 @@ def __init__(self, name: str = None, epochs: int = 1): self.name = name or self.__class__.__name__ self.max_epochs = epochs - self.callbacks: list[Callback] = [] + self.callbacks = CallbackList() self.pipe = None # set by the pipeline self.history = TrainingHistory() self.tracker = Tracker() - self._timer = TimerCallback() - self.add_callback(self._timer) - - self.add_callback(ReduceMetricsCallback()) - - self._table_callback = TableCallback() - self.add_callback(self._table_callback) - self.metric_prefix = None self.barrier_timeout = None + self._timer = TimerCallback() + self._table_callback = TableCallback() + self.add_callback(self._timer, CbPriority.STAGE_TIMER) + self.add_callback(ReduceMetricsCallback(), CbPriority.METRIC_REDUCTION) + self.add_callback(self._table_callback, CbPriority.TABLE) + self.add_callback(_ForwardCallback(), CbPriority.OBJECT_METHODS) # methods have priority 0 + @property def device(self): return self.pipe.device @@ -75,16 +90,20 @@ def epoch_end_time(self): def table(self): return self._table_callback.table - def add_callback(self, callback: 'Callback'): + def add_callback(self, callback: 'Callback', priority: int = 1): """ Adds a callback to this stage. - Callbacks are executed in the order they are added and after the stage-specific hooks. + Callbacks are executed based on their priority, with lower values being executed first. + Callbacks with the same priority are executed in the order they were added. + + The pre_stage, post_stage, pre_epoch, and post_epoch methods are treated as callbacks with priority 0. Args: callback (StageCallback): The callback to add. + priority (int, optional): The priority of the callback. Defaults to 1. """ - self.callbacks.append(callback) + self.callbacks.append(callback, priority) def log(self, name: str, value: Any, reduction: str = 'mean', prefixed: bool = True): if prefixed and self.metric_prefix: @@ -174,28 +193,25 @@ def _pre_stage(self): if len(self.pipe.stages) > 1: dml_logging.info(f'\n========== STAGE: {self.name} ==========') - self.pre_stage() - - for callback in self.callbacks: + callbacks = self.callbacks + self.pipe.callbacks + for callback in callbacks: callback.pre_stage(self) dml_logging.flush_logger() - self.pipe.barrier(self.barrier_timeout) def _post_stage(self): - self.post_stage() - for callback in self.callbacks: + callbacks = self.callbacks + self.pipe.callbacks + for callback in callbacks: callback.post_stage(self) - self.pipe.barrier(self.barrier_timeout) def _pre_epoch(self): - self.pre_epoch() - for callback in self.callbacks: + callbacks = self.callbacks + self.pipe.callbacks + for callback in callbacks: callback.pre_epoch(self) def _post_epoch(self): - self.post_epoch() - for callback in self.callbacks: + callbacks = self.callbacks + self.pipe.callbacks + for callback in callbacks: callback.post_epoch(self) diff --git a/examples/mnist.py b/examples/mnist.py index d9c0ba1..866a131 100644 --- a/examples/mnist.py +++ b/examples/mnist.py @@ -101,9 +101,9 @@ def _val_epoch(self): def main(): pipe = dml.Pipeline() + pipe.append(MNISTStage(epochs=3)) pipe.enable_checkpointing('checkpoints') pipe.enable_wandb() - pipe.append(MNISTStage(epochs=3)) pipe.run() diff --git a/test/test_callback.py b/test/test_callback.py new file mode 100644 index 0000000..783fb8e --- /dev/null +++ b/test/test_callback.py @@ -0,0 +1,209 @@ +import sys +import time + +import dmlcloud as dml +from dmlcloud.core.callbacks import CallbackList +import pytest + + +class DummyCallback(dml.Callback): + + def __init__(self, idx): + super().__init__() + self.idx = idx + + self.t_pre_run = [] + self.t_post_run = [] + self.t_pre_stage = [] + self.t_post_stage = [] + self.t_cleanup = [] + self.t_pre_epoch = [] + self.t_post_epoch = [] + + def pre_run(self, pipe): + self.t_pre_run.append(time.time()) + + def post_run(self, pipe): + self.t_post_run.append(time.time()) + + def pre_stage(self, stage): + self.t_pre_stage.append(time.time()) + + def post_stage(self, stage): + self.t_post_stage.append(time.time()) + + def cleanup(self, pipe, exc_type, exc_value, traceback): + self.t_cleanup.append(time.time()) + + def pre_epoch(self, stage): + self.t_pre_epoch.append(time.time()) + + def post_epoch(self, stage): + self.t_post_epoch.append(time.time()) + + +class DummyStage(dml.Stage): + + def __init__(self, name, epochs): + super().__init__(name, epochs) + self.t_pre_stage = [] + self.t_post_stage = [] + self.t_pre_epoch = [] + self.t_post_epoch = [] + + def pre_stage(self): + self.t_pre_stage.append(time.time()) + + def post_stage(self): + self.t_post_stage.append(time.time()) + + def pre_epoch(self): + self.t_pre_epoch.append(time.time()) + + def post_epoch(self): + self.t_post_epoch.append(time.time()) + + def run_epoch(self): + pass + + + +class TestCallbackList: + def test_priorities(self): + cb_list = CallbackList() + cb_list.append(DummyCallback(0), 100) + cb_list.append(DummyCallback(1), 50) + cb_list.append(DummyCallback(2), 200) + cb_list.append(DummyCallback(3), -100) + cb_list.append(DummyCallback(4), 100) + + indices = [cb.idx for cb in cb_list] + assert indices == [3, 1, 0, 4, 2] + + def test_combining(self): + cb_list1 = CallbackList() + cb_list1.append(DummyCallback(0), 100) + cb_list1.append(DummyCallback(1), 50) + + cb_list2 = CallbackList() + cb_list2.append(DummyCallback(2), 200) + cb_list2.append(DummyCallback(3), -100) + cb_list2.append(DummyCallback(4), 50) + + combined1 = cb_list1 + cb_list2 + indices = [cb.idx for cb in combined1] + assert indices == [3, 1, 4, 0, 2] + + # Order for same-priority depends on the order of the operands + combined2 = cb_list2 + cb_list1 + indices = [cb.idx for cb in combined2] + assert indices == [3, 4, 1, 0, 2] + + + def test_len(self): + cb_list = CallbackList() + assert len(cb_list) == 0 + + cb_list.append(DummyCallback(0), 100) + assert len(cb_list) == 1 + + cb_list.append(DummyCallback(1), 50) + assert len(cb_list) == 2 + + cb_list.append(DummyCallback(2), 200) + assert len(cb_list) == 3 + + +class TestCallback: + + def test_stage_methods(self, torch_distributed): + pipe = dml.Pipeline() + stage1 = DummyStage('stage1', 2) + pipe.append(stage1) + pipe.run() + + assert len(stage1.t_pre_stage) == 1 + assert len(stage1.t_post_stage) == 1 + assert len(stage1.t_pre_epoch) == 2 + assert len(stage1.t_post_epoch) == 2 + + assert stage1.t_pre_stage[0] < stage1.t_pre_epoch[0] + assert stage1.t_pre_epoch[0] < stage1.t_post_epoch[0] + assert stage1.t_post_epoch[0] < stage1.t_pre_epoch[1] + assert stage1.t_pre_epoch[1] < stage1.t_post_epoch[1] + assert stage1.t_post_epoch[1] < stage1.t_post_stage[0] + + def test_stage_callback(self, torch_distributed): + pipe = dml.Pipeline() + stage1 = DummyStage('stage1', 1) + stage2 = DummyStage('stage2', 1) + cb = DummyCallback(0) + + pipe.append(stage1) + pipe.append(stage2) + + stage1.add_callback(cb) + + pipe.run() + + assert len(cb.t_pre_stage) == 1 + assert len(cb.t_post_stage) == 1 + assert len(cb.t_pre_epoch) == 1 + assert len(cb.t_post_epoch) == 1 + assert len(cb.t_pre_run) == 0 + assert len(cb.t_post_run) == 0 + + assert stage1.t_pre_stage[0] < cb.t_pre_stage[0] + assert stage1.t_post_stage[0] < cb.t_post_stage[0] + + def test_stage_callback_priority(self, torch_distributed): + pipe = dml.Pipeline() + stage1 = DummyStage('stage1', 1) + stage2 = DummyStage('stage2', 1) + cb = DummyCallback(0) + + pipe.append(stage1) + pipe.append(stage2) + + stage1.add_callback(cb, priority=-1) + + pipe.run() + + assert len(cb.t_pre_stage) == 1 + assert len(cb.t_post_stage) == 1 + assert len(cb.t_pre_epoch) == 1 + assert len(cb.t_post_epoch) == 1 + assert len(cb.t_pre_run) == 0 + assert len(cb.t_post_run) == 0 + + assert cb.t_pre_stage[0] < stage1.t_pre_stage[0] + assert cb.t_post_stage[0] < stage1.t_post_stage[0] + + + def test_pipeline_callback(self, torch_distributed): + pipe = dml.Pipeline() + stage1 = DummyStage('stage1', 1) + stage2 = DummyStage('stage2', 1) + cb = DummyCallback(0) + + pipe.append(stage1) + pipe.append(stage2) + pipe.add_callback(cb) + + pipe.run() + + assert len(cb.t_pre_run) == 1 + assert len(cb.t_post_run) == 1 + assert len(cb.t_cleanup) == 1 + assert len(cb.t_pre_stage) == 2 + assert len(cb.t_post_stage) == 2 + assert len(cb.t_pre_epoch) == 2 + assert len(cb.t_post_epoch) == 2 + + assert cb.t_pre_run[0] < cb.t_pre_stage[0] + assert cb.t_post_stage[0] < cb.t_post_run[0] + assert cb.t_post_run[0] < cb.t_cleanup[0] + + +if __name__ == '__main__': + sys.exit(pytest.main([__file__]))