Skip to content

Commit

Permalink
feat: manual epoch managment, closes #12
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Jan 7, 2025
1 parent f75ce65 commit fa6f426
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 43 deletions.
47 changes: 24 additions & 23 deletions dmlcloud/core/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,10 @@ def post_epoch(self, stage: 'Stage'):
stage.log('misc/epoch_time', (stage.epoch_end_time - self.epoch_start_time).total_seconds(), prefixed=False)
stage.log('misc/total_time', (stage.epoch_end_time - self.start_time).total_seconds(), prefixed=False)

average_epoch_time = (stage.epoch_end_time - self.start_time) / (stage.current_epoch + 1)
eta = average_epoch_time * (stage.max_epochs - stage.current_epoch - 1)
stage.log('misc/eta', eta.total_seconds(), prefixed=False)
if stage._run_epoch_overridden:
average_epoch_time = (stage.epoch_end_time - self.start_time) / (stage.current_epoch + 1)
eta = average_epoch_time * (stage.max_epochs - stage.current_epoch - 1)
stage.log('misc/eta', eta.total_seconds(), prefixed=False)


class TableCallback(Callback):
Expand All @@ -203,21 +204,21 @@ def __init__(self):
self.tracked_metrics = {}
self.formatters = {}

@property
def table(self):
def get_table(self, stage: 'Stage'):
if self._table is None:
self.table = ProgressTable(file=sys.stdout if is_root() else DevNullIO())
self.track_metric('Epoch', width=5)
self.track_metric('Took', 'misc/epoch_time', formatter=TimedeltaFormatter(), width=7)
self.track_metric('ETA', 'misc/eta', formatter=TimedeltaFormatter(), width=7)
self._table = ProgressTable(file=sys.stdout if is_root() else DevNullIO())
self.track_metric(stage, 'Epoch', width=5)
self.track_metric(stage, 'Took', 'misc/epoch_time', formatter=TimedeltaFormatter(), width=7)
if stage._run_epoch_overridden:
self.track_metric(stage, 'ETA', 'misc/eta', formatter=TimedeltaFormatter(), width=7)
return self._table

@table.setter
def table(self, value):
def set_table(self, value):
self._table = value

def track_metric(
self,
stage: 'Stage',
name: str,
metric: Optional[str] = None,
formatter: Optional[Callable] = None,
Expand All @@ -244,37 +245,37 @@ def track_metric(
if formatter and not metric:
raise ValueError('Cannot provide a formatter without a metric name')

self.table.add_column(name, width=width, color=color, alignment=alignment)
self.get_table(stage).add_column(name, width=width, color=color, alignment=alignment)

if metric:
self.tracked_metrics[name] = metric
self.formatters[name] = formatter

def pre_stage(self, stage: 'Stage'):
_ = self.table # Ensure the table has been created at this point
self.get_table(stage) # Ensure the table has been created at this point

def post_stage(self, stage: 'Stage'):
self.table.close()
self.get_table(stage).close()

def pre_epoch(self, stage: 'Stage'):
if 'Epoch' in self.table.column_names:
self.table['Epoch'] = stage.current_epoch
if 'Epoch' in self.get_table(stage).column_names:
self.get_table(stage)['Epoch'] = stage.current_epoch

def post_epoch(self, stage: 'Stage'):
metrics = stage.history.last()

for column_name, metric_name in self.tracked_metrics.items():
if column_name not in self.table.column_names:
if column_name not in self.get_table(stage).column_names:
continue

value = metrics[metric_name]
formatter = self.formatters[column_name]
if formatter is not None:
value = formatter(value)

self.table.update(column_name, value)
self.get_table(stage).update(column_name, value)

self.table.next_row()
self.get_table(stage).next_row()


class ReduceMetricsCallback(Callback):
Expand All @@ -283,7 +284,7 @@ class ReduceMetricsCallback(Callback):
"""

def post_epoch(self, stage: 'Stage'):
metrics = stage.tracker.reduce()
metrics = stage.metrics.reduce()
stage.history.append_metrics(**metrics)
stage.history.next_step()

Expand Down Expand Up @@ -388,15 +389,15 @@ class WandbCallback(Callback):
A callback that logs metrics to Weights & Biases.
"""

def __init__(self, entity, project, group, tags, startup_timeout, **kwargs):
def __init__(self, project, entity, group, tags, startup_timeout, **kwargs):
try:
import wandb
except ImportError:
raise ImportError('wandb is required for the WandbCallback')

self.wandb = wandb
self.entity = entity
self.project = project
self.entity = entity
self.group = group
self.tags = tags
self.startup_timeout = startup_timeout
Expand All @@ -407,8 +408,8 @@ def pre_run(self, pipe: 'Pipeline'):
self.wandb.init(
config=OmegaConf.to_container(pipe.config, resolve=True),
name=pipe.name,
entity=self.entity,
project=self.project,
entity=self.entity,
group=self.group,
tags=self.tags,
**self.kwargs,
Expand Down
6 changes: 6 additions & 0 deletions dmlcloud/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,9 @@ def clear(self):
for metric in self.metrics.values():
metric.reset()
self.metrics.clear()

def __getitem__(self, name: str):
return self.metrics[name]

def __setitem__(self, name: str, metric: torchmetrics.Metric):
self.add_metric(name, metric)
13 changes: 10 additions & 3 deletions dmlcloud/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,15 @@ def enable_wandb(
import wandb # import now to avoid potential long import times later on # noqa

if is_root():
project = project or self.name
self.add_callback(WandbCallback(project, entity, group, tags, startup_timeout, **kwargs), CbPriority.WANDB)
callback = WandbCallback(
project=project,
entity=entity,
group=group,
tags=tags,
startup_timeout=startup_timeout,
**kwargs,
)
self.add_callback(callback, CbPriority.WANDB)

self.wandb = True

Expand All @@ -224,7 +231,7 @@ def run(self):
self._pre_run()
for stage in self.stages:
self.current_stage = stage
stage.run()
stage._run()
self._post_run()

def pre_run(self):
Expand Down
68 changes: 56 additions & 12 deletions dmlcloud/core/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Stage:
- post_epoch()
"""

def __init__(self, name: str = None, epochs: int = 1):
def __init__(self, name: str = None, epochs: int | None = 1):
self.name = name or self.__class__.__name__
self.max_epochs = epochs

Expand All @@ -46,7 +46,7 @@ def __init__(self, name: str = None, epochs: int = 1):
self.pipe = None # set by the pipeline

self.history = TrainingHistory()
self.tracker = Tracker()
self.metrics = Tracker()

self.metric_prefix = None
self.barrier_timeout = None
Expand Down Expand Up @@ -88,7 +88,15 @@ def epoch_end_time(self):

@property
def table(self):
return self._table_callback.table
return self._table_callback.get_table(self)

@property
def _run_overridden(self):
return type(self).run != Stage.run

@property
def _run_epoch_overridden(self):
return type(self).run_epoch != Stage.run_epoch

def add_callback(self, callback: 'Callback', priority: int = 1):
"""
Expand All @@ -108,11 +116,11 @@ def add_callback(self, callback: 'Callback', priority: int = 1):
def log(self, name: str, value: Any, reduction: str = 'mean', prefixed: bool = True):
if prefixed and self.metric_prefix:
name = f'{self.metric_prefix}/{name}'
self.tracker.log(name, value, reduction)
self.metrics.log(name, value, reduction)

def add_metric(self, name, metric):
metric = metric.to(self.device)
self.tracker.add_metric(name, metric)
self.metrics.add_metric(name, metric)
return metric

def add_column(
Expand Down Expand Up @@ -143,7 +151,7 @@ def add_column(
alignment (str, optional): The alignment of the column. Defaults to None.
"""
self._table_callback.track_metric(
name, metric=metric, formatter=formatter, width=width, color=color, alignment=alignment
self, name, metric=metric, formatter=formatter, width=width, color=color, alignment=alignment
)

def pre_stage(self):
Expand Down Expand Up @@ -172,22 +180,58 @@ def post_epoch(self):
"""
pass

def run():
"""
Override this method to implement the main logic of the stage and do manual epoch management.
Either this method or :meth:`run_epoch` must be implemented by subclasses.
Unlike :meth:`run_epoch`, this method is called only once per stage, and the implementation is responsible for
managing the epochs and calling :meth:`next_epoch` when appropriate.
"""
raise NotImplementedError()

def next_epoch(self):
"""
Advances the stage to the next epoch.
This method must only be called by the implementation of :meth:`run` when the stage finishes an epoch.
"""
if self._run_epoch_overridden:
raise ValueError('next_epoch() must not be called when run_epoch() is implemented.')

self._post_epoch()
self._pre_epoch()

def run_epoch(self):
"""
Train the model for one epoch. Must be implemented by subclasses.
Override this method to implement the main logic of the stage for a single epoch.
Either this method or :meth:`run` must be implemented by subclasses.
Unlike :meth:`run`, this method is called automatically by the stage and does not need to manage the epochs.
"""
raise NotImplementedError()

def run(self):
def _run(self):
"""
Runs this stage. Either until max_epochs are reached, or until stop_stage() is called.
"""
self._pre_stage()
while self.max_epochs is None or self.current_epoch < self.max_epochs:
if self._run_overridden and self._run_epoch_overridden:
raise ValueError('Only one of run() or run_epoch() must be implemented.')
elif not self._run_overridden and not self._run_epoch_overridden:
raise ValueError('Either run() or run_epoch() must be implemented.')
elif self._run_epoch_overridden:
self._pre_stage()
while self.max_epochs is None or self.current_epoch < self.max_epochs:
self._pre_epoch()
self.run_epoch()
self._post_epoch()
self._post_stage()
else:
self._pre_stage()
self._pre_epoch()
self.run_epoch()
self.run()
self._post_epoch()
self._post_stage()
self._post_stage()

def _pre_stage(self):
if len(self.pipe.stages) > 1:
Expand Down
82 changes: 82 additions & 0 deletions examples/custom_epochs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import dmlcloud as dml
import torch
import torchmetrics
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


class CustomEpochStage(dml.Stage):

def pre_stage(self):
with dml.root_first():
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='data', train=True, download=dml.is_root(), transform=transform)

self.train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
self.train_loader = DataLoader(train_dataset, batch_size=32, sampler=self.train_sampler)

model = nn.Sequential(
nn.Conv2d(1, 16, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(16, 16, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(784, 10),
)
self.model = dml.wrap_ddp(model, self.device)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=dml.scale_lr(1e-3))

self.loss = nn.CrossEntropyLoss()

# Finally, we add columns to the table to track the loss and accuracy
self.add_column('# Steps', 'misc/steps')
self.add_column('# Samples', 'misc/total_samples')

self.add_column('Loss', 'train/loss', color='green')

def run(self):
MAX_STEPS = 5000
LOG_PERIOD = 250

num_steps = 0
total_samples = 0
while num_steps < MAX_STEPS:
self.train_sampler.set_epoch(self.current_epoch)

for img, target in self.train_loader:
img, target = img.to(self.device), target.to(self.device)

self.optimizer.zero_grad()
output = self.model(img)
loss = self.loss(output, target)
loss.backward()
self.optimizer.step()

self.log('train/loss', loss)
self.log('misc/samples', len(img), reduction='sum')

num_steps += 1
if num_steps % LOG_PERIOD == 0:
total_samples += self.metrics['misc/samples'].compute()
self.log('misc/total_samples', total_samples)
self.log('misc/steps', num_steps)
if num_steps < MAX_STEPS:
self.next_epoch()
self.train_sampler.set_epoch(self.current_epoch)
else:
break


def main():
pipe = dml.Pipeline(name='custom-epochs')
pipe.append(CustomEpochStage())
pipe.enable_checkpointing('checkpoints')
pipe.enable_wandb()
pipe.run()


if __name__ == '__main__':
main()
6 changes: 1 addition & 5 deletions examples/mnist.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import sys

sys.path.insert(0, './')

import dmlcloud as dml
import torch
import torchmetrics
Expand Down Expand Up @@ -100,7 +96,7 @@ def _val_epoch(self):


def main():
pipe = dml.Pipeline()
pipe = dml.Pipeline(name='MNIST')
pipe.append(MNISTStage(epochs=3))
pipe.enable_checkpointing('checkpoints')
pipe.enable_wandb()
Expand Down

0 comments on commit fa6f426

Please sign in to comment.