diff --git a/dmlcloud/core/callbacks.py b/dmlcloud/core/callbacks.py index 7c981aa..1623f06 100644 --- a/dmlcloud/core/callbacks.py +++ b/dmlcloud/core/callbacks.py @@ -10,6 +10,7 @@ from omegaconf import OmegaConf from progress_table import ProgressTable +from ..git import git_diff from ..util.logging import DevNullIO, experiment_header, general_diagnostics, IORedirector from ..util.wandb import wandb_is_initialized, wandb_set_startup_timeout from . import logging as dml_logging @@ -89,6 +90,7 @@ class CbPriority(IntEnum): CHECKPOINT = -190 STAGE_TIMER = -180 DIAGNOSTICS = -170 + GIT = -160 METRIC_REDUCTION = -160 OBJECT_METHODS = 0 @@ -436,3 +438,23 @@ def post_run(self, pipe): dml_logging.info(f'Finished training in {pipe.stop_time - pipe.start_time} ({pipe.stop_time})') if pipe.checkpointing_enabled: dml_logging.info(f'Outputs have been saved to {pipe.checkpoint_dir}') + + +class GitDiffCallback(Callback): + """ + A callback that prints a git diff and if checkpointing is enabled, saves it to the checkpoint directory. + """ + + def pre_run(self, pipe): + diff = git_diff() + + if pipe.checkpointing_enabled: + self._save(pipe.checkpoint_dir.path / 'git_diff.txt', diff) + + msg = '* GIT-DIFF:\n' + msg += '\n'.join('\t' + line for line in diff.splitlines()) + dml_logging.info(msg) + + def _save(self, path, diff): + with open(path, 'w') as f: + f.write(diff) diff --git a/dmlcloud/core/pipeline.py b/dmlcloud/core/pipeline.py index d97b87f..a36024f 100644 --- a/dmlcloud/core/pipeline.py +++ b/dmlcloud/core/pipeline.py @@ -15,6 +15,7 @@ CheckpointCallback, CsvCallback, DiagnosticsCallback, + GitDiffCallback, WandbCallback, ) from .checkpoint import CheckpointDir, find_slurm_checkpoint, generate_checkpoint_path @@ -113,6 +114,7 @@ def __init__(self, config: Optional[Union[OmegaConf, Dict]] = None, name: Option self.callbacks = CallbackList() self.add_callback(DiagnosticsCallback(), CbPriority.DIAGNOSTICS) + self.add_callback(GitDiffCallback(), CbPriority.GIT) self.add_callback(_ForwardCallback(), CbPriority.OBJECT_METHODS) # methods have priority 0 if dist.is_gloo_available(): diff --git a/dmlcloud/git.py b/dmlcloud/git.py index 05800c7..dcb11d9 100644 --- a/dmlcloud/git.py +++ b/dmlcloud/git.py @@ -90,5 +90,9 @@ def git_hash(short=False): def git_diff(): + """ + Returns the output of `git diff -U0 --no-color HEAD` + """ + process = run_in_project(['git', 'diff', '-U0', '--no-color', 'HEAD']) return process.stdout.decode('utf-8').strip()