From 57f89b984a8dfd7ce299fd5bc624c6cd37caffd3 Mon Sep 17 00:00:00 2001 From: Sebastian Hoffmann Date: Tue, 7 Jan 2025 19:07:54 +0100 Subject: [PATCH] feat: manual epoch managment, closes #12 --- dmlcloud/core/callbacks.py | 47 +++++++++++----------- dmlcloud/core/metrics.py | 6 +++ dmlcloud/core/pipeline.py | 13 ++++-- dmlcloud/core/stage.py | 68 ++++++++++++++++++++++++++------ examples/README.md | 2 +- examples/custom_epochs.py | 81 ++++++++++++++++++++++++++++++++++++++ examples/mnist.py | 6 +-- 7 files changed, 179 insertions(+), 44 deletions(-) create mode 100644 examples/custom_epochs.py diff --git a/dmlcloud/core/callbacks.py b/dmlcloud/core/callbacks.py index 89b5a04..cc3680f 100644 --- a/dmlcloud/core/callbacks.py +++ b/dmlcloud/core/callbacks.py @@ -188,9 +188,10 @@ def post_epoch(self, stage: 'Stage'): stage.log('misc/epoch_time', (stage.epoch_end_time - self.epoch_start_time).total_seconds(), prefixed=False) stage.log('misc/total_time', (stage.epoch_end_time - self.start_time).total_seconds(), prefixed=False) - average_epoch_time = (stage.epoch_end_time - self.start_time) / (stage.current_epoch + 1) - eta = average_epoch_time * (stage.max_epochs - stage.current_epoch - 1) - stage.log('misc/eta', eta.total_seconds(), prefixed=False) + if stage._run_epoch_overridden: + average_epoch_time = (stage.epoch_end_time - self.start_time) / (stage.current_epoch + 1) + eta = average_epoch_time * (stage.max_epochs - stage.current_epoch - 1) + stage.log('misc/eta', eta.total_seconds(), prefixed=False) class TableCallback(Callback): @@ -203,21 +204,21 @@ def __init__(self): self.tracked_metrics = {} self.formatters = {} - @property - def table(self): + def get_table(self, stage: 'Stage'): if self._table is None: - self.table = ProgressTable(file=sys.stdout if is_root() else DevNullIO()) - self.track_metric('Epoch', width=5) - self.track_metric('Took', 'misc/epoch_time', formatter=TimedeltaFormatter(), width=7) - self.track_metric('ETA', 'misc/eta', formatter=TimedeltaFormatter(), width=7) + self._table = ProgressTable(file=sys.stdout if is_root() else DevNullIO()) + self.track_metric(stage, 'Epoch', width=5) + self.track_metric(stage, 'Took', 'misc/epoch_time', formatter=TimedeltaFormatter(), width=7) + if stage._run_epoch_overridden: + self.track_metric(stage, 'ETA', 'misc/eta', formatter=TimedeltaFormatter(), width=7) return self._table - @table.setter - def table(self, value): + def set_table(self, value): self._table = value def track_metric( self, + stage: 'Stage', name: str, metric: Optional[str] = None, formatter: Optional[Callable] = None, @@ -244,27 +245,27 @@ def track_metric( if formatter and not metric: raise ValueError('Cannot provide a formatter without a metric name') - self.table.add_column(name, width=width, color=color, alignment=alignment) + self.get_table(stage).add_column(name, width=width, color=color, alignment=alignment) if metric: self.tracked_metrics[name] = metric self.formatters[name] = formatter def pre_stage(self, stage: 'Stage'): - _ = self.table # Ensure the table has been created at this point + self.get_table(stage) # Ensure the table has been created at this point def post_stage(self, stage: 'Stage'): - self.table.close() + self.get_table(stage).close() def pre_epoch(self, stage: 'Stage'): - if 'Epoch' in self.table.column_names: - self.table['Epoch'] = stage.current_epoch + if 'Epoch' in self.get_table(stage).column_names: + self.get_table(stage)['Epoch'] = stage.current_epoch def post_epoch(self, stage: 'Stage'): metrics = stage.history.last() for column_name, metric_name in self.tracked_metrics.items(): - if column_name not in self.table.column_names: + if column_name not in self.get_table(stage).column_names: continue value = metrics[metric_name] @@ -272,9 +273,9 @@ def post_epoch(self, stage: 'Stage'): if formatter is not None: value = formatter(value) - self.table.update(column_name, value) + self.get_table(stage).update(column_name, value) - self.table.next_row() + self.get_table(stage).next_row() class ReduceMetricsCallback(Callback): @@ -283,7 +284,7 @@ class ReduceMetricsCallback(Callback): """ def post_epoch(self, stage: 'Stage'): - metrics = stage.tracker.reduce() + metrics = stage.metrics.reduce() stage.history.append_metrics(**metrics) stage.history.next_step() @@ -388,15 +389,15 @@ class WandbCallback(Callback): A callback that logs metrics to Weights & Biases. """ - def __init__(self, entity, project, group, tags, startup_timeout, **kwargs): + def __init__(self, project, entity, group, tags, startup_timeout, **kwargs): try: import wandb except ImportError: raise ImportError('wandb is required for the WandbCallback') self.wandb = wandb - self.entity = entity self.project = project + self.entity = entity self.group = group self.tags = tags self.startup_timeout = startup_timeout @@ -407,8 +408,8 @@ def pre_run(self, pipe: 'Pipeline'): self.wandb.init( config=OmegaConf.to_container(pipe.config, resolve=True), name=pipe.name, - entity=self.entity, project=self.project, + entity=self.entity, group=self.group, tags=self.tags, **self.kwargs, diff --git a/dmlcloud/core/metrics.py b/dmlcloud/core/metrics.py index 0ee2244..29c4923 100644 --- a/dmlcloud/core/metrics.py +++ b/dmlcloud/core/metrics.py @@ -204,3 +204,9 @@ def clear(self): for metric in self.metrics.values(): metric.reset() self.metrics.clear() + + def __getitem__(self, name: str): + return self.metrics[name] + + def __setitem__(self, name: str, metric: torchmetrics.Metric): + self.add_metric(name, metric) diff --git a/dmlcloud/core/pipeline.py b/dmlcloud/core/pipeline.py index c5db070..99c7b40 100644 --- a/dmlcloud/core/pipeline.py +++ b/dmlcloud/core/pipeline.py @@ -197,8 +197,15 @@ def enable_wandb( import wandb # import now to avoid potential long import times later on # noqa if is_root(): - project = project or self.name - self.add_callback(WandbCallback(project, entity, group, tags, startup_timeout, **kwargs), CbPriority.WANDB) + callback = WandbCallback( + project=project, + entity=entity, + group=group, + tags=tags, + startup_timeout=startup_timeout, + **kwargs, + ) + self.add_callback(callback, CbPriority.WANDB) self.wandb = True @@ -224,7 +231,7 @@ def run(self): self._pre_run() for stage in self.stages: self.current_stage = stage - stage.run() + stage._run() self._post_run() def pre_run(self): diff --git a/dmlcloud/core/stage.py b/dmlcloud/core/stage.py index 96015aa..43b6ed0 100644 --- a/dmlcloud/core/stage.py +++ b/dmlcloud/core/stage.py @@ -37,7 +37,7 @@ class Stage: - post_epoch() """ - def __init__(self, name: str = None, epochs: int = 1): + def __init__(self, name: str = None, epochs: int | None = 1): self.name = name or self.__class__.__name__ self.max_epochs = epochs @@ -46,7 +46,7 @@ def __init__(self, name: str = None, epochs: int = 1): self.pipe = None # set by the pipeline self.history = TrainingHistory() - self.tracker = Tracker() + self.metrics = Tracker() self.metric_prefix = None self.barrier_timeout = None @@ -88,7 +88,15 @@ def epoch_end_time(self): @property def table(self): - return self._table_callback.table + return self._table_callback.get_table(self) + + @property + def _run_overridden(self): + return type(self).run != Stage.run + + @property + def _run_epoch_overridden(self): + return type(self).run_epoch != Stage.run_epoch def add_callback(self, callback: 'Callback', priority: int = 1): """ @@ -108,11 +116,11 @@ def add_callback(self, callback: 'Callback', priority: int = 1): def log(self, name: str, value: Any, reduction: str = 'mean', prefixed: bool = True): if prefixed and self.metric_prefix: name = f'{self.metric_prefix}/{name}' - self.tracker.log(name, value, reduction) + self.metrics.log(name, value, reduction) def add_metric(self, name, metric): metric = metric.to(self.device) - self.tracker.add_metric(name, metric) + self.metrics.add_metric(name, metric) return metric def add_column( @@ -143,7 +151,7 @@ def add_column( alignment (str, optional): The alignment of the column. Defaults to None. """ self._table_callback.track_metric( - name, metric=metric, formatter=formatter, width=width, color=color, alignment=alignment + self, name, metric=metric, formatter=formatter, width=width, color=color, alignment=alignment ) def pre_stage(self): @@ -172,22 +180,58 @@ def post_epoch(self): """ pass + def run(): + """ + Override this method to implement the main logic of the stage and do manual epoch management. + + Either this method or :meth:`run_epoch` must be implemented by subclasses. + Unlike :meth:`run_epoch`, this method is called only once per stage, and the implementation is responsible for + managing the epochs and calling :meth:`next_epoch` when appropriate. + """ + raise NotImplementedError() + + def next_epoch(self): + """ + Advances the stage to the next epoch. + + This method must only be called by the implementation of :meth:`run` when the stage finishes an epoch. + """ + if self._run_epoch_overridden: + raise ValueError('next_epoch() must not be called when run_epoch() is implemented.') + + self._post_epoch() + self._pre_epoch() + def run_epoch(self): """ - Train the model for one epoch. Must be implemented by subclasses. + Override this method to implement the main logic of the stage for a single epoch. + + Either this method or :meth:`run` must be implemented by subclasses. + Unlike :meth:`run`, this method is called automatically by the stage and does not need to manage the epochs. """ raise NotImplementedError() - def run(self): + def _run(self): """ Runs this stage. Either until max_epochs are reached, or until stop_stage() is called. """ - self._pre_stage() - while self.max_epochs is None or self.current_epoch < self.max_epochs: + if self._run_overridden and self._run_epoch_overridden: + raise ValueError('Only one of run() or run_epoch() must be implemented.') + elif not self._run_overridden and not self._run_epoch_overridden: + raise ValueError('Either run() or run_epoch() must be implemented.') + elif self._run_epoch_overridden: + self._pre_stage() + while self.max_epochs is None or self.current_epoch < self.max_epochs: + self._pre_epoch() + self.run_epoch() + self._post_epoch() + self._post_stage() + else: + self._pre_stage() self._pre_epoch() - self.run_epoch() + self.run() self._post_epoch() - self._post_stage() + self._post_stage() def _pre_stage(self): if len(self.pipe.stages) > 1: diff --git a/examples/README.md b/examples/README.md index d2501a0..6128546 100644 --- a/examples/README.md +++ b/examples/README.md @@ -7,4 +7,4 @@ on the MNIST dataset using dmlcloud. | Example | Description | | --- | --- | | [mnist.py](mnist.py) | Minimal example that demonstrates how to train a simple neural network on the MNIST dataset using dmlcloud. | -| - | - | +| [custom_epochs.py](custom_epochs.py) | Demonstrates how to fully control when "epochs" start and end, e.g. for reinforcement learning or LLM training. | diff --git a/examples/custom_epochs.py b/examples/custom_epochs.py new file mode 100644 index 0000000..52d8dd2 --- /dev/null +++ b/examples/custom_epochs.py @@ -0,0 +1,81 @@ +import dmlcloud as dml +import torch +import torchmetrics +from torch import nn +from torch.utils.data import DataLoader +from torchvision import datasets, transforms + + +class CustomEpochStage(dml.Stage): + def pre_stage(self): + with dml.root_first(): + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) + train_dataset = datasets.MNIST(root='data', train=True, download=dml.is_root(), transform=transform) + + self.train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + self.train_loader = DataLoader(train_dataset, batch_size=32, sampler=self.train_sampler) + + model = nn.Sequential( + nn.Conv2d(1, 16, 3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.Conv2d(16, 16, 3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.Flatten(), + nn.Linear(784, 10), + ) + self.model = dml.wrap_ddp(model, self.device) + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=dml.scale_lr(1e-3)) + + self.loss = nn.CrossEntropyLoss() + + # Finally, we add columns to the table to track the loss and accuracy + self.add_column('# Steps', 'misc/steps') + self.add_column('# Samples', 'misc/total_samples') + + self.add_column('Loss', 'train/loss', color='green') + + def run(self): + MAX_STEPS = 5000 + LOG_PERIOD = 250 + + num_steps = 0 + total_samples = 0 + while num_steps < MAX_STEPS: + self.train_sampler.set_epoch(self.current_epoch) + + for img, target in self.train_loader: + img, target = img.to(self.device), target.to(self.device) + + self.optimizer.zero_grad() + output = self.model(img) + loss = self.loss(output, target) + loss.backward() + self.optimizer.step() + + self.log('train/loss', loss) + self.log('misc/samples', len(img), reduction='sum') + + num_steps += 1 + if num_steps % LOG_PERIOD == 0: + total_samples += self.metrics['misc/samples'].compute() + self.log('misc/total_samples', total_samples) + self.log('misc/steps', num_steps) + if num_steps < MAX_STEPS: + self.next_epoch() + self.train_sampler.set_epoch(self.current_epoch) + else: + break + + +def main(): + pipe = dml.Pipeline(name='custom-epochs') + pipe.append(CustomEpochStage()) + pipe.enable_checkpointing('checkpoints') + pipe.enable_wandb() + pipe.run() + + +if __name__ == '__main__': + main() diff --git a/examples/mnist.py b/examples/mnist.py index 866a131..bc13083 100644 --- a/examples/mnist.py +++ b/examples/mnist.py @@ -1,7 +1,3 @@ -import sys - -sys.path.insert(0, './') - import dmlcloud as dml import torch import torchmetrics @@ -100,7 +96,7 @@ def _val_epoch(self): def main(): - pipe = dml.Pipeline() + pipe = dml.Pipeline(name='MNIST') pipe.append(MNISTStage(epochs=3)) pipe.enable_checkpointing('checkpoints') pipe.enable_wandb()