Skip to content

Commit

Permalink
feat: Save metrics to csv, closes #15
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Dec 26, 2024
1 parent 3bb6340 commit f0454e5
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 7 deletions.
34 changes: 33 additions & 1 deletion dmlcloud/core/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import csv
import sys
from datetime import datetime, timedelta
from typing import Callable, Optional
from pathlib import Path
from typing import Callable, Optional, Union

import torch
from progress_table import ProgressTable
Expand All @@ -16,6 +18,7 @@
'TimreCallback',
'TableCallback',
'ReduceMetricsCallback',
'CsvCallback',
]


Expand Down Expand Up @@ -195,3 +198,32 @@ def post_epoch(self, stage: 'Stage'):
metrics = stage.tracker.reduce()
stage.history.append_metrics(**metrics)
stage.history.next_step()


class CsvCallback(StageCallback):
"""
Saves metrics to a CSV file at the end of each epoch.
"""

def __init__(self, path: Union[str, Path]):
self.path = Path(path)

def pre_stage(self, stage: 'Stage'):
# If for some reason we can't write to the file or it exists already, its better to fail early
with open(self.path, 'x'):
pass

def post_epoch(self, stage: 'Stage'):
with open(self.path, 'a') as f:
writer = csv.writer(f)

metrics = stage.history.last()

# Write the header if the file is empty
if f.tell() == 0:
writer.writerow(['Epoch'] + list(metrics))

row = [stage.current_epoch - 1] # epoch is already incremented
for value in metrics.values():
row.append(value.item())
writer.writerow(row)
1 change: 0 additions & 1 deletion dmlcloud/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ def __init__(self):
super().__init__()

self.metrics = torch.nn.ModuleDict()
self.external_metrics = torch.nn.ModuleDict()

def add_metric(self, name: str, metric: torchmetrics.Metric):
if name in self.metrics:
Expand Down
11 changes: 8 additions & 3 deletions dmlcloud/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dmlcloud.util.wandb import wandb, wandb_is_initialized, wandb_set_startup_timeout
from ..util.logging import experiment_header, general_diagnostics, IORedirector
from . import logging as dml_logging
from .callbacks import CsvCallback
from .checkpoint import CheckpointDir, find_slurm_checkpoint, generate_checkpoint_path
from .distributed import all_gather_object, broadcast_object, init, local_rank, root_only
from .stage import Stage
Expand All @@ -32,6 +33,10 @@ def __init__(self, config: Optional[Union[OmegaConf, Dict]] = None, name: Option
else:
self.config = config

# Auto-init distributed if not already initialized
if not dist.is_initialized():
init()

self.name = name

self.checkpoint_dir = None
Expand Down Expand Up @@ -179,9 +184,6 @@ def _pre_run(self):
if len(self.stages) == 0:
raise ValueError('No stages defined. Use append_stage() to add stages to the pipeline.')

if not dist.is_initialized():
init()

if dist.is_gloo_available():
self.gloo_group = dist.new_group(backend='gloo')
else:
Expand Down Expand Up @@ -226,6 +228,9 @@ def _init_checkpointing(self):
self.io_redirector = IORedirector(self.checkpoint_dir.log_file)
self.io_redirector.install()

for stage in self.stages:
stage.add_callback(CsvCallback(self.checkpoint_dir.path / f'metrics_{stage.name}.csv'))

def _resume_run(self):
dml_logging.info(f'Resuming training from checkpoint: {self.checkpoint_dir}')
self.resume_run()
Expand Down
4 changes: 2 additions & 2 deletions dmlcloud/core/stage.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Callable, List, Optional

from . import logging as dml_logging
from .callbacks import ReduceMetricsCallback, StageCallback, TableCallback, TimerCallback
from .callbacks import CsvCallback, ReduceMetricsCallback, StageCallback, TableCallback, TimerCallback
from .metrics import Tracker, TrainingHistory

__all__ = [
Expand Down Expand Up @@ -146,7 +146,7 @@ def pre_epoch(self):

def post_epoch(self):
"""
Executed after each epoch and after the metrics have been reduced.
Executed after each epoch.
"""
pass

Expand Down
1 change: 1 addition & 0 deletions examples/barebone_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def _val_epoch(self):

def main():
pipe = dml.Pipeline()
pipe.enable_checkpointing('checkpoints')
pipe.append(MNISTStage(epochs=3))
pipe.run()

Expand Down

0 comments on commit f0454e5

Please sign in to comment.