Skip to content

Commit

Permalink
feat: TensorboardCallback, closes #30
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Jan 5, 2025
1 parent 70050e0 commit aad274d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
23 changes: 23 additions & 0 deletions dmlcloud/core/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
from omegaconf import OmegaConf
from progress_table import ProgressTable
from torch.utils.tensorboard import SummaryWriter

from ..git import git_diff
from ..util.logging import DevNullIO, experiment_header, general_diagnostics, IORedirector
Expand All @@ -32,6 +33,7 @@
'CheckpointCallback',
'CsvCallback',
'WandbCallback',
'TensorboardCallback',
]


Expand Down Expand Up @@ -96,6 +98,7 @@ class CbPriority(IntEnum):
OBJECT_METHODS = 0

CSV = 110
TENSORBOARD = 110
TABLE = 120


Expand Down Expand Up @@ -416,6 +419,26 @@ def cleanup(self, pipe, exc_type, exc_value, traceback):
self.wandb.finish(exit_code=0 if exc_type is None else 1)


class TensorboardCallback(Callback):
"""
A callback that logs metrics to Tensorboard.
"""

def __init__(self, log_dir: Union[str, Path]):
self.log_dir = Path(log_dir)

def pre_run(self, pipe):
self.writer = SummaryWriter(log_dir=self.log_dir)

def post_epoch(self, stage: 'Stage'):
metrics = stage.history.last()
for key, value in metrics.items():
self.writer.add_scalar(key, value.item(), stage.current_epoch)

def cleanup(self, pipe, exc_type, exc_value, traceback):
self.writer.close()


class DiagnosticsCallback(Callback):
"""
A callback that logs diagnostics information at the beginning of training.
Expand Down
2 changes: 2 additions & 0 deletions dmlcloud/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
CsvCallback,
DiagnosticsCallback,
GitDiffCallback,
TensorboardCallback,
WandbCallback,
)
from .checkpoint import CheckpointDir, find_slurm_checkpoint, generate_checkpoint_path
Expand Down Expand Up @@ -175,6 +176,7 @@ def enable_checkpointing(
if is_root():
self.add_callback(CheckpointCallback(self.checkpoint_dir.path), CbPriority.CHECKPOINT)
self.add_callback(CsvCallback(self.checkpoint_dir.path, append_stage_name=True), CbPriority.CSV)
self.add_callback(TensorboardCallback(self.checkpoint_dir.path), CbPriority.TENSORBOARD)

def enable_wandb(
self,
Expand Down

0 comments on commit aad274d

Please sign in to comment.