Skip to content

Commit

Permalink
feat: simpler table interface
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Dec 17, 2024
1 parent c0036b9 commit 8d4027b
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 61 deletions.
85 changes: 32 additions & 53 deletions dmlcloud/core/stage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys
import time
from datetime import datetime
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Union

import torch
Expand Down Expand Up @@ -39,9 +39,11 @@ def __init__(self):
self._stop_requested = False

self.metric_prefix = None
self.table = None
self.barrier_timeout = None

self.table = None
self.columns = {}

@property
def tracker(self) -> MetricTracker:
return self.pipeline.tracker
Expand Down Expand Up @@ -73,6 +75,17 @@ def track(self, name: str, value, step: Optional[int] = None, prefixed: bool = T
name = f'{self.metric_prefix}/{name}'
self.pipeline.track(name, value, step)

def add_column(
self,
name: str,
metric: Optional[str] = None,
width: Optional[int] = None,
color: Optional[str] = None,
alignment: Optional[str] = None,
):
self.columns[name] = metric
self.table.add_column(name, width=width, color=color, alignment=alignment)

def stop_stage(self):
self._stop_requested = True

Expand Down Expand Up @@ -108,25 +121,6 @@ def run_epoch(self):
"""
raise NotImplementedError()

def table_columns(self) -> List[Union[str, Dict[str, Any]]]:
"""
Override this method to customize the metrics displayed in the progress table.
Should return a list containing either strings or dicts.
If a string, it will be used as both the display name and the metric name.
If a dict, it should contain a 'name' key and a 'metric' key.
The 'name' key will be used as the display name, and the 'metric' key will be used as the metric name.
Additional keys are forwarded to the ProgressTable.add_column method.
If 'metric' is None, then the user is responsible for updating the column manually.
"""
columns = [
{'name': 'Epoch', 'metric': 'misc/epoch'},
{'name': 'Time/Epoch', 'metric': None},
]
if self.max_epochs is not None:
columns.append({'name': 'ETA', 'metric': None})
return columns

def run(self):
"""
Runs this stage. Either until max_epochs are reached, or until stop_stage() is called.
Expand All @@ -142,11 +136,14 @@ def run(self):

def _pre_stage(self):
self.start_time = datetime.now()
self.table = ProgressTable(file=sys.stdout if is_root() else DevNullIO())
self._setup_table()
if len(self.pipeline.stages) > 1:
dml_logging.info(f'\n========== STAGE: {self.name} ==========')

self.table = ProgressTable(file=sys.stdout if is_root() else DevNullIO())
self.add_column('Epoch', 'misc/epoch', color='bright', width=5)
self.add_column('Took', None, width=7)
self.add_column('ETA', None, width=7)

self.pre_stage()

dml_logging.flush_logger()
Expand Down Expand Up @@ -183,39 +180,21 @@ def _reduce_metrics(self):
self.tracker.next_epoch()
pass

def _setup_table(self):
for column_dct in self._metrics():
display_name = column_dct.pop('name')
column_dct.pop('metric')
self.table.add_column(display_name, **column_dct)

def _update_table(self):
self.table.update('Epoch', self.current_epoch)
self.table.update('Time/Epoch', (datetime.now() - self.start_time) / self.current_epoch)
self.table.update(
'ETA', (datetime.now() - self.start_time) / self.current_epoch * (self.max_epochs - self.current_epoch)
)
for column_dct in self._metrics():
display_name = column_dct['name']
metric_name = column_dct['metric']
if metric_name is not None:
self.table.update(display_name, self.tracker[metric_name][-1])
self.table.next_row()

def _metrics(self):
metrics = []
for column in self.table_columns():
if isinstance(column, str):
metrics.append({'name': column, 'metric': column})
elif isinstance(column, dict):
if 'name' not in column:
raise ValueError('Column dict must contain a "name" key')
if 'metric' not in column:
raise ValueError('Column dict must contain a "metric" key')
metrics.append(column)
else:
raise ValueError(f'Invalid column: {column}. Must be a string or a dict.')
return metrics
time = datetime.now() - self.epoch_start_time
self.table.update('Took', str(time - timedelta(microseconds=time.microseconds)))

per_epoch = (datetime.now() - self.start_time) / self.current_epoch
eta = per_epoch * (self.max_epochs - self.current_epoch)
self.table.update('ETA', str(eta - timedelta(microseconds=eta.microseconds)))

for name, metric in self.columns.items():
if metric is not None:
self.table.update(name, self.tracker[metric][-1])

self.table.next_row()


class TrainValStage(Stage):
Expand Down
13 changes: 5 additions & 8 deletions examples/barebone_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ def pre_stage(self):
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
self.loss = nn.CrossEntropyLoss()

self.add_column('[Train] Loss', 'train/loss', color='green')
self.add_column('[Train] Acc.', 'train/accuracy', color='green')
self.add_column('[Val] Loss', 'val/loss', color='blue')
self.add_column('[Val] Acc.', 'val/accuracy', color='blue')

def run_epoch(self):
self._train_epoch()
self._val_epoch()
Expand Down Expand Up @@ -72,14 +77,6 @@ def _log_metrics(self, img, target, output, loss):
self.track_reduce('loss', loss)
self.track_reduce('accuracy', (output.argmax(1) == target).float().mean())

def table_columns(self):
columns = super().table_columns()
columns.insert(1, {'name': '[Train] Loss', 'metric': 'train/loss'})
columns.insert(2, {'name': '[Val] Loss', 'metric': 'val/loss'})
columns.insert(3, {'name': '[Train] Acc.', 'metric': 'train/accuracy'})
columns.insert(4, {'name': '[Val] Acc.', 'metric': 'val/accuracy'})
return columns


def main():
pipeline = dml.TrainingPipeline()
Expand Down

0 comments on commit 8d4027b

Please sign in to comment.