Skip to content

Commit

Permalink
feat: StageCallback
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Dec 26, 2024
1 parent 57e10a5 commit 3bb6340
Show file tree
Hide file tree
Showing 4 changed files with 275 additions and 57 deletions.
197 changes: 197 additions & 0 deletions dmlcloud/core/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import sys
from datetime import datetime, timedelta
from typing import Callable, Optional

import torch
from progress_table import ProgressTable

from ..util.logging import DevNullIO
from . import logging as dml_logging
from .distributed import is_root


__all__ = [
'TimedeltaFormatter',
'StageCallback',
'TimreCallback',
'TableCallback',
'ReduceMetricsCallback',
]


class TimedeltaFormatter:
"""
A formatter that converts a number of seconds to a human-readable string.
"""

def __init__(self, microseconds=False):
self.microseconds = microseconds

def __call__(self, value: torch.Tensor) -> str:
delta = timedelta(seconds=value.item())
if not self.microseconds:
delta -= timedelta(microseconds=delta.microseconds)
return str(delta)


class StageCallback:
"""
A callback that can be registered to a stage to receive updates on the training progress.
"""

def pre_stage(self, stage: 'Stage'):
"""
Executed before the stage starts.
"""
pass

def post_stage(self, stage: 'Stage'):
"""
Executed after the stage finishes.
"""
pass

def pre_epoch(self, stage: 'Stage'):
"""
Executed before each epoch.
"""
pass

def post_epoch(self, stage: 'Stage'):
"""
Executed after each epoch.
"""
pass


class TimerCallback(StageCallback):
"""
A callback that logs the time taken for each epoch.
"""

def __init__(self):
self.start_time = None
self.end_time = None
self.epoch_start_time = None
self.epoch_end_time = None

def pre_stage(self, stage: 'Stage'):
self.start_time = datetime.now()

def post_stage(self, stage: 'Stage'):
self.end_time = datetime.now()

def pre_epoch(self, stage: 'Stage'):
self.epoch_start_time = datetime.now()

def post_epoch(self, stage: 'Stage'):
self.epoch_end_time = datetime.now()

stage.log('misc/epoch', stage.current_epoch, prefixed=False)
stage.log('misc/epoch_time', (stage.epoch_end_time - self.epoch_start_time).total_seconds(), prefixed=False)
stage.log('misc/total_time', (stage.epoch_end_time - self.start_time).total_seconds(), prefixed=False)

eta = (
(stage.epoch_end_time - self.start_time)
/ (stage.current_epoch + 1)
* (stage.max_epochs - stage.current_epoch - 1)
)
stage.log('misc/eta', eta.total_seconds(), prefixed=False)

if len(stage.pipe.stages) > 1:
dml_logging.info(f'Finished stage in {stage.end_time - stage.start_time}')


class TableCallback(StageCallback):
"""
A callback that updates a table with the latest metrics from a stage.
"""

def __init__(self):
self._table = None
self.tracked_metrics = {}
self.formatters = {}

@property
def table(self):
if self._table is None:
self.table = ProgressTable(file=sys.stdout if is_root() else DevNullIO())
self.track_metric('Epoch', width=5)
self.track_metric('Took', 'misc/epoch_time', formatter=TimedeltaFormatter(), width=7)
self.track_metric('ETA', 'misc/eta', formatter=TimedeltaFormatter(), width=7)
return self._table

@table.setter
def table(self, value):
self._table = value

def track_metric(
self,
name: str,
metric: Optional[str] = None,
formatter: Optional[Callable] = None,
width: Optional[int] = None,
color: Optional[str] = None,
alignment: Optional[str] = None,
):
"""
Track a metric in the table.
If no metric name is provided, only a column is created and the caller must update the value manually.
If a formatter is provided, the metric value will be passed through the formatter before being displayed.
For a detailed description of width, color, and alignment, see `ProgressTable.add_column`.
Args:
name (str): The name of the column.
metric (str, optional): The name of the metric to track. Defaults to None.
formatter (Callable, optional): A function that takes the metric value and returns a string. Defaults to None.
width (int, optional): The width of the column. Defaults to None.
color (str, optional): The color of the column. Defaults to None.
alignment (str, optional): The alignment of the column. Defaults to None.
"""
if formatter and not metric:
raise ValueError('Cannot provide a formatter without a metric name')

self.table.add_column(name, width=width, color=color, alignment=alignment)

if metric:
self.tracked_metrics[name] = metric
self.formatters[name] = formatter

def pre_stage(self, stage: 'Stage'):
_ = self.table # Ensure the table has been created at this point

def post_stage(self, stage: 'Stage'):
self.table.close()

def pre_epoch(self, stage: 'Stage'):
if 'Epoch' in self.table.column_names:
self.table['Epoch'] = stage.current_epoch

def post_epoch(self, stage: 'Stage'):
metrics = stage.history.last()

for column_name, metric_name in self.tracked_metrics.items():
if column_name not in self.table.column_names:
continue

value = metrics[metric_name]
formatter = self.formatters[column_name]
if formatter is not None:
value = formatter(value)

self.table.update(column_name, value)

self.table.next_row()


class ReduceMetricsCallback(StageCallback):
"""
A callback that reduces the metrics at the end of each epoch and appends them to the history.
"""

def post_epoch(self, stage: 'Stage'):
metrics = stage.tracker.reduce()
stage.history.append_metrics(**metrics)
stage.history.next_step()
2 changes: 1 addition & 1 deletion dmlcloud/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,4 +205,4 @@ def reduce(self):
def clear(self):
for metric in self.metrics.values():
metric.reset()
self.metrics.clear()
self.metrics.clear()
127 changes: 74 additions & 53 deletions dmlcloud/core/stage.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import sys
from datetime import datetime, timedelta
from typing import Any, Optional
from typing import Any, Callable, List, Optional

from progress_table import ProgressTable

from ..util.logging import DevNullIO
from . import logging as dml_logging
from .distributed import is_root
from .callbacks import ReduceMetricsCallback, StageCallback, TableCallback, TimerCallback
from .metrics import Tracker, TrainingHistory

__all__ = [
Expand All @@ -27,22 +22,24 @@ def __init__(self, name: str = None, epochs: int = 1):
self.name = name or self.__class__.__name__
self.max_epochs = epochs

self.callbacks: List[StageCallback] = []

self.pipe = None # set by the pipeline

self.history = TrainingHistory()
self.tracker = Tracker()

self.start_time = None
self.stop_time = None
self.epoch_start_time = None
self.epoch_stop_time = None
self._timer = TimerCallback()
self.add_callback(self._timer)

self.add_callback(ReduceMetricsCallback())

self._table_callback = TableCallback()
self.add_callback(self._table_callback)

self.metric_prefix = None
self.barrier_timeout = None

self.table = None
self.columns = {}

@property
def device(self):
return self.pipe.device
Expand All @@ -55,6 +52,37 @@ def config(self):
def current_epoch(self):
return self.history.num_steps

@property
def start_time(self):
return self._timer.start_time

@property
def end_time(self):
return self._timer.end_time

@property
def epoch_start_time(self):
return self._timer.epoch_start_time

@property
def epoch_end_time(self):
return self._timer.epoch_end_time

@property
def table(self):
return self._table_callback.table

def add_callback(self, callback: StageCallback):
"""
Adds a callback to this stage.
Callbacks are executed in the order they are added and after the stage-specific hooks.
Args:
callback (StageCallback): The callback to add.
"""
self.callbacks.append(callback)

def log(self, name: str, value: Any, reduction: str = 'mean', prefixed: bool = True):
if prefixed:
name = f'{self.metric_prefix}/{name}'
Expand All @@ -69,12 +97,32 @@ def add_column(
self,
name: str,
metric: Optional[str] = None,
formatter: Optional[Callable] = None,
width: Optional[int] = None,
color: Optional[str] = None,
alignment: Optional[str] = None,
):
self.columns[name] = metric
self.table.add_column(name, width=width, color=color, alignment=alignment)
"""
Adds a column to the table.
If metric is provided, the column will be updated with the latest value of the metric.
Otherwise,the caller must update the value manually using `table.update`.
If a formatter is provided, the metric value will be passed through the formatter before being displayed.
For a detailed description of width, color, and alignment, see `ProgressTable.add_column`.
Args:
name (str): The name of the column.
metric (str, optional): The name of the metric to track. Defaults to None.
formatter (Callable, optional): A function that takes the metric value and returns a string. Defaults to None.
width (int, optional): The width of the column. Defaults to None.
color (str, optional): The color of the column. Defaults to None.
alignment (str, optional): The alignment of the column. Defaults to None.
"""
self._table_callback.track_metric(
name, metric=metric, formatter=formatter, width=width, color=color, alignment=alignment
)

def pre_stage(self):
"""
Expand Down Expand Up @@ -120,58 +168,31 @@ def run(self):
self._post_stage()

def _pre_stage(self):
self.start_time = datetime.now()
if len(self.pipe.stages) > 1:
dml_logging.info(f'\n========== STAGE: {self.name} ==========')

self.table = ProgressTable(file=sys.stdout if is_root() else DevNullIO())
self.add_column('Epoch', None, color='bright', width=5)
self.add_column('Took', None, width=7)
self.add_column('ETA', None, width=7)

self.pre_stage()

for callback in self.callbacks:
callback.pre_stage(self)

dml_logging.flush_logger()

self.pipe.barrier(self.barrier_timeout)

def _post_stage(self):
self.table.close()
self.post_stage()
for callback in self.callbacks:
callback.post_stage(self)

self.pipe.barrier(self.barrier_timeout)
self.stop_time = datetime.now()
if len(self.pipe.stages) > 1:
dml_logging.info(f'Finished stage in {self.stop_time - self.start_time}')

def _pre_epoch(self):
self.epoch_start_time = datetime.now()
self.table['Epoch'] = self.current_epoch
self.pre_epoch()
for callback in self.callbacks:
callback.pre_epoch(self)

def _post_epoch(self):
self.epoch_stop_time = datetime.now()
self._reduce_metrics()
self.post_epoch()
self._update_table()

def _reduce_metrics(self):
# self.log('misc/epoch', self.current_epoch, prefixed=False)
# self.log('misc/epoch_time', (self.epoch_stop_time - self.epoch_start_time).total_seconds())
metrics = self.tracker.reduce()
self.history.append_metrics(**metrics)
self.history.next_step()

def _update_table(self):
time = datetime.now() - self.epoch_start_time
self.table.update('Took', str(time - timedelta(microseconds=time.microseconds)))

per_epoch = (datetime.now() - self.start_time) / self.current_epoch
eta = per_epoch * (self.max_epochs - self.current_epoch)
self.table.update('ETA', str(eta - timedelta(microseconds=eta.microseconds)))

last_metrics = self.history.last()
for name, metric in self.columns.items():
if metric is not None:
self.table.update(name, last_metrics[metric])

self.table.next_row()
for callback in self.callbacks:
callback.post_epoch(self)
Loading

0 comments on commit 3bb6340

Please sign in to comment.