Skip to content

Commit

Permalink
feat: create checkpoint dir via callback
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Jan 3, 2025
1 parent 369ca41 commit 60faff3
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 33 deletions.
42 changes: 37 additions & 5 deletions dmlcloud/core/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
import torch
from progress_table import ProgressTable

from ..util.logging import DevNullIO
from ..util.logging import DevNullIO, IORedirector
from . import logging as dml_logging
from .distributed import is_root

if TYPE_CHECKING:
from .stage import Stage
from .pipeline import Pipeline
from .stage import Stage


__all__ = [
Expand All @@ -22,7 +22,9 @@
'TimerCallback',
'TableCallback',
'ReduceMetricsCallback',
'CheckpointCallback',
'CsvCallback',
'WandbCallback',
]


Expand Down Expand Up @@ -58,9 +60,9 @@ def post_run(self, pipe: 'Pipeline'):
"""
pass

def cleanup(self, pipe: 'Pipeline', exc_type, exc_value, traceback):
def cleanup(self, pipe: 'Pipeline', exc_type, exc_value, traceback):
"""
Executed after the pipeline finishes, even if an error occurred.
Executed after the pipeline finishes, even if an error occurred.
E.g. to close file handles.
Args:
Expand Down Expand Up @@ -226,6 +228,36 @@ def post_epoch(self, stage: 'Stage'):
stage.history.next_step()


class CheckpointCallback(Callback):
"""
Creates the checkpoint directory and optionally setups io redirection.
"""

def __init__(self, root_path: Union[str, Path], redirect_io: bool = True):
"""
Initialize the callback with the given root path.
Args:
root_path (Union[str, Path]): The root path where the checkpoint directory will be created.
redirect_io (bool, optional): Whether to redirect the IO to a file. Defaults to True.
"""
self.root_path = Path(root_path)
self.redirect_io = redirect_io
self.io_redirector = None

def pre_run(self, pipe: 'Pipeline'):
if not pipe.checkpoint_dir.is_valid:
pipe.checkpoint_dir.create()
pipe.checkpoint_dir.save_config(pipe.config)

self.io_redirector = IORedirector(pipe.checkpoint_dir.log_file)
self.io_redirector.install()

def cleanup(self, pipe, exc_type, exc_value, traceback):
if self.io_redirector is not None:
self.io_redirector.uninstall()


class CsvCallback(Callback):
"""
Saves metrics to a CSV file at the end of each epoch.
Expand Down Expand Up @@ -299,4 +331,4 @@ def pre_stage(self, stage: 'Stage'):

def post_epoch(self, stage: 'Stage'):
metrics = stage.history.last()
self.wandb.log(metrics)
self.wandb.log(metrics)
56 changes: 28 additions & 28 deletions dmlcloud/core/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
import warnings
from datetime import datetime, timedelta
from functools import cached_property
from typing import Dict, List, Optional, Sequence, Union
from typing import Dict, List, Optional, Union

import torch
import torch.distributed as dist
from omegaconf import OmegaConf
from torch.utils.data import DataLoader, Dataset

from dmlcloud.util.wandb import wandb, wandb_is_initialized, wandb_set_startup_timeout
from ..util.logging import experiment_header, general_diagnostics, IORedirector
from ..util.logging import experiment_header, general_diagnostics
from . import logging as dml_logging
from .callbacks import CsvCallback, Callback
from .callbacks import Callback, CheckpointCallback, CsvCallback
from .checkpoint import CheckpointDir, find_slurm_checkpoint, generate_checkpoint_path
from .distributed import all_gather_object, broadcast_object, init, local_rank, root_only
from .distributed import all_gather_object, broadcast_object, init, is_root, local_rank, root_only
from .stage import Stage


__all__ = [
'Pipeline',
]


class _ForwardCallback(Callback):
"""
A callback class that forwards the callback methods to all callbacks in the pipeline.
Expand All @@ -29,15 +29,15 @@ class _ForwardCallback(Callback):
def pre_stage(self, stage):
for callback in stage.pipe.callbacks:
callback.pre_stage(stage)

def post_stage(self, stage):
for callback in stage.pipe.callbacks:
callback.post_stage(stage)

def pre_epoch(self, stage):
for callback in stage.pipe.callbacks:
callback.pre_epoch(stage)

def post_epoch(self, stage):
for callback in stage.pipe.callbacks:
callback.post_epoch(stage)
Expand All @@ -47,13 +47,23 @@ class _RunGuard:
"""
Context manager that ensures that the pipeline is properly cleaned up in case of an exception or interruption.
"""

def __init__(self, pipe):
self.pipe = pipe

def __enter__(self):
pass

def __exit__(self, exc_type, exc_value, traceback):
suppress_exception = False
if exc_type is KeyboardInterrupt:
dml_logging.info('------- Training interrupted by user -------')
suppress_exception = True
elif exc_type is not None:
dml_logging.error(
'------- Training failed with an exception -------', exc_info=(exc_type, exc_value, traceback)
)

callbacks = []
if self.pipe.current_stage is not None:
callbacks += self.pipe.current_stage.callbacks
Expand All @@ -62,6 +72,8 @@ def __exit__(self, exc_type, exc_value, traceback):
for callback in callbacks:
callback.cleanup(self.pipe, exc_type, exc_value, traceback)

return suppress_exception


class Pipeline:
def __init__(self, config: Optional[Union[OmegaConf, Dict]] = None, name: Optional[str] = None):
Expand Down Expand Up @@ -135,6 +147,10 @@ def enable_checkpointing(

self.checkpoint_dir = CheckpointDir(path)

if is_root():
self.add_callback(CheckpointCallback(self.checkpoint_dir.path))
self.add_callback(CsvCallback(self.checkpoint_dir.path, append_stage_name=True))

def enable_wandb(
self,
project: str | None = None,
Expand Down Expand Up @@ -218,8 +234,6 @@ def _pre_run(self):
self.barrier(
timeout=10 * 60
) # important to prevent checkpoint dir creation before all processes searched for it
if self.checkpointing_enabled:
self._init_checkpointing()

if self.wandb:
self._wandb_initalizer()
Expand All @@ -246,15 +260,8 @@ def _pre_run(self):

self.pre_run()

@root_only
def _init_checkpointing(self):
if not self.checkpoint_dir.is_valid:
self.checkpoint_dir.create()
self.checkpoint_dir.save_config(self.config)
self.io_redirector = IORedirector(self.checkpoint_dir.log_file)
self.io_redirector.install()

self.add_callback(CsvCallback(self.checkpoint_dir.path, append_stage_name=True))
for callback in self.callbacks:
callback.pre_run(self)

def _resume_run(self):
dml_logging.info(f'Resuming training from checkpoint: {self.checkpoint_dir}')
Expand All @@ -267,21 +274,14 @@ def _post_run(self):
dml_logging.info(f'Outputs have been saved to {self.checkpoint_dir}')
self.post_run()

for callback in self.callbacks:
callback.post_run(self)

def _cleanup(self, exc_type, exc_value, traceback):
"""
Called by _RunGuard to ensure that the pipeline is properly cleaned up
"""
if exc_type is KeyboardInterrupt:
dml_logging.info('------- Training interrupted by user -------')
elif exc_type is not None:
dml_logging.error(
'------- Training failed with an exception -------', exc_info=(exc_type, exc_value, traceback)
)

if self.wandb and wandb_is_initialized():
wandb.finish(exit_code=0 if exc_type is None else 1)

if self.io_redirector is not None:
self.io_redirector.uninstall()

return False

0 comments on commit 60faff3

Please sign in to comment.