From aad274dd28c8cbcb2397878edb7ce93d39d6b063 Mon Sep 17 00:00:00 2001 From: Sebastian Hoffmann Date: Sun, 5 Jan 2025 17:23:35 +0100 Subject: [PATCH] feat: TensorboardCallback, closes #30 --- dmlcloud/core/callbacks.py | 23 +++++++++++++++++++++++ dmlcloud/core/pipeline.py | 2 ++ 2 files changed, 25 insertions(+) diff --git a/dmlcloud/core/callbacks.py b/dmlcloud/core/callbacks.py index 6b4f61a..6585290 100644 --- a/dmlcloud/core/callbacks.py +++ b/dmlcloud/core/callbacks.py @@ -9,6 +9,7 @@ import torch from omegaconf import OmegaConf from progress_table import ProgressTable +from torch.utils.tensorboard import SummaryWriter from ..git import git_diff from ..util.logging import DevNullIO, experiment_header, general_diagnostics, IORedirector @@ -32,6 +33,7 @@ 'CheckpointCallback', 'CsvCallback', 'WandbCallback', + 'TensorboardCallback', ] @@ -96,6 +98,7 @@ class CbPriority(IntEnum): OBJECT_METHODS = 0 CSV = 110 + TENSORBOARD = 110 TABLE = 120 @@ -416,6 +419,26 @@ def cleanup(self, pipe, exc_type, exc_value, traceback): self.wandb.finish(exit_code=0 if exc_type is None else 1) +class TensorboardCallback(Callback): + """ + A callback that logs metrics to Tensorboard. + """ + + def __init__(self, log_dir: Union[str, Path]): + self.log_dir = Path(log_dir) + + def pre_run(self, pipe): + self.writer = SummaryWriter(log_dir=self.log_dir) + + def post_epoch(self, stage: 'Stage'): + metrics = stage.history.last() + for key, value in metrics.items(): + self.writer.add_scalar(key, value.item(), stage.current_epoch) + + def cleanup(self, pipe, exc_type, exc_value, traceback): + self.writer.close() + + class DiagnosticsCallback(Callback): """ A callback that logs diagnostics information at the beginning of training. diff --git a/dmlcloud/core/pipeline.py b/dmlcloud/core/pipeline.py index a36024f..98a9158 100644 --- a/dmlcloud/core/pipeline.py +++ b/dmlcloud/core/pipeline.py @@ -16,6 +16,7 @@ CsvCallback, DiagnosticsCallback, GitDiffCallback, + TensorboardCallback, WandbCallback, ) from .checkpoint import CheckpointDir, find_slurm_checkpoint, generate_checkpoint_path @@ -175,6 +176,7 @@ def enable_checkpointing( if is_root(): self.add_callback(CheckpointCallback(self.checkpoint_dir.path), CbPriority.CHECKPOINT) self.add_callback(CsvCallback(self.checkpoint_dir.path, append_stage_name=True), CbPriority.CSV) + self.add_callback(TensorboardCallback(self.checkpoint_dir.path), CbPriority.TENSORBOARD) def enable_wandb( self,