From f0454e5e27c3ef7a68a5abbe8072310cba639973 Mon Sep 17 00:00:00 2001 From: Sebastian Hoffmann Date: Thu, 26 Dec 2024 19:01:23 +0100 Subject: [PATCH] feat: Save metrics to csv, closes #15 --- dmlcloud/core/callbacks.py | 34 +++++++++++++++++++++++++++++++++- dmlcloud/core/metrics.py | 1 - dmlcloud/core/pipeline.py | 11 ++++++++--- dmlcloud/core/stage.py | 4 ++-- examples/barebone_mnist.py | 1 + 5 files changed, 44 insertions(+), 7 deletions(-) diff --git a/dmlcloud/core/callbacks.py b/dmlcloud/core/callbacks.py index a170f4b..158e8b3 100644 --- a/dmlcloud/core/callbacks.py +++ b/dmlcloud/core/callbacks.py @@ -1,6 +1,8 @@ +import csv import sys from datetime import datetime, timedelta -from typing import Callable, Optional +from pathlib import Path +from typing import Callable, Optional, Union import torch from progress_table import ProgressTable @@ -16,6 +18,7 @@ 'TimreCallback', 'TableCallback', 'ReduceMetricsCallback', + 'CsvCallback', ] @@ -195,3 +198,32 @@ def post_epoch(self, stage: 'Stage'): metrics = stage.tracker.reduce() stage.history.append_metrics(**metrics) stage.history.next_step() + + +class CsvCallback(StageCallback): + """ + Saves metrics to a CSV file at the end of each epoch. + """ + + def __init__(self, path: Union[str, Path]): + self.path = Path(path) + + def pre_stage(self, stage: 'Stage'): + # If for some reason we can't write to the file or it exists already, its better to fail early + with open(self.path, 'x'): + pass + + def post_epoch(self, stage: 'Stage'): + with open(self.path, 'a') as f: + writer = csv.writer(f) + + metrics = stage.history.last() + + # Write the header if the file is empty + if f.tell() == 0: + writer.writerow(['Epoch'] + list(metrics)) + + row = [stage.current_epoch - 1] # epoch is already incremented + for value in metrics.values(): + row.append(value.item()) + writer.writerow(row) diff --git a/dmlcloud/core/metrics.py b/dmlcloud/core/metrics.py index e0377ea..31b5971 100644 --- a/dmlcloud/core/metrics.py +++ b/dmlcloud/core/metrics.py @@ -167,7 +167,6 @@ def __init__(self): super().__init__() self.metrics = torch.nn.ModuleDict() - self.external_metrics = torch.nn.ModuleDict() def add_metric(self, name: str, metric: torchmetrics.Metric): if name in self.metrics: diff --git a/dmlcloud/core/pipeline.py b/dmlcloud/core/pipeline.py index cc42fdb..38c70d2 100644 --- a/dmlcloud/core/pipeline.py +++ b/dmlcloud/core/pipeline.py @@ -13,6 +13,7 @@ from dmlcloud.util.wandb import wandb, wandb_is_initialized, wandb_set_startup_timeout from ..util.logging import experiment_header, general_diagnostics, IORedirector from . import logging as dml_logging +from .callbacks import CsvCallback from .checkpoint import CheckpointDir, find_slurm_checkpoint, generate_checkpoint_path from .distributed import all_gather_object, broadcast_object, init, local_rank, root_only from .stage import Stage @@ -32,6 +33,10 @@ def __init__(self, config: Optional[Union[OmegaConf, Dict]] = None, name: Option else: self.config = config + # Auto-init distributed if not already initialized + if not dist.is_initialized(): + init() + self.name = name self.checkpoint_dir = None @@ -179,9 +184,6 @@ def _pre_run(self): if len(self.stages) == 0: raise ValueError('No stages defined. Use append_stage() to add stages to the pipeline.') - if not dist.is_initialized(): - init() - if dist.is_gloo_available(): self.gloo_group = dist.new_group(backend='gloo') else: @@ -226,6 +228,9 @@ def _init_checkpointing(self): self.io_redirector = IORedirector(self.checkpoint_dir.log_file) self.io_redirector.install() + for stage in self.stages: + stage.add_callback(CsvCallback(self.checkpoint_dir.path / f'metrics_{stage.name}.csv')) + def _resume_run(self): dml_logging.info(f'Resuming training from checkpoint: {self.checkpoint_dir}') self.resume_run() diff --git a/dmlcloud/core/stage.py b/dmlcloud/core/stage.py index c61abe3..32ef480 100644 --- a/dmlcloud/core/stage.py +++ b/dmlcloud/core/stage.py @@ -1,7 +1,7 @@ from typing import Any, Callable, List, Optional from . import logging as dml_logging -from .callbacks import ReduceMetricsCallback, StageCallback, TableCallback, TimerCallback +from .callbacks import CsvCallback, ReduceMetricsCallback, StageCallback, TableCallback, TimerCallback from .metrics import Tracker, TrainingHistory __all__ = [ @@ -146,7 +146,7 @@ def pre_epoch(self): def post_epoch(self): """ - Executed after each epoch and after the metrics have been reduced. + Executed after each epoch. """ pass diff --git a/examples/barebone_mnist.py b/examples/barebone_mnist.py index 9b5a667..bbb19b8 100644 --- a/examples/barebone_mnist.py +++ b/examples/barebone_mnist.py @@ -101,6 +101,7 @@ def _val_epoch(self): def main(): pipe = dml.Pipeline() + pipe.enable_checkpointing('checkpoints') pipe.append(MNISTStage(epochs=3)) pipe.run()