From 3bb63407d44e5d7bcf4efdf64321140c3585e0a5 Mon Sep 17 00:00:00 2001 From: Sebastian Hoffmann Date: Thu, 26 Dec 2024 18:28:21 +0100 Subject: [PATCH] feat: StageCallback --- dmlcloud/core/callbacks.py | 197 +++++++++++++++++++++++++++++++++++++ dmlcloud/core/metrics.py | 2 +- dmlcloud/core/stage.py | 127 ++++++++++++++---------- examples/barebone_mnist.py | 6 +- 4 files changed, 275 insertions(+), 57 deletions(-) create mode 100644 dmlcloud/core/callbacks.py diff --git a/dmlcloud/core/callbacks.py b/dmlcloud/core/callbacks.py new file mode 100644 index 0000000..a170f4b --- /dev/null +++ b/dmlcloud/core/callbacks.py @@ -0,0 +1,197 @@ +import sys +from datetime import datetime, timedelta +from typing import Callable, Optional + +import torch +from progress_table import ProgressTable + +from ..util.logging import DevNullIO +from . import logging as dml_logging +from .distributed import is_root + + +__all__ = [ + 'TimedeltaFormatter', + 'StageCallback', + 'TimreCallback', + 'TableCallback', + 'ReduceMetricsCallback', +] + + +class TimedeltaFormatter: + """ + A formatter that converts a number of seconds to a human-readable string. + """ + + def __init__(self, microseconds=False): + self.microseconds = microseconds + + def __call__(self, value: torch.Tensor) -> str: + delta = timedelta(seconds=value.item()) + if not self.microseconds: + delta -= timedelta(microseconds=delta.microseconds) + return str(delta) + + +class StageCallback: + """ + A callback that can be registered to a stage to receive updates on the training progress. + """ + + def pre_stage(self, stage: 'Stage'): + """ + Executed before the stage starts. + """ + pass + + def post_stage(self, stage: 'Stage'): + """ + Executed after the stage finishes. + """ + pass + + def pre_epoch(self, stage: 'Stage'): + """ + Executed before each epoch. + """ + pass + + def post_epoch(self, stage: 'Stage'): + """ + Executed after each epoch. + """ + pass + + +class TimerCallback(StageCallback): + """ + A callback that logs the time taken for each epoch. + """ + + def __init__(self): + self.start_time = None + self.end_time = None + self.epoch_start_time = None + self.epoch_end_time = None + + def pre_stage(self, stage: 'Stage'): + self.start_time = datetime.now() + + def post_stage(self, stage: 'Stage'): + self.end_time = datetime.now() + + def pre_epoch(self, stage: 'Stage'): + self.epoch_start_time = datetime.now() + + def post_epoch(self, stage: 'Stage'): + self.epoch_end_time = datetime.now() + + stage.log('misc/epoch', stage.current_epoch, prefixed=False) + stage.log('misc/epoch_time', (stage.epoch_end_time - self.epoch_start_time).total_seconds(), prefixed=False) + stage.log('misc/total_time', (stage.epoch_end_time - self.start_time).total_seconds(), prefixed=False) + + eta = ( + (stage.epoch_end_time - self.start_time) + / (stage.current_epoch + 1) + * (stage.max_epochs - stage.current_epoch - 1) + ) + stage.log('misc/eta', eta.total_seconds(), prefixed=False) + + if len(stage.pipe.stages) > 1: + dml_logging.info(f'Finished stage in {stage.end_time - stage.start_time}') + + +class TableCallback(StageCallback): + """ + A callback that updates a table with the latest metrics from a stage. + """ + + def __init__(self): + self._table = None + self.tracked_metrics = {} + self.formatters = {} + + @property + def table(self): + if self._table is None: + self.table = ProgressTable(file=sys.stdout if is_root() else DevNullIO()) + self.track_metric('Epoch', width=5) + self.track_metric('Took', 'misc/epoch_time', formatter=TimedeltaFormatter(), width=7) + self.track_metric('ETA', 'misc/eta', formatter=TimedeltaFormatter(), width=7) + return self._table + + @table.setter + def table(self, value): + self._table = value + + def track_metric( + self, + name: str, + metric: Optional[str] = None, + formatter: Optional[Callable] = None, + width: Optional[int] = None, + color: Optional[str] = None, + alignment: Optional[str] = None, + ): + """ + Track a metric in the table. + + If no metric name is provided, only a column is created and the caller must update the value manually. + If a formatter is provided, the metric value will be passed through the formatter before being displayed. + + For a detailed description of width, color, and alignment, see `ProgressTable.add_column`. + + Args: + name (str): The name of the column. + metric (str, optional): The name of the metric to track. Defaults to None. + formatter (Callable, optional): A function that takes the metric value and returns a string. Defaults to None. + width (int, optional): The width of the column. Defaults to None. + color (str, optional): The color of the column. Defaults to None. + alignment (str, optional): The alignment of the column. Defaults to None. + """ + if formatter and not metric: + raise ValueError('Cannot provide a formatter without a metric name') + + self.table.add_column(name, width=width, color=color, alignment=alignment) + + if metric: + self.tracked_metrics[name] = metric + self.formatters[name] = formatter + + def pre_stage(self, stage: 'Stage'): + _ = self.table # Ensure the table has been created at this point + + def post_stage(self, stage: 'Stage'): + self.table.close() + + def pre_epoch(self, stage: 'Stage'): + if 'Epoch' in self.table.column_names: + self.table['Epoch'] = stage.current_epoch + + def post_epoch(self, stage: 'Stage'): + metrics = stage.history.last() + + for column_name, metric_name in self.tracked_metrics.items(): + if column_name not in self.table.column_names: + continue + + value = metrics[metric_name] + formatter = self.formatters[column_name] + if formatter is not None: + value = formatter(value) + + self.table.update(column_name, value) + + self.table.next_row() + + +class ReduceMetricsCallback(StageCallback): + """ + A callback that reduces the metrics at the end of each epoch and appends them to the history. + """ + + def post_epoch(self, stage: 'Stage'): + metrics = stage.tracker.reduce() + stage.history.append_metrics(**metrics) + stage.history.next_step() diff --git a/dmlcloud/core/metrics.py b/dmlcloud/core/metrics.py index bd965b2..e0377ea 100644 --- a/dmlcloud/core/metrics.py +++ b/dmlcloud/core/metrics.py @@ -205,4 +205,4 @@ def reduce(self): def clear(self): for metric in self.metrics.values(): metric.reset() - self.metrics.clear() \ No newline at end of file + self.metrics.clear() diff --git a/dmlcloud/core/stage.py b/dmlcloud/core/stage.py index 97502f4..c61abe3 100644 --- a/dmlcloud/core/stage.py +++ b/dmlcloud/core/stage.py @@ -1,12 +1,7 @@ -import sys -from datetime import datetime, timedelta -from typing import Any, Optional +from typing import Any, Callable, List, Optional -from progress_table import ProgressTable - -from ..util.logging import DevNullIO from . import logging as dml_logging -from .distributed import is_root +from .callbacks import ReduceMetricsCallback, StageCallback, TableCallback, TimerCallback from .metrics import Tracker, TrainingHistory __all__ = [ @@ -27,22 +22,24 @@ def __init__(self, name: str = None, epochs: int = 1): self.name = name or self.__class__.__name__ self.max_epochs = epochs + self.callbacks: List[StageCallback] = [] + self.pipe = None # set by the pipeline self.history = TrainingHistory() self.tracker = Tracker() - self.start_time = None - self.stop_time = None - self.epoch_start_time = None - self.epoch_stop_time = None + self._timer = TimerCallback() + self.add_callback(self._timer) + + self.add_callback(ReduceMetricsCallback()) + + self._table_callback = TableCallback() + self.add_callback(self._table_callback) self.metric_prefix = None self.barrier_timeout = None - self.table = None - self.columns = {} - @property def device(self): return self.pipe.device @@ -55,6 +52,37 @@ def config(self): def current_epoch(self): return self.history.num_steps + @property + def start_time(self): + return self._timer.start_time + + @property + def end_time(self): + return self._timer.end_time + + @property + def epoch_start_time(self): + return self._timer.epoch_start_time + + @property + def epoch_end_time(self): + return self._timer.epoch_end_time + + @property + def table(self): + return self._table_callback.table + + def add_callback(self, callback: StageCallback): + """ + Adds a callback to this stage. + + Callbacks are executed in the order they are added and after the stage-specific hooks. + + Args: + callback (StageCallback): The callback to add. + """ + self.callbacks.append(callback) + def log(self, name: str, value: Any, reduction: str = 'mean', prefixed: bool = True): if prefixed: name = f'{self.metric_prefix}/{name}' @@ -69,12 +97,32 @@ def add_column( self, name: str, metric: Optional[str] = None, + formatter: Optional[Callable] = None, width: Optional[int] = None, color: Optional[str] = None, alignment: Optional[str] = None, ): - self.columns[name] = metric - self.table.add_column(name, width=width, color=color, alignment=alignment) + """ + Adds a column to the table. + + If metric is provided, the column will be updated with the latest value of the metric. + Otherwise,the caller must update the value manually using `table.update`. + + If a formatter is provided, the metric value will be passed through the formatter before being displayed. + + For a detailed description of width, color, and alignment, see `ProgressTable.add_column`. + + Args: + name (str): The name of the column. + metric (str, optional): The name of the metric to track. Defaults to None. + formatter (Callable, optional): A function that takes the metric value and returns a string. Defaults to None. + width (int, optional): The width of the column. Defaults to None. + color (str, optional): The color of the column. Defaults to None. + alignment (str, optional): The alignment of the column. Defaults to None. + """ + self._table_callback.track_metric( + name, metric=metric, formatter=formatter, width=width, color=color, alignment=alignment + ) def pre_stage(self): """ @@ -120,58 +168,31 @@ def run(self): self._post_stage() def _pre_stage(self): - self.start_time = datetime.now() if len(self.pipe.stages) > 1: dml_logging.info(f'\n========== STAGE: {self.name} ==========') - self.table = ProgressTable(file=sys.stdout if is_root() else DevNullIO()) - self.add_column('Epoch', None, color='bright', width=5) - self.add_column('Took', None, width=7) - self.add_column('ETA', None, width=7) - self.pre_stage() + for callback in self.callbacks: + callback.pre_stage(self) + dml_logging.flush_logger() self.pipe.barrier(self.barrier_timeout) def _post_stage(self): - self.table.close() self.post_stage() + for callback in self.callbacks: + callback.post_stage(self) + self.pipe.barrier(self.barrier_timeout) - self.stop_time = datetime.now() - if len(self.pipe.stages) > 1: - dml_logging.info(f'Finished stage in {self.stop_time - self.start_time}') def _pre_epoch(self): - self.epoch_start_time = datetime.now() - self.table['Epoch'] = self.current_epoch self.pre_epoch() + for callback in self.callbacks: + callback.pre_epoch(self) def _post_epoch(self): - self.epoch_stop_time = datetime.now() - self._reduce_metrics() self.post_epoch() - self._update_table() - - def _reduce_metrics(self): - # self.log('misc/epoch', self.current_epoch, prefixed=False) - # self.log('misc/epoch_time', (self.epoch_stop_time - self.epoch_start_time).total_seconds()) - metrics = self.tracker.reduce() - self.history.append_metrics(**metrics) - self.history.next_step() - - def _update_table(self): - time = datetime.now() - self.epoch_start_time - self.table.update('Took', str(time - timedelta(microseconds=time.microseconds))) - - per_epoch = (datetime.now() - self.start_time) / self.current_epoch - eta = per_epoch * (self.max_epochs - self.current_epoch) - self.table.update('ETA', str(eta - timedelta(microseconds=eta.microseconds))) - - last_metrics = self.history.last() - for name, metric in self.columns.items(): - if metric is not None: - self.table.update(name, last_metrics[metric]) - - self.table.next_row() + for callback in self.callbacks: + callback.post_epoch(self) diff --git a/examples/barebone_mnist.py b/examples/barebone_mnist.py index 17eb4f7..9b5a667 100644 --- a/examples/barebone_mnist.py +++ b/examples/barebone_mnist.py @@ -53,9 +53,9 @@ def pre_stage(self): # Finally, we add columns to the table to track the loss and accuracy self.add_column('[Train] Loss', 'train/loss', color='green') - self.add_column('[Train] Acc.', 'train/accuracy', color='green') - self.add_column('[Val] Loss', 'val/loss', color='blue') - self.add_column('[Val] Acc.', 'val/accuracy', color='blue') + self.add_column('[Train] Acc.', 'train/accuracy', formatter=lambda acc: f'{100*acc:.2f}%', color='green') + self.add_column('[Val] Loss', 'val/loss', color='cyan') + self.add_column('[Val] Acc.', 'val/accuracy', formatter=lambda acc: f'{100*acc:.2f}%', color='cyan') self.train_acc = self.add_metric('train/accuracy', torchmetrics.Accuracy('multiclass', num_classes=10)) self.val_acc = self.add_metric('val/accuracy', torchmetrics.Accuracy('multiclass', num_classes=10))