Skip to content

Commit

Permalink
feat: git diffs
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Jan 5, 2025
1 parent 2e52d18 commit 5fa449b
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 0 deletions.
22 changes: 22 additions & 0 deletions dmlcloud/core/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from omegaconf import OmegaConf
from progress_table import ProgressTable

from ..git import git_diff
from ..util.logging import DevNullIO, experiment_header, general_diagnostics, IORedirector
from ..util.wandb import wandb_is_initialized, wandb_set_startup_timeout
from . import logging as dml_logging
Expand Down Expand Up @@ -89,6 +90,7 @@ class CbPriority(IntEnum):
CHECKPOINT = -190
STAGE_TIMER = -180
DIAGNOSTICS = -170
GIT = -160
METRIC_REDUCTION = -160

OBJECT_METHODS = 0
Expand Down Expand Up @@ -436,3 +438,23 @@ def post_run(self, pipe):
dml_logging.info(f'Finished training in {pipe.stop_time - pipe.start_time} ({pipe.stop_time})')
if pipe.checkpointing_enabled:
dml_logging.info(f'Outputs have been saved to {pipe.checkpoint_dir}')


class GitDiffCallback(Callback):
"""
A callback that prints a git diff and if checkpointing is enabled, saves it to the checkpoint directory.
"""

def pre_run(self, pipe):
diff = git_diff()

if pipe.checkpointing_enabled:
self._save(pipe.checkpoint_dir.path / 'git_diff.txt', diff)

msg = '* GIT-DIFF:\n'
msg += '\n'.join('\t' + line for line in diff.splitlines())
dml_logging.info(msg)

def _save(self, path, diff):
with open(path, 'w') as f:
f.write(diff)
2 changes: 2 additions & 0 deletions dmlcloud/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
CheckpointCallback,
CsvCallback,
DiagnosticsCallback,
GitDiffCallback,
WandbCallback,
)
from .checkpoint import CheckpointDir, find_slurm_checkpoint, generate_checkpoint_path
Expand Down Expand Up @@ -113,6 +114,7 @@ def __init__(self, config: Optional[Union[OmegaConf, Dict]] = None, name: Option
self.callbacks = CallbackList()

self.add_callback(DiagnosticsCallback(), CbPriority.DIAGNOSTICS)
self.add_callback(GitDiffCallback(), CbPriority.GIT)
self.add_callback(_ForwardCallback(), CbPriority.OBJECT_METHODS) # methods have priority 0

if dist.is_gloo_available():
Expand Down
4 changes: 4 additions & 0 deletions dmlcloud/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,9 @@ def git_hash(short=False):


def git_diff():
"""
Returns the output of `git diff -U0 --no-color HEAD`
"""

process = run_in_project(['git', 'diff', '-U0', '--no-color', 'HEAD'])
return process.stdout.decode('utf-8').strip()

0 comments on commit 5fa449b

Please sign in to comment.