diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index c56a20d..ce8c948 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -16,9 +16,7 @@ jobs: - macos-latest - ubuntu-latest python-version: - #- "3.8" - - "3.9" - #- "3.10" + - "3.10" steps: - name: Checkout Repository uses: actions/checkout@v3 @@ -33,7 +31,6 @@ jobs: - name: Install Dependencies run: | pip install -r requirements.txt - pip install --no-build-isolation horovod[pytorch] - name: Install Project run: | diff --git a/.gitignore b/.gitignore index b6e4761..9267900 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,9 @@ +/data +/slurm-*.out +/wandb + +############################ Auto Generated ############################ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/README.md b/README.md index baca555..9e65de4 100644 --- a/README.md +++ b/README.md @@ -1,55 +1,10 @@ -# Python Project Template +# dmlcloud +[![](https://img.shields.io/pypi/v/dmlcloud)](https://pypi.org/project/dmlcloud/) +[![](https://img.shields.io/github/actions/workflow/status/sehoffmann/dmlcloud/run_tests.yml?logo=github)](https://github.com/sehoffmann/dmlcloud/actions/workflows/run_tests.yml) +[![](https://img.shields.io/github/actions/workflow/status/sehoffmann/dmlcloud/run_linting.yml?label=lint&logo=github)](https://github.com/sehoffmann/dmlcloud/actions/workflows/run_linting.yml) -This is a quickstart project template for Python that already comes attached with the following features: +Flexibel, easy-to-use, opinionated -* Packaging and metadata support -* Formatting and linting via *pre-commit*, *black*, *usort*, and *flake8* -* Testing via *pytest* -* CI via github-actions +**dmlcloud** is a library for distributed training of deep learning models with torch. Its main aim is to do all these tiny little tedious things that everybody just copy pastes over and over again, while still giving you full control over the training loop and maximum flexibility. - -## Configuration - -To tailor this template to your needs, the following steps must be taken: - -1. Rename the *myproject* package folder to your project name -2. Change metadata and project name in *setup.cfg*. -3. Do not forget to change the version attribute to point to your new package name as well. -4. Add dependencies to *requirements.txt* -5. Adjust the *LICENSE* file to your liking. -6. Adjust this *README.md* file to your liking. - -### Formatting and linting - -Install *pre-commit* and *pytest* via -``` -pip install -r ci_requirements.txt -``` - -To format and lint the entire codebase run: -``` -pre-commit run --all-files -``` - -To perform this step automatically during each commit (and fail on errors) run: -``` -pre-commit install -``` - -### Testing -To run the tests execute: -``` -pytest -``` -in the top-level directory. -Tests can also be executed individually by running them as regular python script. This requires you to add a small main function to them, c.f. *test/test_myproject.py*. - -### Github Actions -This project defines the following workflows: -1. *run_linting.yml* will run `pre-commit run --all-files` on every push to develop and pull request -2. *run_tests.yml* will run `pytest` on Windows, Ubuntu, and MacOS on every push to develop and pull_request -3. *release_public.yml* and *release_test.yml* can be triggered manually to build a wheel distribution and publish it to PyPI or TestPyPI respectively - -For the publising to work, you need to add the PyPI API token as Github secrets: -* *PYPI_TOKEN* for the official PyPI index -* *TEST_PYPI_TOKEN* for the TestPyPI index +Unlike other similar frameworks, such as *lightning*, dmcloud tries to add as little additional complexity and abstraction as possible. Instead, it is tailored towards a careful selected set of libraries and workflows and sticks with them. diff --git a/dmlcloud/__init__.py b/dmlcloud/__init__.py index db12b47..3c99987 100644 --- a/dmlcloud/__init__.py +++ b/dmlcloud/__init__.py @@ -1,14 +1,3 @@ -from .config import ArgparseVar, BaseConfig, ConfigVar, DefaultConfig, SubConfig -from .training import BaseTrainer, ClassificationTrainer +__version__ = "0.3.0" -__version__ = "0.1.1" - -__all__ = [ - 'ArgparseVar', - 'BaseConfig', - 'BaseTrainer', - 'ClassificationTrainer', - 'ConfigVar', - 'DefaultConfig', - 'SubConfig', -] +__all__ = [] diff --git a/dmlcloud/checkpoint.py b/dmlcloud/checkpoint.py new file mode 100644 index 0000000..8db023e --- /dev/null +++ b/dmlcloud/checkpoint.py @@ -0,0 +1,123 @@ +import datetime +import logging +import secrets +from pathlib import Path +from typing import Optional + +from omegaconf import OmegaConf + +from dmlcloud.util.slurm import slurm_job_id + + +def sanitize_filename(filename: str) -> str: + return filename.replace('/', '_') + + +def generate_id() -> str: + s = secrets.token_urlsafe(5) + return s.replace('-', 'a').replace('_', 'b') + + +def generate_checkpoint_path( + root: Path | str, name: Optional[str] = None, creation_time: Optional[datetime.datetime] = None +) -> Path: + root = Path(root) + + if name is None: + name = 'run' + + if creation_time is None: + creation_time = datetime.datetime.now() + + dt = datetime.datetime.now().strftime('%Y.%m.%d-%H:%M') + name = sanitize_filename(name) + return root / f'{name}-{dt}-{generate_id()}' + + +def find_slurm_checkpoint(root: Path | str) -> Optional[Path]: + root = Path(root) + + job_id = slurm_job_id() + if job_id is None: + return None + + for child in root.iterdir(): + if CheckpointDir(child).is_valid and CheckpointDir(child).slurm_job_id == job_id: + return child + + return None + + +class CheckpointDir: + def __init__(self, path: Path): + self.path = path.resolve() + self.logger = logging.getLogger('dmlcloud') + + @property + def config_file(self) -> Path: + return self.path / 'config.yaml' + + @property + def indicator_file(self) -> Path: + return self.path / '.dmlcloud' + + @property + def log_file(self) -> Path: + return self.path / 'log.txt' + + @property + def slurm_file(self) -> Path: + return self.path / '.slurm-jobid' + + @property + def exists(self) -> bool: + return self.path.exists() + + @property + def is_valid(self) -> bool: + if not self.exists or not self.path.is_dir(): + return False + + if not self.indicator_file.exists(): + return False + + return True + + @property + def slurm_job_id(self) -> Optional[str]: + if not self.slurm_file.exists(): + return None + + with open(self.slurm_file) as f: + return f.read() + + def create(self): + if self.exists: + raise ValueError(f'Checkpoint directory already exists: {self.path}') + + self.path.mkdir(parents=True, exist_ok=True) + self.indicator_file.touch() + self.log_file.touch() + if slurm_job_id() is not None: + with open(self.slurm_file, 'w') as f: + f.write(slurm_job_id()) + + def save_config(self, config: OmegaConf): + if not self.exists: + raise ValueError(f'Checkpoint directory does not exist: {self.path}') + + with open(self.config_file, 'w') as f: + OmegaConf.save(config, f) + + def load_config(self) -> OmegaConf: + if not self.is_valid: + raise ValueError(f'Checkpoint directory is not valid: {self.path}') + + with open(self.config_file) as f: + return OmegaConf.load(f) + + def __str__(self) -> str: + return str(self.path) + + def __repr__(self) -> str: + return f'CheckpointDir({self.path})' diff --git a/dmlcloud/config/__init__.py b/dmlcloud/config/__init__.py deleted file mode 100644 index 1d46577..0000000 --- a/dmlcloud/config/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from .common import ArgparseVar, ConfigVar, SubConfig -from .config import BaseConfig, DefaultConfig -from .training import TrainingConfig - - -__all__ = [ - 'ArgparseVar', - 'BaseConfig', - 'ConfigVar', - 'DefaultConfig', - 'ModelConfig', - 'SubConfig', - 'TrainingConfig', -] -assert __all__ == sorted(__all__) diff --git a/dmlcloud/config/common.py b/dmlcloud/config/common.py deleted file mode 100644 index 6d09bd9..0000000 --- a/dmlcloud/config/common.py +++ /dev/null @@ -1,80 +0,0 @@ -class ConfigVar: - def _default_parse_fn(self, config, args): - val = getattr(args, self.name, None) - if val is not None: - setattr(config, self.name, val) - - def __init__(self, add_argument_fn=None, parse_argument_fn=_default_parse_fn): - if parse_argument_fn is ConfigVar._default_parse_fn: - parse_argument_fn = self._default_parse_fn # bind self - self.add_argument_fn = add_argument_fn - self.parse_argument_fn = parse_argument_fn - - def __set_name__(self, owner, name): - self.name = name - if not hasattr(owner, '_config_vars'): - owner._config_vars = [] - owner._config_vars.append(self) - - def __get__(self, obj, objtype=None): - return obj.dct[self.name] - - def __set__(self, obj, value): - obj.dct[self.name] = value - - def __delete__(self, obj): - del obj.dct[self.name] - - def add_argument(self, config, parser): - if self.add_argument_fn: - self.add_argument_fn(config, parser) - - def parse_argument(self, config, args): - if self.parse_argument_fn: - self.parse_argument_fn(config, args) - - -class ArgparseVar(ConfigVar): - def _default_add_fn(self, config, parser): - option = f'--{self.name.replace("_", "-")}' - args = self.args or [option] - kwargs = self.kwargs.copy() - kwargs['dest'] = self.name - parser.add_argument(*args, **kwargs) - - def __init__(self, *args, add_argument_fn=_default_add_fn, parse_argument_fn=ConfigVar._default_parse_fn, **kwargs): - if add_argument_fn is ArgparseVar._default_add_fn: - add_argument_fn = self._default_add_fn # bind self - - super().__init__(add_argument_fn, parse_argument_fn) - self.args = args - self.kwargs = kwargs - if 'dest' in self.kwargs: - raise ValueError('dest cannot be specified in kwargs') - - -class SubConfig: - def __init__(self, parent, root_dct, key=None): - self.parent = parent - self.root_dct = root_dct - self.key = key - self.set_defaults() - - def set_defaults(self): - pass - - @property - def dct(self): - if self.key is None: - return self.root_dct - else: - return self.root_dct[self.key] - - def add_arguments(self, parser): - group = parser.add_argument_group(self.key) - for cfg_var in self._config_vars: - cfg_var.add_argument(self, group) - - def parse_args(self, args): - for cfg_var in self._config_vars: - cfg_var.parse_argument(self, args) diff --git a/dmlcloud/config/config.py b/dmlcloud/config/config.py deleted file mode 100644 index 0a50f74..0000000 --- a/dmlcloud/config/config.py +++ /dev/null @@ -1,69 +0,0 @@ -import argparse - -from .meta import MetaConfig -from .training import TrainingConfig -from .wandb import WandbConfig - - -class BaseConfig: - def __init__(self, dct=None): - self._sub_configs = [] - self.dct = {} - - self._sub_configs = [] - self._setup_sub_configs() - self.set_defaults() - - if dct: - self.dct.update(dct) - - def __getattr__(self, name): - if name == '_sub_configs': - return super().__getattribute__(name) - - for cfg in self._sub_configs: - try: - return getattr(cfg, name) - except AttributeError: - pass - raise AttributeError(f'Config has no attribute {name}') - - def __setattr__(self, name, value): - if name != '_sub_configs': - for cfg in self._sub_configs: - if hasattr(cfg, name): - setattr(cfg, name, value) - return - super().__setattr__(name, value) - - def _setup_sub_configs(self): - raise NotImplementedError() - - def set_sub_config(self, key, cls): - self.dct.setdefault(key, {}) - sub_cfg = cls(self, self.dct, key) - self._sub_configs.append(sub_cfg) - - def create_parser(self, parser=None): - if parser is None: - parser = argparse.ArgumentParser() - for sub_cfg in self._sub_configs: - sub_cfg.add_arguments(parser) - return parser - - def parse_args(self, args): - for sub_cfg in self._sub_configs: - sub_cfg.parse_args(args) - - def as_dictionary(self): - return dict(self.dct) - - def set_defaults(self): - pass - - -class DefaultConfig(BaseConfig): - def _setup_sub_configs(self): - self.set_sub_config('meta', MetaConfig) - self.set_sub_config('training', TrainingConfig) - self.set_sub_config('wandb', WandbConfig) diff --git a/dmlcloud/config/meta.py b/dmlcloud/config/meta.py deleted file mode 100644 index 90595e3..0000000 --- a/dmlcloud/config/meta.py +++ /dev/null @@ -1,24 +0,0 @@ -import sys -from pathlib import Path - -from dmlcloud.util import git_hash -from .common import ArgparseVar, ConfigVar, SubConfig - - -class MetaConfig(SubConfig): - trainer_cls = ConfigVar() - model_dir = ConfigVar() - job_id = ConfigVar() - command_line = ConfigVar() - git_hash = ConfigVar() - checkpoint_dir = ArgparseVar('-d', '--dir', type=Path, help='The directory where runs are stored') - name = ArgparseVar('-n', '--name', help='The name of the experiment') - - def set_defaults(self): - self.trainer_cls = None - self.model_dir = None - self.job_id = None - self.command_line = ' '.join(sys.argv) - self.git_hash = git_hash() - self.checkpoint_dir = Path('./checkpoints').resolve() - self.name = None diff --git a/dmlcloud/config/training.py b/dmlcloud/config/training.py deleted file mode 100644 index 61201d7..0000000 --- a/dmlcloud/config/training.py +++ /dev/null @@ -1,42 +0,0 @@ -from argparse import BooleanOptionalAction - -from .common import ArgparseVar, SubConfig - - -class TrainingConfig(SubConfig): - seed = ArgparseVar(type=int, help='The random seed') - epochs = ArgparseVar(type=int, help='The number of epochs') - batch_size = ArgparseVar(type=int, help='The batch size') - base_batch_size = ArgparseVar(type=int, help='Base batch size for all parameters (lr, beta1, beta2)') - base_lr = ArgparseVar(type=float, help='The learning rate per 32-er batch / GPU') - base_beta1 = ArgparseVar(type=float, help='The beta1 parameter (momentum) per base batch size') - base_beta2 = ArgparseVar(type=float, help='The beta2 parameter per base batch size') - scale_lr = ArgparseVar(type=bool, action=BooleanOptionalAction, help='Scale the learning rate') - scale_beta1 = ArgparseVar(type=bool, action=BooleanOptionalAction, help='Scale the beta1 parameter (momentum)') - scale_beta2 = ArgparseVar(type=bool, action=BooleanOptionalAction, help='Scale the beta2 parameter') - rampup_epochs = ArgparseVar(type=int, help='The number of epochs to ramp up the learning rate') - weight_decay = ArgparseVar(type=float, help='The weight decay') - clip_gradients = ArgparseVar(type=float, help='The gradient clipping threshold') - log_gradients = ArgparseVar(type=bool, action=BooleanOptionalAction, help='Log gradients during training') - mixed = ArgparseVar(type=bool, action=BooleanOptionalAction, help='Use mixed precision training') - adasum = ArgparseVar(type=bool, action=BooleanOptionalAction, help='Use adasum for distributed training') - check_nans = ArgparseVar(type=bool, action=BooleanOptionalAction, help='Check for NaNs during training') - - def set_defaults(self): - self.seed = None - self.epochs = 10 - self.batch_size = 32 - self.base_batch_size = 32 - self.base_lr = 1e-3 - self.base_beta1 = 0.9 - self.base_beta2 = 0.999 - self.scale_lr = True - self.scale_beta1 = False - self.scale_beta2 = False - self.rampup_epochs = 5 - self.weight_decay = 1e-4 - self.clip_gradients = None - self.log_gradients = True - self.mixed = False - self.adasum = False - self.check_nans = False diff --git a/dmlcloud/config/wandb.py b/dmlcloud/config/wandb.py deleted file mode 100644 index 1823d0e..0000000 --- a/dmlcloud/config/wandb.py +++ /dev/null @@ -1,12 +0,0 @@ -from .common import ArgparseVar, SubConfig - - -class WandbConfig(SubConfig): - wb_project = ArgparseVar('-p', '--project', type=str, help='The wandb project name') - wb_name = ArgparseVar('--wb-name', type=str, help='Can be used to override the wandb experiment name') - wb_tags = ArgparseVar('--tags', nargs='+', type=str, help='The wandb tags') - - def set_defaults(self): - self.wb_project = None - self.wb_name = None - self.wb_tags = [] diff --git a/dmlcloud/metrics.py b/dmlcloud/metrics.py new file mode 100644 index 0000000..09ff494 --- /dev/null +++ b/dmlcloud/metrics.py @@ -0,0 +1,290 @@ +from enum import Enum + +import torch +import torch.distributed as dist + + +class Reduction(Enum): + MEAN = 'MEAN' + SUM = 'SUM' + MIN = 'MIN' + MAX = 'MAX' + + def as_torch(self): + if self == Reduction.SUM: + return dist.ReduceOp.SUM + elif self == Reduction.MIN: + return dist.ReduceOp.MIN + elif self == Reduction.MAX: + return dist.ReduceOp.MAX + else: + raise ValueError(f'Reduction {self} is not supported by torch') + + +def reduce_tensor(tensor, reduction, dim=None): + if not isinstance(tensor, torch.Tensor): + raise ValueError('tensor must be a torch.Tensor') + + # required because dim=None is not supported by torch + if dim is None: + dim = list(range(tensor.dim())) + + if reduction is Reduction.MEAN: + return tensor.mean(dim) + elif reduction is Reduction.SUM: + return tensor.sum(dim) + elif reduction is Reduction.MIN: + return tensor.amin(dim) + elif reduction is Reduction.MAX: + return tensor.amax(dim) + else: + raise ValueError(f'Unknown reduction {reduction}') + + +class MetricReducer: + """ + Stores a list of tensors and reduces them at the end of an epoch. + The dim argument specifies the dimensions to reduce over. If None, every dimension is completely reduced. + Notice that the list of individual tensors stored in this obcect, is ALWAYS reduced, both locally and distributed. + Hence, dimension 0 refers to the first dimension of individual tensors, which is usually the batch dimension. + """ + + def __init__(self, reduction=Reduction.MEAN, dim=None, globally=True): + if reduction not in [Reduction.MEAN, Reduction.SUM, Reduction.MIN, Reduction.MAX]: + raise ValueError(f'Unknown reduction {self.reduction}') + + self.values = [] + self.reduction = reduction + self.globally = globally + if isinstance(dim, int): + self.dim = [dim] + elif dim is not None: + self.dim = list(dim) + else: + self.dim = None + + def append(self, value): + """ + Appends a value to the list of values. + If the value is a tensor, it is detached and moved to the cpu to avoid growing memory consumption. + """ + value = torch.as_tensor(value) + value = value.detach().cpu() + self.values.append(value) + + def extend(self, values): + for value in values: + self.append(value) + + def __iadd__(self, value): + self.append(value) + return self + + def __setitem__(self, idx, value): + value = torch.as_tensor(value) + value = value.detach().cpu() + self.values[idx] = value + + def __getitem__(self, idx): + return self.values[idx] + + def __delitem__(self, idx): + del self.values[idx] + + def __len__(self): + return len(self.values) + + def __iter__(self): + return iter(self.values) + + def clear(self): + self.values.clear() + + def reduce_and_append(self, value): + value = reduce_tensor(value, self.reduction, dim=self.dim) + self.values.append(value) + + def reduce_locally(self): + if isinstance(self.dim, list): + dim = [0] + [d + 1 for d in self.dim] + elif isinstance(self.dim, int): + dim = [0, self.dim + 1] + else: + dim = None + tensor = torch.stack(self.values) + tensor = reduce_tensor(tensor, reduction=self.reduction, dim=dim) + return tensor + + def reduce_globally(self, group=None, async_op=False): + tensor = self.reduce_locally() + if self.globally: + if self.reduction == Reduction.MEAN: + dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=async_op) + tensor /= dist.get_world_size(group) + else: + dist.all_reduce(tensor, op=self.reduction.as_torch(), group=group, async_op=async_op) + return tensor + + def state_dict(self): + return { + 'reduction': self.reduction, + 'dim': self.dim, + 'globally': self.globally, + 'values': self.values, + } + + def load_state_dict(self, state): + self.reduction = state['reduction'] + self.dim = state['dim'] + self.globally = state['globally'] + self.values = state['values'] + + +class MetricTracker: + """ + This class keeps track of multiple metrics and their history. + + Usage: + tracker = MetricTracker() + tracker.register_metric('loss', reduction=Reduction.MEAN) + tracker.track('loss', torch.randn(10, 1)) + tracker.next_epoch() + + print(tracker['loss'].last()) + """ + + def __init__(self): + self.histories = {} + self.reducers = {} + self.epoch = 1 + + def __getitem__(self, name): + """ + Returns the history of a metric up to the current epoch. + Values for the current epoch that have been reduced already are not included. + """ + if name not in self: + raise ValueError(f'Metric {name} does not exist') + return list(self.histories[name])[: self.epoch - 1] + + def __contains__(self, name): + return name in self.histories + + def __len__(self): + return len(self.histories) + + def __iter__(self): + return iter(self.histories) + + def current_value(self, name): + """ + If the metric already has an reduced value for the current epoch, it is returned. Otherwise, None is returned. + """ + if name not in self: + raise ValueError(f'Metric {name} does not exist') + if self.has_value(name): + return self.histories[name][-1] + else: + return None + + def is_reduced_metric(self, name): + """ + Returns True if the metric gets (all)reduced at the end of each epoch. + """ + if name not in self: + raise ValueError(f'Metric {name} does not exist') + return name in self.reducers + + def has_value(self, name): + """ + Returns True if the metric has a final value for the current epoch. + """ + if name not in self: + raise ValueError(f'Metric {name} does not exist') + return len(self.histories[name]) >= self.epoch + + def register_metric(self, name, reduction=None, dim=None, globally=True): + if name in self: + raise ValueError(f'Metric {name} already exists') + + if dim is not None and reduction is None: + raise ValueError('If dim is specified, reduction must be specified as well') + + self.histories[name] = [] + [None] * (self.epoch - 1) + if reduction is not None: + self.reducers[name] = MetricReducer(reduction=reduction, dim=dim, globally=globally) + + def track(self, name, value): + if isinstance(value, torch.Tensor): + value = value.detach().cpu() + + if name not in self: + raise ValueError(f'Metric {name} does not exist') + + if self.has_value(name): + raise ValueError(f'History for {name} already has a value for epoch {self.epoch}') + + history = self.histories[name] + reducer = self.reducers.get(name) + if reducer is not None: + reducer.append(value) + else: + history.append(value) + + def reduce_all(self, prefix=None, strict=True): + """ + Reduces all metrics and appends their reduced values to the history. + If prefix is specified, only metrics with the specified prefix are reduced. + If strict is True, an error is raised if a metric has already been reduced for the current epoch. + + After this method has been called, no more values for the reduced metrics can be tracked for the current epoch, + and next_epoch() must be called to be able to track new values. + """ + for name, history in self.histories.items(): + if prefix is not None and not name.startswith(prefix): + continue + + if self.has_value(name): + if strict: + raise ValueError(f'History for {name} has already been reduced for epoch {self.epoch}') + else: + continue + + if name in self.reducers: + history.append(self.reducers[name].reduce_globally()) + self.reducers[name].clear() + else: + history.append(None) + + def next_epoch(self): + """ + Reduces all metrics (if not already reduced) and advances the epoch counter. + """ + self.reduce_all(strict=False) + self.epoch += 1 + + def state_dict(self): + state = { + 'epoch': self.epoch, + 'histories': dict(self.histories), + 'reducers': {name: reducer.state_dict() for name, reducer in self.reducers.items()}, + } + return state + + def load_state_dict(self, state): + self.epoch = state['epoch'] + self.histories = state['histories'] + self.reducers = {} + for name, reducer_state in state['reducers'].items(): + self.reducers[name] = MetricReducer() + self.reducers[name].load_state_dict(reducer_state) + + def __str__(self): + s = 'MetricTracker(' + for name, history in self.histories.items(): + s += f'\n {name}: {history}' + if len(self.histories) > 0: + s += '\n)' + else: + s += ')' + return s diff --git a/dmlcloud/pipeline.py b/dmlcloud/pipeline.py new file mode 100644 index 0000000..34b2ebd --- /dev/null +++ b/dmlcloud/pipeline.py @@ -0,0 +1,293 @@ +import logging +from datetime import datetime +from typing import Any, Dict, List, Optional, Sequence, Union + +import torch +import torch.distributed as dist +from omegaconf import OmegaConf +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data import DataLoader, Dataset + +from dmlcloud.util.wandb import wandb_is_initialized, wandb_set_startup_timeout +from .checkpoint import CheckpointDir, find_slurm_checkpoint, generate_checkpoint_path +from .metrics import MetricTracker, Reduction +from .stage import Stage +from .util.distributed import local_rank +from .util.logging import add_log_handlers, experiment_header, general_diagnostics, IORedirector + + +class TrainingPipeline: + def __init__(self, cfg: Optional[Union[OmegaConf, Dict]] = None, name: Optional[str] = None): + if cfg is None: + self.cfg = OmegaConf.create() + elif not isinstance(cfg, OmegaConf): + self.cfg = OmegaConf.create(cfg) + else: + self.cfg = cfg + + self.name = name + + self.logger = logging.getLogger('dmlcloud') + self.checkpoint_dir = None + self.io_redirector = None + self.resumed = None + self.tracker = MetricTracker() + self.device = None + self.start_time = None + self.stop_time = None + self.current_stage = None + + self.wandb = False + self._wandb_initalizer = None + + self.stages = [] + self.datasets = {} + self.models = {} + self.optimizers = {} + self.schedulers = {} + + @property + def checkpointing_enabled(self): + return self.checkpoint_dir is not None + + def register_model( + self, + name: str, + model: torch.nn.Module, + use_ddp: bool = True, + save_latest: bool = True, + save_interval: Optional[int] = None, + save_best: bool = False, + best_metric: str = 'val/loss', + verbose: bool = True, + ): + if name in self.models: + raise ValueError(f'Model with name {name} already exists') + if use_ddp: + model = DistributedDataParallel(model, broadcast_buffers=False) + model = model.to(self.device) + self.models[name] = model + + if verbose: + msg = f'Model "{name}":\n' + msg += f' - Parameters: {sum(p.numel() for p in model.parameters())/1e6:.1f} kk\n' + msg += f' - DDP: {use_ddp}\n' + msg += f' - {model}' + self.logger.info(msg) + + def register_optimizer(self, name: str, optimizer, scheduler=None): + if name in self.optimizers: + raise ValueError(f'Optimizer with name {name} already exists') + self.optimizers[name] = optimizer + if scheduler is not None: + self.schedulers[name] = scheduler + + def register_dataset(self, name: str, dataset: Union[DataLoader, Dataset, Sequence], verbose: bool = True): + if name in self.datasets: + raise ValueError(f'Dataset with name {name} already exists') + + self.datasets[name] = dataset + if verbose: + msg = f'Dataset "{name}":\n' + msg += f' - Batches (Total): ~{len(dataset) * dist.get_world_size()}\n' + msg += f' - Batches (/Worker): {len(dataset)}\n' + self.logger.info(msg) + + def append_stage(self, stage: Stage, max_epochs: Optional[int] = None, name: Optional[str] = None): + if not isinstance(stage, Stage): + raise ValueError('stage must be a Stage object') + + stage.pipeline = self + stage.max_epochs = max_epochs + stage.name = name + self.stages.append(stage) + + def enable_checkpointing( + self, + root: str, + resume: bool = True, + ): + if self.checkpointing_enabled: + raise ValueError('Checkpointing already enabled') + + path = None + if resume and CheckpointDir(root).is_valid: + path = root + self.resumed = True + elif resume and find_slurm_checkpoint(root): + path = find_slurm_checkpoint(root) + self.resumed = True + if path is None: + path = generate_checkpoint_path(root=root, name=self.name, creation_time=self.start_time) + self.resumed = False + self.checkpoint_dir = CheckpointDir(path) + + def enable_wandb( + self, + project: str | None = None, + entity: str | None = None, + group: str | None = None, + tags: List[str] | None = None, + startup_timeout: int = 360, + **kwargs, + ): + import wandb # import now to avoid potential long import times later on + + self.wandb = True + + def initializer(): + wandb_set_startup_timeout(startup_timeout) + wandb.init( + config=OmegaConf.to_container(self.cfg, resolve=True), + name=self.name, + entity=entity, + project=project, + group=group, + tags=tags, + **kwargs, + ) + + self._wandb_initalizer = initializer + + def track_reduce( + self, + name: str, + value: torch.Tensor, + step: Optional[int] = None, + reduction: Reduction = Reduction.MEAN, + dim: Optional[List[int]] = None, + reduce_globally: bool = True, + ): + if name not in self.tracker: + self.tracker.register_metric(name, reduction, dim, reduce_globally) + + self.tracker.track(name, value) + + def track( + self, + name: str, + value: Any, + step: Optional[int] = None, + ): + if name not in self.tracker: + self.tracker.register_metric(name) + + self.tracker.track(name, value) + + def run(self): + """ + Starts the training and runs all registered stages. + """ + with _RunGuard(self): + self._pre_run() + for stage in self.stages: + stage.run() + self._post_run() + + def pre_run(self): + pass + + def post_run(self): + pass + + def resume_run(self): + pass + + def _pre_run(self): + if len(self.stages) == 0: + raise ValueError('No stages defined. Use append_stage() to add stages to the pipeline.') + + if not dist.is_initialized(): + raise ValueError( + 'Default process group not initialized! Call torch.distributed.init_process_group() first.' + ) + + if torch.cuda.is_available(): + if local_rank() is None: + self.device = torch.device('cuda') + else: + self.device = torch.device('cuda', local_rank()) + else: + self.device = torch.device('cpu') + + if self.checkpointing_enabled: + self._init_checkpointing() + + if self.wandb: + self._wandb_initalizer() + + self.start_time = datetime.now() + + add_log_handlers(self.logger) + header = '\n' + experiment_header(self.name, self.checkpoint_dir, self.start_time) + self.logger.info(header) + + if self.resumed: + self._resume_run() + + diagnostics = general_diagnostics() + diagnostics += '\n* CONFIG:\n' + OmegaConf.to_yaml(self.cfg) + self.logger.info(diagnostics) + + self.pre_run() + + def _init_checkpointing(self): + if not self.checkpoint_dir.is_valid: + self.checkpoint_dir.create() + self.checkpoint_dir.save_config(self.cfg) + self.io_redirector = IORedirector(self.checkpoint_dir.log_file) + self.io_redirector.install() + + def _resume_run(self): + self.logger.info(f'Resuming training from checkpoint: {self.checkpoint_dir}') + self.resume_run() + + def _post_run(self): + self.stop_time = datetime.now() + self.logger.info(f'Finished training in {self.stop_time - self.start_time} ({self.stop_time})') + if self.checkpointing_enabled: + self.logger.info(f'Outputs have been saved to {self.checkpoint_dir}') + self.post_run() + + def _pre_epoch(self): + pass + + def _post_epoch(self): + if self.wandb: + import wandb + + metrics = {name: self.tracker[name][-1] for name in self.tracker} + wandb.log(metrics) + + def _cleanup(self, exc_type, exc_value, traceback): + """ + Called by _RunGuard to ensure that the pipeline is properly cleaned up + """ + if exc_type is KeyboardInterrupt: + self.logger.info('------- Training interrupted by user -------') + elif exc_type is not None: + self.logger.error( + '------- Training failed with an exception -------', exc_info=(exc_type, exc_value, traceback) + ) + + if self.wandb: + import wandb + + if wandb_is_initialized(): + wandb.finish(exit_code=0 if exc_type is None else 1) + + if self.io_redirector is not None: + self.io_redirector.uninstall() + + return False + + +class _RunGuard: + def __init__(self, pipeline): + self.pipeline = pipeline + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_value, traceback): + return self.pipeline._cleanup(exc_type, exc_value, traceback) diff --git a/dmlcloud/stage.py b/dmlcloud/stage.py new file mode 100644 index 0000000..3fa9d4f --- /dev/null +++ b/dmlcloud/stage.py @@ -0,0 +1,280 @@ +import sys +from datetime import datetime +from typing import Any, Dict, List, Optional, Union + +import torch +from progress_table import ProgressTable + +from .metrics import MetricTracker, Reduction +from .util.distributed import is_root, root_only + + +class Stage: + """ + Hook Points: + - pre_stage() + - post_stage() + - pre_epoch() + - post_epoch() + """ + + def __init__(self): + self.pipeline = None # set by the pipeline + self.max_epochs = None # set by the pipeline + self.name = None # set by the pipeline + + self.start_time = None + self.stop_time = None + self.epoch_start_time = None + self.epoch_stop_time = None + self.current_epoch = 1 + self._stop_requested = False + + self.metric_prefix = None + self.table = None + + @property + def tracker(self) -> MetricTracker: + return self.pipeline.tracker + + @property + def logger(self): + return self.pipeline.logger + + @property + def device(self): + return self.pipeline.device + + def track_reduce( + self, + name: str, + value: torch.Tensor, + step: Optional[int] = None, + reduction: Reduction = Reduction.MEAN, + dim: Optional[List[int]] = None, + reduce_globally: bool = True, + prefixed: bool = True, + ): + if prefixed and self.metric_prefix: + name = f'{self.metric_prefix}/{name}' + self.pipeline.track_reduce(name, value, step, reduction, dim, reduce_globally) + + def track(self, name: str, value, step: Optional[int] = None, prefixed: bool = True): + if prefixed and self.metric_prefix: + name = f'{self.metric_prefix}/{name}' + self.pipeline.track(name, value, step) + + def stop_stage(self): + self._stop_requested = True + + def pre_stage(self): + """ + Executed before the stage starts. + Use this method to setup aby stage-specific data sets or models. + """ + pass + + def post_stage(self): + """ + Executed after the stage finishes. + Use this method to clean up any stage-specific resources or to save any intermediate results/artifacts. + """ + pass + + def pre_epoch(self): + """ + Executed before each epoch. + """ + pass + + def post_epoch(self): + """ + Executed after each epoch and after the metrics have been reduced. + """ + pass + + def run_epoch(self): + """ + Train the model for one epoch. Must be implemented by subclasses. + """ + raise NotImplementedError() + + def table_columns(self) -> List[Union[str, Dict[str, Any]]]: + """ + Override this method to customize the metrics displayed in the progress table. + + Should return a list containing either strings or dicts. + If a string, it will be used as both the display name and the metric name. + If a dict, it should contain a 'name' key and a 'metric' key. + The 'name' key will be used as the display name, and the 'metric' key will be used as the metric name. + Additional keys are forwarded to the ProgressTable.add_column method. + If 'metric' is None, then the user is responsible for updating the column manually. + """ + columns = [ + {'name': 'Epoch', 'metric': 'misc/epoch'}, + {'name': 'Time/Epoch', 'metric': None}, + ] + if self.max_epochs is not None: + columns.append({'name': 'ETA', 'metric': None}) + return columns + + def run(self): + """ + Runs this stage. Either until max_epochs are reached, or until stop_stage() is called. + """ + self._pre_stage() + while self.max_epochs is None or self.current_epoch <= self.max_epochs: + self._pre_epoch() + self.run_epoch() + self._post_epoch() + if self._stop_requested: + break + self._post_stage() + + def _pre_stage(self): + self.start_time = datetime.now() + self.table = ProgressTable(file=sys.stdout) + self._setup_table() + + if len(self.pipeline.stages) > 1: + self.logger.info(f'\n========== STAGE: {self.name} ==========') + + self.pre_stage() + + for handler in self.logger.handlers: + handler.flush() + + self.table._print_header() + + def _post_stage(self): + self.stop_time = datetime.now() + if is_root(): + self.table.close() + if len(self.pipeline.stages) > 1: + self.logger.info(f'Finished stage in {self.stop_time - self.start_time}') + self.post_stage() + + def _pre_epoch(self): + self.epoch_start_time = datetime.now() + self.pre_epoch() + self.pipeline._pre_epoch() + + def _post_epoch(self): + self.epoch_stop_time = datetime.now() + self._reduce_metrics() + self.post_epoch() + self.pipeline._post_epoch() + self._update_table() + self.current_epoch += 1 + + def _reduce_metrics(self): + self.track(name='misc/epoch', value=self.current_epoch, prefixed=False) + self.track( + name='misc/epoch_time', value=(self.epoch_stop_time - self.epoch_stop_time).total_seconds(), prefixed=False + ) + self.tracker.next_epoch() + pass + + @root_only + def _setup_table(self): + for column_dct in self._metrics(): + display_name = column_dct.pop('name') + column_dct.pop('metric') + self.table.add_column(display_name, **column_dct) + + @root_only + def _update_table(self): + self.table.update('Epoch', self.current_epoch) + self.table.update('Time/Epoch', (datetime.now() - self.start_time) / self.current_epoch) + self.table.update( + 'ETA', (datetime.now() - self.start_time) / self.current_epoch * (self.max_epochs - self.current_epoch) + ) + for column_dct in self._metrics(): + display_name = column_dct['name'] + metric_name = column_dct['metric'] + if metric_name is not None: + self.table.update(display_name, self.tracker[metric_name][-1]) + self.table.next_row() + + def _metrics(self): + metrics = [] + for column in self.table_columns(): + if isinstance(column, str): + metrics.append({'name': column, 'metric': column}) + elif isinstance(column, dict): + if 'name' not in column: + raise ValueError('Column dict must contain a "name" key') + if 'metric' not in column: + raise ValueError('Column dict must contain a "metric" key') + metrics.append(column) + else: + raise ValueError(f'Invalid column: {column}. Must be a string or a dict.') + return metrics + + +class TrainValStage(Stage): + def __init__(self): + super().__init__() + self.is_train = True + + def run_epoch(self): + self.train_epoch() + self.val_epoch() + + def step(self, batch) -> torch.Tensor: + raise NotImplementedError() + + def train_step(self, batch): + return self.step(batch) + + def val_step(self, batch): + return self.step(batch) + + def train_epoch(self): + self.is_train = True + self.metric_prefix = 'train' + + train_ds = self.pipeline.datasets.get('train') + if train_ds is None: + raise ValueError( + 'No "train" dataset found in pipeline. Use register_dataset("train", ...) to register a dataset.' + ) + + if hasattr(train_ds, 'sampler') and hasattr(train_ds.sampler, 'set_epoch'): + train_ds.sampler.set_epoch(self.current_epoch) + + for batch in train_ds: + for optimizer in self.pipeline.optimizers.values(): + optimizer.zero_grad() + + loss = self.train_step(batch) + loss.backward() + + for optimizer in self.pipeline.optimizers.values(): + optimizer.step() + + self.track_reduce('loss', loss) + + for scheduler in self.pipeline.schedulers.values(): + scheduler.step() + + @torch.no_grad() + def val_epoch(self): + self.is_train = False + self.metric_prefix = 'val' + + val_ds = self.pipeline.datasets.get('val') + if val_ds is None: + raise ValueError( + 'No "val" dataset found in pipeline. Use register_dataset("val", ...) to register a dataset.' + ) + + for batch in val_ds: + loss = self.val_step(batch) + self.track_reduce('loss', loss) + + def table_columns(self): + columns = super().table_columns() + columns.insert(1, {'name': '[Train] Loss', 'metric': 'train/loss'}) + columns.insert(2, {'name': '[Val] Loss', 'metric': 'val/loss'}) + return columns diff --git a/dmlcloud/training/__init__.py b/dmlcloud/training/__init__.py deleted file mode 100644 index 9f7f750..0000000 --- a/dmlcloud/training/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from .classification import ClassificationTrainer -from .trainer import BaseTrainer - -__all__ = [ - 'BaseTrainer', - 'ClassificationTrainer', -] -assert __all__ == sorted(__all__) diff --git a/dmlcloud/training/checkpoint.py b/dmlcloud/training/checkpoint.py deleted file mode 100644 index 32b7bbf..0000000 --- a/dmlcloud/training/checkpoint.py +++ /dev/null @@ -1,125 +0,0 @@ -import json -import logging -import os -from datetime import datetime - -import horovod.torch as hvd -from wandb.sdk.lib.runid import generate_id - -from dmlcloud.config import DefaultConfig - - -class ExtendedJSONEncoder(json.JSONEncoder): - """ - JSONEncoder subclass that serializes classes and functions as well (by their name). - """ - - def default(self, o): - if isinstance(o, type): - return f'' - elif callable(o): - return f'' - - try: - return super().default(o) - except TypeError: - return str(o) - - -class ExtendedJSONDecoder(json.JSONDecoder): - pass - - -def get_config_path(model_dir): - return model_dir / 'config.json' - - -def get_checkpoint_path(model_dir): - return model_dir / 'checkpoint.pth' - - -def get_slurm_id(): - return os.environ.get('SLURM_JOB_ID') - - -def find_old_checkpoint(base_dir): - slurm_id = get_slurm_id() - checkpoints = list(base_dir.glob(f'*-{slurm_id} *')) - - if slurm_id and len(checkpoints) == 1: - # if there exists multiple possible checkpoints, we don't know which one to resume - # usually only happens for interactive sessions - model_dir = checkpoints[0] - job_id = model_dir.name.split(' ')[0] - else: - job_id = None - model_dir = None - - return model_dir, job_id - - -def sanitize_filename(filename): - return filename.replace('/', '_') - - -def generate_job_id(): - slurm_id = get_slurm_id() - job_id = slurm_id if slurm_id else generate_id() - date_str = datetime.now().strftime('%Y%m%d_%H%M%S') - job_id = f'{date_str}-{job_id}' - return job_id - - -def create_project_dir(base_dir, config): - job_id = hvd.broadcast_object(generate_job_id(), name='job_id') - dir_name = job_id - if config.name: - dir_name += ' ' + sanitize_filename(config.name) - - model_dir = hvd.broadcast_object(base_dir / dir_name, name='model_dir') - - if hvd.rank() == 0: - os.makedirs(model_dir) - save_config(get_config_path(model_dir), config) - - return model_dir, job_id - - -def resume_project_dir(config): - config.model_dir, config.job_id = find_old_checkpoint(config.checkpoint_dir) - is_resumed = config.model_dir is not None - if is_resumed: - parsed_dct = load_config_dct(get_config_path(config.model_dir)) - consistency_check(parsed_dct, config) - logging.info(f'Resuming run from {config.model_dir}') - else: - config.model_dir, config.job_id = create_project_dir(config.checkpoint_dir, config) - logging.info(f'Created run directory {config.model_dir}') - - return is_resumed - - -def consistency_check(parsed_dct, config): - parsed_cfg = DefaultConfig(parsed_dct) - if parsed_cfg.git_hash != config.git_hash: - msg = 'Git hash of resumed run does not match current git hash.\n' - msg += f'Current git hash: {config.git_hash}\n' - msg += f'Git hash of resumed run: {parsed_cfg.git_hash}' - logging.warning(msg) - - if parsed_cfg.command_line != config.command_line: - msg = 'Command line of resumed run does not match current command line.\n' - msg += f'Current command line: {config.command_line}\n' - msg += f'Command line of resumed run: {parsed_cfg.command_line}' - logging.warning(msg) - - -def save_config(path, config): - with open(path, 'w') as file: - json.dump(config.as_dictionary(), file, cls=ExtendedJSONEncoder, indent=4) - - -def load_config_dct(path): - with open(path) as file: - dct = json.load(file, cls=ExtendedJSONDecoder) - return dct diff --git a/dmlcloud/training/classification.py b/dmlcloud/training/classification.py deleted file mode 100644 index 4e13ee3..0000000 --- a/dmlcloud/training/classification.py +++ /dev/null @@ -1,23 +0,0 @@ -import torch -from torch import nn - -from dmlcloud.util import accuracy, top5_error -from .trainer import BaseTrainer - - -class ClassificationTrainer(BaseTrainer): - def forward_step(self, batch_idx, batch): - X, label = (tensor.to(self.device, non_blocking=True) for tensor in batch) - pred = self.model(X) - - with torch.no_grad(): - self.log_metric('acc', accuracy(pred, label)) - self.log_metric('top5_error', top5_error(pred, label)) - - return self.loss_fn(pred, label) - - def create_loss(self): - return nn.CrossEntropyLoss() - - def metric_names(self): - return ['train/acc', 'val/acc', 'val/top5_error'] diff --git a/dmlcloud/training/metrics.py b/dmlcloud/training/metrics.py deleted file mode 100644 index 17282ee..0000000 --- a/dmlcloud/training/metrics.py +++ /dev/null @@ -1,128 +0,0 @@ -import csv - -import horovod.torch as hvd -import torch - -Var = 'var' -Std = 'std' -Statistics = 'statistics' - - -class Metric: - def __init__(self, name, reduction=hvd.Average, allreduce=True): - self.name = name - self.reduction = reduction - self.batch_values = [] - self.allreduce = allreduce - assert not allreduce or reduction is not None, 'Cannot allreduce without reduction' - - def _reduce(self, value, dim=None): - if value.dim() == 0: - return value - elif self.reduction == hvd.Average: - return value.mean(dim=dim) - elif self.reduction == hvd.Sum: - return value.sum(dim=dim) - elif self.reduction == hvd.Min: - return value.min(dim=dim)[0] - elif self.reduction == hvd.Max: - return value.max(dim=dim)[0] - elif self.reduction == hvd.Product: - return value.prod(dim=dim) - elif self.reduction == Var: - return value.var(dim=dim) - elif self.reduction == Std: - return value.std(dim=dim) - else: - raise ValueError(f'Unknown reduction {self.reduction}') - - def add_batch_value(self, value): - if isinstance(value, torch.Tensor): - value = value.detach().cpu() # this is very important to avoid memory leaks and for performance - - if self.reduction is None: - self.batch_values.append(value) - else: - tensor = torch.as_tensor(value, device='cpu') - tensor = self._reduce(tensor, dim=0) - self.batch_values.append(tensor) - - def reduce_locally(self): - if self.reduction is None: - return self.batch_values[0] - tensor = torch.stack(self.batch_values) - tensor = self._reduce(tensor, dim=0) - return tensor - - def reduce(self): - tensor = self.reduce_locally() - if self.allreduce: - if self.reduction is Var: - return hvd.allreduce_( - tensor, op=hvd.Average, name=f'metric/{self.name}' - ) # not properly bessel corrected, but close enough - elif self.reduction is Std: - var = hvd.allreduce(tensor**2, op=hvd.Average, name=f'metric/{self.name}') - return torch.sqrt(var) - else: - return hvd.allreduce_(tensor, op=self.reduction, name=f'metric/{self.name}') - else: - return tensor - - -class MetricSaver: - def __init__(self, epochs=None): - self.epochs = epochs or [] - self.current_metrics = {} - - @property - def last(self): - return self.epochs[-1] - - def get_metrics(self, name): - return [epoch[name] for epoch in self.epochs] - - def reduce(self): - reduced = {} - for name, metric in self.current_metrics.items(): - reduced[name] = metric.reduce() - self.epochs.append(reduced) - self.current_metrics = {} - - def log_metric(self, name, value, reduction=hvd.Average, allreduce=True): - if reduction == Statistics: - self.log_metric(f'{name}/mean', value, reduction=hvd.Average, allreduce=allreduce) - self.log_metric(f'{name}/std', value, reduction=Std, allreduce=allreduce) - self.log_metric(f'{name}/min', value, reduction=hvd.Min, allreduce=allreduce) - self.log_metric(f'{name}/max', value, reduction=hvd.Max, allreduce=allreduce) - else: - if name not in self.current_metrics: - self.current_metrics[name] = Metric(name, reduction, allreduce) - metric = self.current_metrics[name] - metric.add_batch_value(value) - - def log_python_object(self, name, value): - self.log_metric(name, value, reduction=None, allreduce=False) - - def scalar_metrics(self, with_epoch=False): - scalars = [] - for epoch, metrics in enumerate(self.epochs): - dct = {} - if with_epoch: - dct['epoch'] = epoch + 1 - - for name, value in metrics.items(): - if isinstance(value, torch.Tensor) and value.dim() == 0: - dct[name] = value.item() - elif not isinstance(value, torch.Tensor): - dct[name] = value - scalars.append(dct) - - return scalars - - def scalars_to_csv(self, path): - with open(path, 'w') as file: - scalar_metrics = self.scalar_metrics(with_epoch=True) - writer = csv.DictWriter(file, fieldnames=scalar_metrics[0].keys()) - writer.writeheader() - writer.writerows(scalar_metrics) diff --git a/dmlcloud/training/scaling.py b/dmlcloud/training/scaling.py deleted file mode 100644 index 0e07823..0000000 --- a/dmlcloud/training/scaling.py +++ /dev/null @@ -1,50 +0,0 @@ -import horovod.torch as hvd - - -def scale_lr(lr, per_worker_batch_size, base_batch_size, use_adasum, use_gpu): - if use_adasum and hvd.nccl_built() and use_gpu: - lr_scaling = hvd.local_size() # gpu adasum needs scaling by local size - elif use_adasum: - lr_scaling = 1.0 # cpu adasum doesn't need per_batch_scaling - else: - lr_scaling = hvd.size() - lr_scaling *= per_worker_batch_size / base_batch_size - return lr * lr_scaling, lr_scaling - - -def scale_beta1(beta1, per_worker_batch_size, base_batch_size): - factor = hvd.size() * per_worker_batch_size / base_batch_size - return beta1**factor - - -def scale_beta2(beta2, per_worker_batch_size, base_batch_size): - factor = hvd.size() * per_worker_batch_size / base_batch_size - return beta2**factor - - -def scale_param_group(param_group, config, use_gpu): - lr_enabled = param_group['scale_lr'] if 'scale_lr' in param_group else config.scale_lr - beta1_enabled = param_group['scale_beta1'] if 'scale_beta1' in param_group else config.scale_beta1 - beta2_enabled = param_group['scale_beta2'] if 'scale_beta2' in param_group else config.scale_beta2 - - scaled_params = dict(param_group) - if 'lr' in param_group and lr_enabled: - scaled_params['lr'], _ = scale_lr( - param_group['lr'], config.batch_size, config.base_batch_size, config.adasum, use_gpu - ) - - if 'betas' in param_group: - if beta1_enabled: - beta1 = scale_beta1(param_group['betas'][0], config.batch_size, config.base_batch_size) - else: - beta1 = param_group['betas'][0] - if beta2_enabled: - beta2 = scale_beta2(param_group['betas'][1], config.batch_size, config.base_batch_size) - else: - beta2 = param_group['betas'][1] - scaled_params['betas'] = (beta1, beta2) - - if 'momentum' in param_group and beta1_enabled: - scaled_params['momentum'] = scale_beta1(param_group['momentum'], config.batch_size, config.base_batch_size) - - return scaled_params diff --git a/dmlcloud/training/trainer.py b/dmlcloud/training/trainer.py deleted file mode 100644 index 3a3064d..0000000 --- a/dmlcloud/training/trainer.py +++ /dev/null @@ -1,606 +0,0 @@ -import logging -import random -import sys -from contextlib import nullcontext -from datetime import datetime, timedelta - -import horovod.torch as hvd -import numpy as np -import torch -import wandb -from progress_table import ProgressTable -from torch.cuda.amp import autocast, GradScaler -from torch.optim.lr_scheduler import ChainedScheduler, LinearLR - -from dmlcloud.util import ( - hvd_is_initialized, - hvd_print_worker, - project_dir, - script_path, - setup_horovod, - wandb_is_initialized, - wandb_set_startup_timeout, -) -from .checkpoint import resume_project_dir -from .metrics import MetricSaver -from .scaling import scale_lr, scale_param_group -from .util import ( - git_hash, - global_grad_norm, - log_config, - log_delimiter, - log_diagnostics, - log_git, - log_model, - setup_logging, -) - - -class TrainerInterface: - """ - These methods must be implemented for each experiment - """ - - def create_loss(self): - """ - Returns a loss function. - Will be available as self.loss_fn. - """ - return None - - def create_dataset(self): - """ - Returns a tuple of (train_dl, val_dl). - Will be available as self.train_dl and self.val_dl. - These shall be iterators that yield batches. - """ - raise NotImplementedError() - - def create_model(self): - """ - Returns a torch.nn.Module. - Will be available as self.model. - If you need multiple networks, e.g. for GANs, wrap them in a nn.Module. - """ - raise NotImplementedError() - - def create_optimizer(self, params, lr): - """ - Returns an optimizer. - Will be available as self.optimizer. - """ - raise NotImplementedError() - - def create_scheduler(self): - """ - Returns a scheduler or None. - """ - return None - - def forward_step(self, batch_idx, batch): - """ - Performs a forward pass and returns the loss. - """ - raise NotImplementedError() - - -class BaseTrainer(TrainerInterface): - @staticmethod - def root_only(fn): - """ - Decorator for methods that should only be called on the root rank. - """ - - def wrapper(self, *args, **kwargs): - if self.is_root: - return fn(self, *args, **kwargs) - - return wrapper - - def __init__(self, config, val_loss_name='loss'): - self.cfg = config - self.val_loss_name = val_loss_name - self.reset() - - def reset(self): - self.initialized = False - self.is_resumed = False - self.train_metrics = MetricSaver() - self.val_metrics = MetricSaver() - self.misc_metrics = MetricSaver() - self.epoch = 1 - self.mode = 'train' - - @property - def use_checkpointing(self): - return self.model_dir is not None - - @property - def is_gpu(self): - return self.device.type == 'cuda' - - @property - def is_root(self): - return hvd.rank() == 0 - - @property - def is_train(self): - return self.mode == 'train' - - @property - def is_eval(self): - return not self.is_train - - @property - def model_dir(self): - return self.cfg.model_dir - - @property - def job_id(self): - return self.cfg.job_id - - def setup_all(self, use_checkpointing=True, use_wandb=True, print_diagnostics=True): - if self.initialized: - raise ValueError('Trainer already initialized! Call reset() first.') - - if not hvd_is_initialized(): - setup_horovod() - - self.seed() - self.setup_general() - - if use_checkpointing: - self.is_resumed = resume_project_dir(self.cfg) - - if use_wandb: - self.setup_wandb() - - if print_diagnostics: - self.print_diagnositcs() - - self.setup_loss() - self.setup_dataset() - self.setup_model() - self.setup_optimizer() - self.resume_training() - - if print_diagnostics: - log_config(self.cfg) - - self.setup_table() - - hvd.broadcast_parameters(self.model.state_dict(), root_rank=0) - hvd.broadcast_optimizer_state(self.optimizer, root_rank=0) - - self.initialized = True - - def print_diagnositcs(self): - log_delimiter() - logging.info(f'Script path: {script_path() or "N/A"}') - logging.info(f'Project dir: {project_dir() or "N/A"}') - log_git() - log_diagnostics(self.device) - - @root_only - def setup_table(self): - self.table = ProgressTable(columns=self.metric_names(), print_row_on_update=False) - - def setup_general(self): - setup_logging() - - if torch.cuda.is_available(): - self.device = torch.device('cuda', hvd.local_rank()) - else: - self.device = torch.device('cpu') - - torch.set_num_threads(8) - self.cfg.git_hash = git_hash() - - def seed(self): - if self.cfg.seed is None: - seed = int.from_bytes(random.randbytes(4), byteorder='little') - self.cfg.seed = hvd.broadcast_object(seed) - - np.random.seed(self.cfg.seed) - random.seed(self.cfg.seed) - torch.manual_seed(self.cfg.seed) - torch.cuda.manual_seed(self.cfg.seed) - - @root_only - def setup_wandb(self): - wandb_set_startup_timeout(600) - exp_name = self.cfg.wb_name if self.cfg.wb_name else self.cfg.name - wandb.init( - project=self.cfg.wb_project, - name=exp_name, - tags=self.cfg.wb_tags, - dir=self.model_dir, - id=self.job_id, - resume='must' if self.is_resumed else 'never', - config=self.cfg.as_dictionary(), - ) - - def setup_dataset(self): - logging.info('Creating dataset') - hvd.barrier() - ts = datetime.now() - if self.is_root: - self.train_dl, self.val_dl = self.create_dataset() - hvd.barrier() - else: - hvd.barrier() # wait until rank 0 has created the dataset (e.g. downloaded it) - self.train_dl, self.val_dl = self.create_dataset() - logging.info(f'Dataset creation took {(datetime.now() - ts).total_seconds():.1f}s') - - if hasattr(self.train_dl, 'dataset') and hasattr(self.train_dl.dataset, '__len__'): - train_samples = f'{len(self.train_dl.dataset)}' - else: - train_samples = 'N/A' - train_sizes = hvd.allgather(torch.tensor([len(self.train_dl)]), name='train_dataset_size') - train_sizes = [t.item() for t in train_sizes] - msg = 'Train dataset:' - msg += f'\n\t* Batches: {train_sizes[0]}' - msg += f'\n\t* Batches (total): {sum(train_sizes)}' - msg += f'\n\t* Samples (calculated): {sum(train_sizes) * self.cfg.batch_size}' - msg += f'\n\t* Samples (raw): {train_samples}' - logging.info(msg) - if len(set(train_sizes)) > 1 and self.is_root: - logging.warning(f'!!! Uneven train dataset batches: {train_sizes}') - - if self.val_dl is not None: - if hasattr(self.val_dl, 'dataset') and hasattr(self.val_dl.dataset, '__len__'): - val_samples = f'{len(self.val_dl.dataset)}' - else: - val_samples = 'N/A' - - val_sizes = hvd.allgather(torch.tensor([len(self.val_dl)]), name='val_dataset_size') - val_sizes = [t.item() for t in val_sizes] - msg = 'Train dataset:' - msg += f'\n\t* Batches: {val_sizes[0]}' - msg += f'\n\t* Batches (total): {sum(val_sizes)}' - msg += f'\n\t* Samples (calculated): {sum(val_sizes) * self.cfg.batch_size}' - msg += f'\n\t* Samples (raw): {val_samples}' - logging.info(msg) - if len(set(val_sizes)) > 1 and self.is_root: - logging.warning(f'!!! Uneven val dataset batches: {val_sizes}') - - log_delimiter() - - def setup_model(self): - logging.info('Creating model') - self.model = self.create_model().to(self.device) - log_model(self.model) - if self.is_root and self.use_checkpointing: - with open(self.model_dir / 'model.txt', 'w') as f: - f.write(str(self.model)) - log_delimiter() - - def setup_loss(self): - self.loss_fn = self.create_loss() - - def setup_optimizer(self): - logging.info('Creating optimizer') - optimizer = self.create_optimizer(self.model.parameters(), self.cfg.base_lr) - lr_scale_factor = self.scale_optimizer(optimizer) - self.optimizer = hvd.DistributedOptimizer( - optimizer, named_parameters=self.model.named_parameters(), op=hvd.Adasum if self.cfg.adasum else hvd.Average - ) - - schedulers = [] - if self.cfg.rampup_epochs: - linear_warmup = LinearLR( - self.optimizer, start_factor=1 / lr_scale_factor, end_factor=1.0, total_iters=self.cfg.rampup_epochs - ) - schedulers.append(linear_warmup) - - user_scheduler = self.create_scheduler() - if isinstance(user_scheduler, list): - schedulers.extend(user_scheduler) - elif user_scheduler is not None: - schedulers.append(user_scheduler) - - self.scheduler = ChainedScheduler(schedulers) - self.scaler = GradScaler(enabled=self.cfg.mixed) - - def scale_optimizer(self, optimizer): - use_gpu = self.device.type == 'cuda' - _, lr_scale_factor = scale_lr( - optimizer.defaults['lr'], self.cfg.batch_size, self.cfg.base_batch_size, self.cfg.adasum, use_gpu - ) - logging.info(f'LR Scale Factor: {lr_scale_factor}') - logging.info('Param-Groups:') - for i, param_group in enumerate(optimizer.param_groups): - param_group_cpy = dict(param_group) - scaled_params = scale_param_group(param_group, self.cfg, use_gpu) - scaled_params_cpy = dict(scaled_params) - optimizer.param_groups[i] = scaled_params - del param_group_cpy['params'] - del scaled_params_cpy['params'] - logging.info(f'[P{i}] Pre-scaled: {param_group_cpy}') - logging.info(f'[P{i}] Scaled: {scaled_params_cpy}') - log_delimiter() - return lr_scale_factor - - def load_state_dict(self, state_dict): - self.epoch = state_dict['epoch'] - self.train_metrics = MetricSaver(state_dict['train_metrics']) - self.val_metrics = MetricSaver(state_dict['val_metrics']) - self.misc_metrics = MetricSaver(state_dict['misc_metrics']) - self.model.load_state_dict(state_dict['model_state']) - self.optimizer.load_state_dict(state_dict['optimizer_state']) - self.scheduler.load_state_dict(state_dict['scheduler_state']) - self.scaler.load_state_dict(state_dict['scaler_state']) - - def state_dict(self): - state_dict = { - 'epoch': self.epoch, - 'train_metrics': self.train_metrics.epochs, - 'val_metrics': self.val_metrics.epochs, - 'misc_metrics': self.misc_metrics.epochs, - 'model_state': self.model.state_dict(), - 'optimizer_state': self.optimizer.state_dict(), - 'scheduler_state': self.scheduler.state_dict(), - 'scaler_state': self.scaler.state_dict(), - } - return state_dict - - def load_checkpoint(self, path): - state_dict = torch.load(path, map_location=self.device) - self.load_state_dict(state_dict) - self.epoch += 1 - - def resume_training(self): - if not self.use_checkpointing: - return - - cp_path = self.model_dir / 'checkpoint.pt' - if cp_path.exists(): - self.load_checkpoint(cp_path) - logging.info(f'Loaded checkpoint from {cp_path}') - logging.info( - f'Continuing training at epoch {self.epoch}, previous loss: {self.train_metrics.last["loss"]:.3f}' - ) - elif self.is_resumed: - logging.critical('No checkpoint found!') - sys.exit(1) - - @root_only - def save_checkpoint(self): - if not self.use_checkpointing: - return - - checkpoint_path = self.model_dir / 'checkpoint.pt' - best_path = self.model_dir / 'best.pt' - - torch.save(self.state_dict(), checkpoint_path) - if self.is_best_epoch(): - torch.save(self.state_dict(), best_path) - if wandb_is_initialized(): - wandb.save(str(best_path), policy='now', base_path=str(self.model_dir)) - - self.train_metrics.scalars_to_csv(self.model_dir / 'train_metrics.csv') - self.val_metrics.scalars_to_csv(self.model_dir / 'val_metrics.csv') - self.misc_metrics.scalars_to_csv(self.model_dir / 'misc_metrics.csv') - - def is_best_epoch(self): - best_val_loss = min(self.val_metrics.get_metrics(self.val_loss_name)) - return self.val_metrics.last[self.val_loss_name] == best_val_loss - - def _update_table_entries(self): - metric_names = self.metric_names() - if isinstance(metric_names, list): - metric_names = {name: name for name in metric_names} - for display_name, metric_name in metric_names.items(): - splits = metric_name.split('/', 1) - value = None - if metric_name == 'Epoch': - value = str(self.epoch) - elif metric_name == 'ETA': - value = str(self.misc_metrics.last['eta']) - elif metric_name == 'ms/batch': - value = f'{self.misc_metrics.last["ms_per_batch"]:.1f}' - elif metric_name == 'time/epoch': - value = str(self.misc_metrics.last['time_per_epoch']) - elif len(splits) == 2: - group, key = splits[0], splits[1] - metrics = self.train_metrics if group == 'train' else self.val_metrics - if key in metrics.last: - value = metrics.last[key] - else: - raise ValueError(f'Invalid metric name: {metric_name}') - self.table[display_name] = value - self.table.next_row() - - @root_only - def log_epoch(self): - self._update_table_entries() - if wandb_is_initialized(): - self.log_wandb() - - def log_wandb(self): - metrics = {} - for key, value in self.train_metrics.scalar_metrics()[-1].items(): - metrics[f'train/{key}'] = value - - for key, value in self.val_metrics.scalar_metrics()[-1].items(): - metrics[f'val/{key}'] = value - - for key, value in self.misc_metrics.scalar_metrics()[-1].items(): - metrics[f'misc/{key}'] = value - - wandb.log(metrics) - if self.is_best_epoch(): - wandb.run.summary['best/epoch'] = self.epoch - for key, value in metrics.items(): - if not key.startswith('misc'): - wandb.run.summary[f'best/{key}'] = value - - def forward_step(self, batch_idx, batch): - raise NotImplementedError() - - def switch_mode(self, train=True): - if train: - self.model.train() - self.current_metrics = self.train_metrics - self.mode = 'train' - else: - self.model.eval() - self.current_metrics = self.val_metrics - self.mode = 'eval' - - def pre_train(self): - self.epoch_train_start = datetime.now() - - def post_train(self): - self.epoch_train_end = datetime.now() - - def train_epoch(self, max_steps=None): - self.pre_train() - self.switch_mode(train=True) - - # Do this now, and not later, to immidiately show that a new epoch has started - if self.is_root and 'Epoch' in self.table.columns: - self.table['Epoch'] = self.epoch - - if hasattr(self.train_dl, 'sampler') and hasattr(self.train_dl.sampler, 'set_epoch'): - self.train_dl.sampler.set_epoch(self.epoch) - - nan_ctx_manager = torch.autograd.detect_anomaly() if self.cfg.check_nans else nullcontext() - for batch_idx, batch in enumerate(self.train_dl): - if max_steps and batch_idx >= max_steps: - break - - self.optimizer.zero_grad() - - with nan_ctx_manager: - # forward pass - with autocast(enabled=self.cfg.mixed): - loss = self.forward_step(batch_idx, batch) - # backward pass - self.scaler.scale(loss).backward() # scale loss and, in turn, gradients to prevent underflow - - if loss.isnan() and not self.scaler.is_enabled(): - logging.critical( - 'Got NaN loss but mixed precision training is disabled! This might be due to NaN values in the data or from diverging training.' - ) - sys.exit(1) - - self.optimizer.synchronize() # make sure all async allreduces are done - self.scaler.unscale_(self.optimizer) # now, unscale gradients again - - if self.cfg.log_gradients: - norm = global_grad_norm(self.model.parameters()) - self.log_metric('grad_norm', norm, allreduce=False, reduction='statistics') - if self.cfg.clip_gradients: - self.log_metric('grad_norm/n_clipped', norm > self.cfg.clip_gradients, hvd.Sum, allreduce=False) - - if self.cfg.clip_gradients: - torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.clip_gradients) - - with self.optimizer.skip_synchronize(): # we already synchronized manually - self.scaler.step(self.optimizer) - self.scaler.update() # adjust gradient scaling based on number of infs/nans - - if not torch.isnan(loss): # mixed-precision might produce nan steps - self.log_metric('loss', loss) - self.misc_metrics.log_metric('n_nan', 0, hvd.Sum) - else: - self.misc_metrics.log_metric('n_nan', 1, hvd.Sum) - - self.misc_metrics.log_metric('n_steps', 1, hvd.Sum, allreduce=False) - self.misc_metrics.log_metric('n_total_batches', 1, hvd.Sum, allreduce=True) - - hvd.join() # prevents hangup on allreduce() due to uneven sharding - - self.misc_metrics.log_python_object('lr', self.scheduler.get_last_lr()[0]) - for k, v in self.scaler.state_dict().items(): - self.misc_metrics.log_python_object(f'scaler/{k}', v) - - if self.scheduler is not None: - self.scheduler.step() - - self.post_train() - - def pre_eval(self): - self.epoch_eval_start = datetime.now() - - def post_eval(self): - self.epoch_eval_end = datetime.now() - - n_remaining = self.cfg.epochs - self.epoch - per_epoch = (self.epoch_eval_end - self.start_time) / self.epoch - per_epoch -= timedelta(microseconds=per_epoch.microseconds) - eta = n_remaining * per_epoch - self.misc_metrics.log_python_object('eta', str(eta)) - self.misc_metrics.log_python_object('time_per_epoch', str(per_epoch)) - - n_train_batches = self.misc_metrics.current_metrics['n_total_batches'].reduce().item() - per_step = (self.epoch_train_end - self.epoch_train_start) / n_train_batches - per_step = per_step.total_seconds() * 1000 - self.misc_metrics.log_python_object('ms_per_batch', per_step) - - self.reduce_metrics() - self.log_epoch() - self.save_checkpoint() - self.epoch += 1 - - def reduce_metrics(self): - self.train_metrics.reduce() - self.val_metrics.reduce() - self.misc_metrics.reduce() - - def evaluate_epoch(self, max_steps=None): - self.pre_eval() - self.switch_mode(train=False) - - if self.val_dl is not None: - for batch_idx, batch in enumerate(self.val_dl): - if max_steps and batch_idx >= max_steps: - break - - with torch.no_grad(): - loss = self.forward_step(batch_idx, batch).item() - self.log_metric('loss', loss) - - hvd.join() # prevents hangup on allreduce() due to uneven sharding - - self.post_eval() - - def pre_training(self): - hvd_print_worker('READY') - self.start_time = datetime.now() - logging.info('Starting training...') - - def post_training(self): - if self.is_root: - self.table.close() - logging.info('Training finished.') - - def train(self, max_steps=None, use_checkpointing=True, use_wandb=True, print_diagnostics=True): - if not self.initialized: - self.setup_all( - use_checkpointing=use_checkpointing, use_wandb=use_wandb, print_diagnostics=print_diagnostics - ) - - self.pre_training() - while self.epoch <= self.cfg.epochs: - self.train_epoch(max_steps) - self.evaluate_epoch(max_steps) - self.post_training() - - def log_metric(self, name, value, reduction=hvd.Average, allreduce=True): - self.current_metrics.log_metric(name, value, reduction, allreduce) - - def log_python_object(self, name, value): - self.current_metrics.log_python_object(name, value, reduction=None, allreduce=False) - - def metric_names(self): - """ - Returns a list or dictionary with custom metrics that are displayed during training. - If a dictionary is returned, the keys are used as display names. - """ - columns = ['Epoch', 'ETA', 'train/loss'] - if self.val_dl is not None: - columns += [f'val/{self.val_loss_name}'] - columns += ['ms/batch', 'time/epoch'] - return columns diff --git a/dmlcloud/training/util.py b/dmlcloud/training/util.py deleted file mode 100644 index 6fc4006..0000000 --- a/dmlcloud/training/util.py +++ /dev/null @@ -1,116 +0,0 @@ -import json -import logging -import os -import sys - -import horovod.torch as hvd -import torch -from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype - -from dmlcloud.util import git_diff, git_hash, hvd_print_worker -from .checkpoint import ExtendedJSONEncoder - - -def setup_logging(): - root_logger = logging.getLogger() - root_logger.setLevel(logging.INFO if hvd.rank() == 0 else logging.WARNING) - - if root_logger.hasHandlers(): - return - - stdout_handler = logging.StreamHandler(sys.stdout) - stdout_handler.setLevel(logging.DEBUG) - stdout_handler.setFormatter(logging.Formatter()) - stderr_handler = logging.StreamHandler() - stderr_handler.setLevel(logging.WARNING) - stderr_handler.setFormatter(logging.Formatter()) - - root_logger.addHandler(stdout_handler) - root_logger.addHandler(stderr_handler) - - -def delimiter(n=40, newline=True): - delim = '-' * n - if newline: - delim += '\n' - return delim - - -def log_delimiter(n=40): - logging.info(delimiter(n, newline=False)) - - -def log_diagnostics(device): - msg = f'Training distributed on {hvd.size()} workers/gpus\n' - msg += delimiter() - msg += f'SLURM_JOB_ID = {os.environ.get("SLURM_JOB_ID")}\n' - msg += f'SLURM_STEP_ID = {os.environ.get("SLURM_STEP_ID")}\n' - msg += f'SLURM_STEP_NODELIST = {os.environ.get("SLURM_STEP_NODELIST")}\n' - msg += f'SLURM_TASKS_PER_NODE = {os.environ.get("SLURM_TASKS_PER_NODE")}\n' - msg += f'SLURM_STEP_GPUS = {os.environ.get("SLURM_STEP_GPUS")}\n' - msg += f'SLURM_GPUS_ON_NODE = {os.environ.get("SLURM_GPUS_ON_NODE")}\n' - msg += f'SLURM_CPUS_PER_TASK = {os.environ.get("SLURM_CPUS_PER_TASK")}\n' - msg += f'SLURM_CPU_BIND_LIST = {os.environ.get("SLURM_CPU_BIND_LIST")}\n' - msg += delimiter() - msg += f'MPI built: {hvd.mpi_built()}\n' - msg += f'NCCL built: {hvd.nccl_built() > 0}\n' - msg += f'Gloo built: {hvd.gloo_built()}\n' - msg += f'CUDA built: {hvd.cuda_built()}\n' - msg += f'DDL built: {hvd.ddl_built()}\n' - msg += f'ROCm built: {hvd.rocm_built()}\n' - msg += f'oneCCL built: {hvd.ccl_built()}\n' - msg += delimiter() - msg += f'MPI enabled: {hvd.mpi_enabled()}\n' - msg += f'Gloo enabled: {hvd.gloo_enabled()}\n' - msg += delimiter() - msg += f'CUDA_VISIBLE_DEVICES = {os.environ.get("CUDA_VISIBLE_DEVICES")}\n' - msg += f'Device count: {torch.cuda.device_count()}' - logging.info(msg) - hvd_print_worker(f'Using {device}') - log_delimiter() - - -def log_config(config): - msg = 'CONFIG:\n' - msg += json.dumps(config.dct, indent=4, cls=ExtendedJSONEncoder) + '\n' - msg += delimiter(newline=False) - logging.info(msg) - - -def log_git(): - msg = f'Git Hash: {git_hash() or "N/A"}\n' - msg += f'Git Diff:\n{git_diff() or "N/A"}\n' - msg += delimiter() - logging.info(msg) - - -def log_model(model): - n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - n_non_trainable_params = sum(p.numel() for p in model.parameters() if not p.requires_grad) - msg = f'# trainable parameters: {n_trainable_params/1e6:.1f}M\n' - msg += f'# non-trainable parameters: {n_non_trainable_params/1e6:.1f}M' - logging.info(msg) - - -def global_grad_norm(parameters, norm_type=2.0): - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - grads = [p.grad for p in parameters if p.grad is not None] - norm_type = float(norm_type) - if len(grads) == 0: - return torch.tensor(0.0) - - first_device = grads[0].device - grouped_grads = _group_tensors_by_device_and_dtype([[g.detach() for g in grads]]) - - if norm_type == torch.inf: - norms = [g.detach().abs().max().to(first_device) for g in grads] - total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) - else: - norms = [] - for (device, _), [grads] in grouped_grads.items(): - norms.extend([torch.norm(g, norm_type) for g in grads]) - - total_norm = torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type) - - return total_norm diff --git a/dmlcloud/util/__init__.py b/dmlcloud/util/__init__.py index 327b882..e69de29 100644 --- a/dmlcloud/util/__init__.py +++ b/dmlcloud/util/__init__.py @@ -1,25 +0,0 @@ -from dmlcloud.util.project import project_dir, run_in_project, script_dir, script_path -from .evaluation import accuracy, top5_error -from .git import git_diff, git_hash -from .horovod import hvd_allreduce, hvd_is_initialized, hvd_print_worker, setup_horovod, shard_indices -from .util import EnumAction -from .wandb import wandb_is_initialized, wandb_set_startup_timeout - -__all__ = [ - 'accuracy', - 'top5_error', - 'git_diff', - 'git_hash', - 'EnumAction', - 'hvd_is_initialized', - 'hvd_print_worker', - 'hvd_allreduce', - 'setup_horovod', - 'shard_indices', - 'wandb_is_initialized', - 'wandb_set_startup_timeout', - 'run_in_project', - 'script_dir', - 'script_path', - 'project_dir', -] diff --git a/dmlcloud/util/util.py b/dmlcloud/util/argparse.py similarity index 80% rename from dmlcloud/util/util.py rename to dmlcloud/util/argparse.py index 844bf24..41fc4cb 100644 --- a/dmlcloud/util/util.py +++ b/dmlcloud/util/argparse.py @@ -1,17 +1,5 @@ import argparse import enum -import os - -import wandb - - -def set_wandb_startup_timeout(seconds: int): - assert isinstance(seconds, int) - os.environ['WANDB__SERVICE_WAIT'] = f'{seconds}' - - -def is_wandb_initialized(): - return wandb.run is not None class EnumAction(argparse.Action): diff --git a/dmlcloud/util/distributed.py b/dmlcloud/util/distributed.py new file mode 100644 index 0000000..13768fe --- /dev/null +++ b/dmlcloud/util/distributed.py @@ -0,0 +1,183 @@ +import os +from contextlib import contextmanager + +import numpy as np +import torch.distributed as dist + +from .tcp import find_free_port, get_local_ips + + +def is_root(): + return dist.get_rank() == 0 + + +def root_only(fn): + """ + Decorator for methods that should only be called on the root rank. + """ + + def wrapper(*args, **kwargs): + if is_root(): + return fn(*args, **kwargs) + + return wrapper + + +@contextmanager +def root_first(): + """ + Context manager that ensures that the root rank executes the code first before all other ranks + """ + if is_root(): + try: + yield + finally: + dist.barrier() + else: + dist.barrier() + try: + yield + finally: + pass + + +def mpi_local_comm(): + try: + from mpi4py import MPI + + comm = MPI.COMM_WORLD + local_comm = comm.Split_type(MPI.COMM_TYPE_SHARED, 0, MPI.INFO_NULL) + return local_comm + except ImportError: + return None + + +def local_rank(): + if 'LOCAL_RANK' in os.environ: + return int(os.environ["LOCAL_RANK"]) + local_comm = mpi_local_comm() + if local_comm is not None: + return local_comm.Get_rank() + else: + return None + + +def local_size(): + if 'LOCAL_WORLD_SIZE' in os.environ: + return int(os.environ["LOCAL_WORLD_SIZE"]) + local_comm = mpi_local_comm() + if local_comm is not None: + return local_comm.Get_size() + else: + return None + + +def print_worker(msg, barrier=True, flush=True): + if barrier: + dist.barrier() + print(f'Worker {dist.get_rank()} ({dist.get_group_rank()}.{dist.get_process_group_ranks()}): {msg}', flush=flush) + if barrier: + dist.barrier() + + +def shard_indices(n, rank, size, shuffle=True, drop_remainder=False, seed=0): + indices = np.arange(n) + + if shuffle: + np.random.Generator(np.random.MT19937(seed)).shuffle(indices) + + if drop_remainder: + indices = indices[: n - n % size] + + return indices[rank::size] + + +def init_process_group_dummy(): + """ + Initializes the process group with a single process. + Uses HashStore under the hood. Useful for applications that + only run on a single gpu. + """ + store = dist.HashStore() + dist.init_process_group(store=store, rank=0, world_size=1, backend='gloo') + + +def init_process_group_MPI(ip_idx=0, port=None, **kwargs): + """ + This method setups up the distributed backend using MPI, even + if torch was not built with MPI support. For this to work, you + need to have mpi4py installed and the root rank must be reachable + via TCP. + + If port is None, we will automatically try to find a free port. + + ip_idx can be used to specify which IP address to use if the root + has multiple IP addresses. The default is 0, which means the first. + + kwargs are passed to torch.distributed.init_process_group. + """ + from mpi4py import MPI + + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + size = comm.Get_size() + + port = find_free_port() if port is None and rank == 0 else None + + if rank == 0: + ip = get_local_ips()[ip_idx] + else: + ip = None + + ip = comm.bcast(ip, root=0) + port = comm.bcast(port, root=0) + url = f'tcp://{ip}:{port}' + + comm.Barrier() + + dist.init_process_group( + init_method=url, + world_size=size, + rank=rank, + **kwargs, + ) + + return rank, size + + +def init_process_group_auto(ip_idx=0, port=None, **kwargs): + """ + Tries to initialize torch.distributed in the following order: + 1. If the MASTER_PORT environment variable is set, use environment variable initialization + 2. If a MPI context is available, e.g. from slurm or mpirun, use MPI to exchange ip addresses (see init_process_group_MPI) + 3. Otherwise, use a single process group (see init_process_group_dummy) + """ + + # determine init method + method = 'dummy' + if os.environ.get('MASTER_PORT'): + method = 'env' + else: + try: + from mpi4py import MPI + + if MPI.COMM_WORLD.Get_size() > 1: + method = 'MPI' + except ImportError: + pass + + if method == 'env': + dist.init_process_group(init_method='env://', **kwargs) + elif method == 'MPI': + init_process_group_MPI(ip_idx=ip_idx, port=port, **kwargs) + else: + init_process_group_dummy() + + +def deinitialize_torch_distributed(): + """ + Deinitializes the torch distributed framework. + At the time of writing, `dist.destroy_process_group()` is not well documented. + Hence, this function. + """ + dist.destroy_process_group() diff --git a/dmlcloud/util/evaluation.py b/dmlcloud/util/evaluation.py deleted file mode 100644 index cd70d72..0000000 --- a/dmlcloud/util/evaluation.py +++ /dev/null @@ -1,11 +0,0 @@ -def accuracy(pred, label): - if pred.shape[1] == 1: - return ((pred > 0).float() == label).float().mean() - else: - return (pred.argmax(dim=1) == label).float().mean() - - -def top5_error(pred, label): - top5_indices = pred.topk(5, dim=1)[1] - top5_error = 1 - (top5_indices == label.unsqueeze(1)).float().max(dim=1)[0].mean() - return top5_error diff --git a/dmlcloud/util/git.py b/dmlcloud/util/git.py index 4b38dcd..1e499af 100644 --- a/dmlcloud/util/git.py +++ b/dmlcloud/util/git.py @@ -1,9 +1,7 @@ -from .project import run_in_project, script_path +from .project import run_in_project def git_hash(short=False): - if script_path() is None: - return None if short: process = run_in_project(['git', 'rev-parse', '--short', 'HEAD']) else: @@ -12,7 +10,5 @@ def git_hash(short=False): def git_diff(): - if script_path() is None: - return None process = run_in_project(['git', 'diff', '-U0', '--no-color', 'HEAD']) return process.stdout.decode('utf-8').strip() diff --git a/dmlcloud/util/horovod.py b/dmlcloud/util/horovod.py deleted file mode 100644 index ce2a5d1..0000000 --- a/dmlcloud/util/horovod.py +++ /dev/null @@ -1,56 +0,0 @@ -import os -import sys - -import horovod.torch as hvd -import numpy as np -import torch - - -def hvd_print_worker(msg, barrier=True, flush=True): - if barrier: - hvd.barrier() - print(f'Worker {hvd.rank()} ({hvd.cross_rank()}.{hvd.local_rank()}): {msg}', flush=flush) - if barrier: - hvd.barrier() - - -def setup_horovod(print_status=True): - hvd.init() - n_tasks = int(os.environ.get('SLURM_NTASKS', 0)) - if n_tasks > 1 and hvd.size() == 1: - print( - 'CRITICAL: Horovod only sees a single task! Run "horovodrun --check-build" an verify that MPI is supported. Terminating...' - ) - sys.exit(1) - - if print_status: - hvd_print_worker('STARTED') - - hvd.barrier() # make sure that all processes are running at this point - # this is very important, otherwise subsequent broadcast operations might time out - - -def hvd_is_initialized(): - try: - hvd.size() - return True - except ValueError: - return False - - -def hvd_allreduce(val, *args, **kwargs): - tensor = torch.as_tensor(val) - reduced = hvd.allreduce(tensor, *args, **kwargs) - return reduced.cpu().numpy() - - -def shard_indices(n, rank, size, shuffle=True, drop_remainder=False, seed=0): - indices = np.arange(n) - - if shuffle: - np.random.Generator(np.random.MT19937(seed)).shuffle(indices) - - if drop_remainder: - indices = indices[: n - n % size] - - return indices[rank::size] diff --git a/dmlcloud/util/logging.py b/dmlcloud/util/logging.py new file mode 100644 index 0000000..4ca4a91 --- /dev/null +++ b/dmlcloud/util/logging.py @@ -0,0 +1,169 @@ +import logging +import os +import subprocess +import sys +from datetime import datetime +from pathlib import Path + +import torch +import torch.distributed as dist + +import dmlcloud +from . import slurm +from .git import git_hash +from .thirdparty import try_get_version + + +class IORedirector: + """ + Context manager to redirect stdout and stderr to a file. + Data is written to the file and the original streams. + """ + + class Stdout: + def __init__(self, parent): + self.parent = parent + + def write(self, data): + self.parent.file.write(data) + self.parent.stdout.write(data) + + def flush(self): + self.parent.file.flush() + self.parent.stdout.flush() + + class Stderr: + def __init__(self, parent): + self.parent = parent + + def write(self, data): + self.parent.file.write(data) + self.parent.stderr.write(data) + + def flush(self): + self.parent.file.flush() + self.parent.stderr.flush() + + def __init__(self, log_file: Path): + self.path = log_file + self.file = None + self.stdout = None + self.stderr = None + + def install(self): + if self.file is not None: + return + + self.file = self.path.open('a') + self.stdout = sys.stdout + self.stderr = sys.stderr + self.stdout.flush() + self.stderr.flush() + + sys.stdout = self.Stdout(self) + sys.stderr = self.Stderr(self) + + def uninstall(self): + self.stdout.flush() + self.stderr.flush() + + sys.stdout = self.stdout + sys.stderr = self.stderr + + self.file.close() + + def __enter__(self): + self.install() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.uninstall() + + +def add_log_handlers(logger: logging.Logger): + if logger.hasHandlers(): + return + + logger.setLevel(logging.INFO if dist.get_rank() == 0 else logging.WARNING) + + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setLevel(logging.DEBUG) + stdout_handler.addFilter(lambda record: record.levelno < logging.WARNING) + stdout_handler.setFormatter(logging.Formatter()) + logger.addHandler(stdout_handler) + + stderr_handler = logging.StreamHandler() + stderr_handler.setLevel(logging.WARNING) + stderr_handler.setFormatter(logging.Formatter()) + logger.addHandler(stderr_handler) + + +def experiment_header( + name: str | None, + checkpoint_dir: str | None, + date: datetime, +) -> str: + msg = f'............... Experiment: {name if name else "N/A"} ...............\n' + msg += f'- Date: {date}\n' + msg += f'- Checkpoint Dir: {checkpoint_dir if checkpoint_dir else "N/A"}\n' + msg += f'- Training on {dist.get_world_size()} GPUs\n' + return msg + + +def general_diagnostics() -> str: + msg = '* GENERAL:\n' + msg += f' - argv: {sys.argv}\n' + msg += f' - cwd: {Path.cwd()}\n' + + msg += f' - host (root): {os.environ.get("HOSTNAME")}\n' + msg += f' - user: {os.environ.get("USER")}\n' + msg += f' - git-hash: {git_hash()}\n' + msg += f' - conda-env: {os.environ.get("CONDA_DEFAULT_ENV", "N/A")}\n' + msg += f' - sys-prefix: {sys.prefix}\n' + msg += f' - backend: {dist.get_backend()}\n' + msg += f' - cuda: {torch.cuda.is_available()}\n' + + if torch.cuda.is_available(): + msg += '* GPUs (root):\n' + nvsmi = subprocess.run(['nvidia-smi', '-L'], stdout=subprocess.PIPE, stderr=subprocess.STDOUT).stdout.decode() + for line in nvsmi.splitlines(): + msg += f' - {line}\n' + + msg += '* VERSIONS:\n' + msg += f' - python: {sys.version}\n' + msg += f' - dmlcloud: {dmlcloud.__version__}\n' + msg += f' - cuda: {torch.version.cuda}\n' + try: + msg += ' - ' + Path('/proc/driver/nvidia/version').read_text().splitlines()[0] + '\n' + except (FileNotFoundError, IndexError): + pass + + msg += f' - torch: {torch.__version__}\n' + if try_get_version('torchvision'): + msg += f' - torchvision: {try_get_version("torchvision")}\n' + if try_get_version('torchtext'): + msg += f' - torchtext: {try_get_version("torchtext")}\n' + if try_get_version('torchaudio'): + msg += f' - torchaudio: {try_get_version("torchaudio")}\n' + if try_get_version('einops'): + msg += f' - einops: {try_get_version("einops")}\n' + if try_get_version('numpy'): + msg += f' - numpy: {try_get_version("numpy")}\n' + if try_get_version('pandas'): + msg += f' - pandas: {try_get_version("pandas")}\n' + if try_get_version('xarray'): + msg += f' - xarray: {try_get_version("xarray")}\n' + if try_get_version('sklearn'): + msg += f' - sklearn: {try_get_version("sklearn")}\n' + + if 'SLURM_JOB_ID' in os.environ: + msg += '* SLURM:\n' + msg += f' - SLURM_JOB_ID = {slurm.slurm_job_id()}\n' + msg += f' - SLURM_STEP_ID = {slurm.slurm_step_id()}\n' + msg += f' - SLURM_STEP_NODELIST = {os.environ.get("SLURM_STEP_NODELIST")}\n' + msg += f' - SLURM_TASKS_PER_NODE = {os.environ.get("SLURM_TASKS_PER_NODE")}\n' + msg += f' - SLURM_STEP_GPUS = {os.environ.get("SLURM_STEP_GPUS")}\n' + msg += f' - SLURM_GPUS_ON_NODE = {os.environ.get("SLURM_GPUS_ON_NODE")}\n' + msg += f' - SLURM_CPUS_PER_TASK = {os.environ.get("SLURM_CPUS_PER_TASK")}' + + return msg diff --git a/dmlcloud/util/project.py b/dmlcloud/util/project.py index 87f2a26..6bcde04 100644 --- a/dmlcloud/util/project.py +++ b/dmlcloud/util/project.py @@ -34,11 +34,11 @@ def is_setuptools_cli_script(module): def script_path(): """ Returns the path to the script or module that was executed. - If python runs in interactive mode, or if "-c" command line option was used, returns None. + If python runs in interactive mode, or if "-c" command line option was used, raises a RuntimeError. """ main = sys.modules['__main__'] if not hasattr(main, '__file__'): - return None + raise RuntimeError('script_path() is not supported in interactive mode') if is_setuptools_cli_script(main): stack = traceback.extract_stack() @@ -50,44 +50,28 @@ def script_path(): return Path(main.__file__).resolve() -def script_path_available(): - try: - script_path() - return True - except RuntimeError: - return False - - def script_dir(): """ Returns the directory containing the script or module that was executed. - If python runs in interactive mode, or if "-c" command line option was used, returns None. + If python runs in interactive mode, or if "-c" command line option was used, then raises RuntimeError. """ - path = script_path() - if path is None: - return None - else: - return path.parent + return script_path().parent + def project_dir(): """ Returns the top-level directory containing the script or module that was executed. - If python runs in interactive mode, or if "-c" command line option was used, returns None. + If python runs in interactive mode, or if "-c" command line option was used, then raises RuntimeError. """ cur_dir = script_dir() - if cur_dir is None: - return None while (cur_dir / '__init__.py').exists(): cur_dir = cur_dir.parent return cur_dir -def run_in_project(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, **kwargs): +def run_in_project(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, **kwargs): """ Runs a command in the project directory and returns the output. - If python runs in interactive mode, or if "-c" command line option was used, raises RuntimeError. """ cwd = project_dir() - if cwd is None: - raise RuntimeError("Cannot run in project directory: script path not available") return subprocess.run(cmd, cwd=cwd, stdout=stdout, stderr=stderr, **kwargs) diff --git a/dmlcloud/util/slurm.py b/dmlcloud/util/slurm.py new file mode 100644 index 0000000..a29bbc7 --- /dev/null +++ b/dmlcloud/util/slurm.py @@ -0,0 +1,13 @@ +import os + + +def slurm_job_id(): + return os.environ.get('SLURM_JOB_ID') + + +def slurm_step_id(): + return os.environ.get('SLURM_STEP_ID') + + +def slurm_available(): + return slurm_job_id() is not None diff --git a/dmlcloud/util/tcp.py b/dmlcloud/util/tcp.py new file mode 100644 index 0000000..478a62b --- /dev/null +++ b/dmlcloud/util/tcp.py @@ -0,0 +1,18 @@ +import socket + + +def find_free_port(): + """ + Returns a free port on the local machine. + """ + with socket.socket() as s: + s.bind(('', 0)) + return s.getsockname()[1] + + +def get_local_ips(): + """ + Returns the IP addresses of the local machine. + """ + hostname = socket.gethostname() + return socket.gethostbyname_ex(hostname)[2] diff --git a/dmlcloud/util/thirdparty.py b/dmlcloud/util/thirdparty.py new file mode 100644 index 0000000..9693023 --- /dev/null +++ b/dmlcloud/util/thirdparty.py @@ -0,0 +1,18 @@ +import importlib +from types import ModuleType +from typing import Optional + + +def try_import(name: str) -> Optional[ModuleType]: + try: + return importlib.import_module(name) + except ImportError: + return None + + +def try_get_version(name: str) -> Optional[str]: + module = try_import(name) + if module is not None: + return str(module.__version__) + else: + return None diff --git a/dmlcloud/util/wandb.py b/dmlcloud/util/wandb.py index d1aa881..a45f946 100644 --- a/dmlcloud/util/wandb.py +++ b/dmlcloud/util/wandb.py @@ -1,7 +1,5 @@ import os -import wandb - def wandb_set_startup_timeout(seconds: int): assert isinstance(seconds, int) @@ -9,4 +7,6 @@ def wandb_set_startup_timeout(seconds: int): def wandb_is_initialized(): + import wandb + return wandb.run is not None diff --git a/examples/barebone_mnist.py b/examples/barebone_mnist.py new file mode 100644 index 0000000..1125b1f --- /dev/null +++ b/examples/barebone_mnist.py @@ -0,0 +1,95 @@ +import sys + +sys.path.insert(0, './') + +import torch +from dmlcloud.pipeline import TrainingPipeline +from dmlcloud.stage import Stage +from dmlcloud.util.distributed import init_process_group_auto, is_root, root_first +from torch import nn +from torch.utils.data import DataLoader +from torchvision import datasets, transforms + + +class MNISTStage(Stage): + def pre_stage(self): + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) + + with root_first(): + train_dataset = datasets.MNIST(root='data', train=True, download=is_root(), transform=transform) + self.train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + self.train_loader = DataLoader(train_dataset, batch_size=32, sampler=self.train_sampler) + + val_dataset = datasets.MNIST(root='data', train=False, download=is_root(), transform=transform) + val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False) + self.val_loader = DataLoader(val_dataset, batch_size=32, sampler=val_sampler) + + self.model = nn.Sequential( + nn.Conv2d(1, 16, 3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.Conv2d(16, 16, 3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.Flatten(), + nn.Linear(784, 10), + ).to(self.pipeline.device) + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3) + self.loss = nn.CrossEntropyLoss() + + def run_epoch(self): + self._train_epoch() + self._val_epoch() + + def _train_epoch(self): + self.model.train() + self.metric_prefix = 'train' + self.train_sampler.set_epoch(self.current_epoch) + + for img, target in self.train_loader: + img, target = img.to(self.pipeline.device), target.to(self.pipeline.device) + + self.optimizer.zero_grad() + output = self.model(img) + loss = self.loss(output, target) + loss.backward() + self.optimizer.step() + + self._log_metrics(img, target, output, loss) + + @torch.no_grad() + def _val_epoch(self): + self.model.eval() + self.metric_prefix = 'val' + + for img, target in self.val_loader: + img, target = img.to(self.pipeline.device), target.to(self.pipeline.device) + + output = self.model(img) + loss = self.loss(output, target) + + self._log_metrics(img, target, output, loss) + + def _log_metrics(self, img, target, output, loss): + self.track_reduce('loss', loss) + self.track_reduce('accuracy', (output.argmax(1) == target).float().mean()) + + def table_columns(self): + columns = super().table_columns() + columns.insert(1, {'name': '[Train] Loss', 'metric': 'train/loss'}) + columns.insert(2, {'name': '[Val] Loss', 'metric': 'val/loss'}) + columns.insert(3, {'name': '[Train] Acc.', 'metric': 'train/accuracy'}) + columns.insert(4, {'name': '[Val] Acc.', 'metric': 'val/accuracy'}) + return columns + + +def main(): + init_process_group_auto() + + pipeline = TrainingPipeline() + pipeline.append_stage(MNISTStage(), max_epochs=3) + pipeline.run() + + +if __name__ == '__main__': + main() diff --git a/examples/mnist.py b/examples/mnist.py new file mode 100644 index 0000000..61c03bd --- /dev/null +++ b/examples/mnist.py @@ -0,0 +1,70 @@ +import sys + +sys.path.insert(0, './') + +import torch +from dmlcloud.pipeline import TrainingPipeline +from dmlcloud.stage import TrainValStage +from dmlcloud.util.distributed import init_process_group_auto, is_root, root_first +from torch import nn +from torch.utils.data import DataLoader +from torchvision import datasets, transforms + + +class MNISTStage(TrainValStage): + def pre_stage(self): + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) + + with root_first(): + train_dataset = datasets.MNIST(root='data', train=True, download=is_root(), transform=transform) + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + self.pipeline.register_dataset('train', DataLoader(train_dataset, batch_size=32, sampler=train_sampler)) + + val_dataset = datasets.MNIST(root='data', train=False, download=is_root(), transform=transform) + val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False) + self.pipeline.register_dataset('val', DataLoader(val_dataset, batch_size=32, sampler=val_sampler)) + + model = nn.Sequential( + nn.Conv2d(1, 16, 3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.Conv2d(16, 16, 3, padding=1), + nn.ReLU(), + nn.MaxPool2d(2), + nn.Flatten(), + nn.Linear(784, 10), + ) + self.pipeline.register_model('cnn', model) + + self.pipeline.register_optimizer('adam', torch.optim.Adam(model.parameters(), lr=1e-3)) + + self.loss = nn.CrossEntropyLoss() + + def step(self, batch) -> torch.Tensor: + img, target = batch + img, target = img.to(self.device), target.to(self.device) + + output = self.pipeline.models['cnn'](img) + loss = self.loss(output, target) + + self.track_reduce('accuracy', (output.argmax(1) == target).float().mean()) + return loss + + def table_columns(self): + columns = super().table_columns() + columns.insert(-2, {'name': '[Val] Acc.', 'metric': 'val/accuracy'}) + columns.insert(-2, {'name': '[Train] Acc.', 'metric': 'train/accuracy'}) + return columns + + +def main(): + init_process_group_auto() + pipeline = TrainingPipeline(name='mnist') + pipeline.enable_checkpointing('checkpoints', resume=False) + pipeline.enable_wandb() + pipeline.append_stage(MNISTStage(), max_epochs=3) + pipeline.run() + + +if __name__ == '__main__': + main() diff --git a/pyproject.toml b/pyproject.toml index 6dee53a..8af9a4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,15 +8,15 @@ authors = [ {name = "Sebastian Hoffmann"} ] description = "Distributed torch training using horovod and slurm" -requires-python = ">=3.9" +requires-python = ">=3.10" license = {file = "LICENSE"} -keywords = ["pytorch", "horovod", "slurm", "distributed", "training"] +keywords = ["pytorch", "torch.distributed", "slurm", "distributed training", "deep learning"] classifiers = [ "Development Status :: 3 - Alpha", "License :: OSI Approved :: BSD License", "Operating System :: MacOS", "Operating System :: POSIX :: Linux", - "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", "Topic :: Scientific/Engineering :: Artificial Intelligence", ] dynamic = ["version", "readme", "dependencies"] @@ -32,4 +32,4 @@ dependencies = {file = ["requirements.txt"]} [tool.black] skip-string-normalization = true line-length = 120 -target-version = ["py39"] +target-version = ["py310"] diff --git a/requirements.txt b/requirements.txt index 0c25c84..d522863 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ torch -wandb numpy -progress_table +progress_table>=0.1.20,<1.0.0 +omegaconf diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 0000000..4384df9 --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,9 @@ +import pytest +from dmlcloud.util.distributed import deinitialize_torch_distributed, init_process_group_dummy + + +@pytest.fixture +def torch_distributed(): + init_process_group_dummy() + yield + deinitialize_torch_distributed() diff --git a/test/test_metrics.py b/test/test_metrics.py new file mode 100644 index 0000000..4f276e8 --- /dev/null +++ b/test/test_metrics.py @@ -0,0 +1,202 @@ +import sys + +sys.path.insert(0, './') + +import pytest +import torch +from dmlcloud.metrics import MetricReducer, MetricTracker, Reduction + + +class TestMetricReducer: + def test_local_reduction(self): + reducer = MetricReducer(reduction=Reduction.MIN, globally=False) + reducer.append(torch.tensor([1, 2, 3], dtype=torch.float)) + reducer.append(torch.tensor([-1, -2, -3], dtype=torch.float)) + reducer.append(torch.tensor([1, 7, 10], dtype=torch.float)) + + assert reducer.reduce_locally().item() == -3 + assert reducer.reduce_globally().item() == -3 + + reducer.reduction = Reduction.MAX + assert reducer.reduce_locally().item() == 10 + assert reducer.reduce_globally().item() == 10 + + reducer.reduction = Reduction.SUM + assert reducer.reduce_locally().item() == 18 + assert reducer.reduce_globally().item() == 18 + + reducer.reduction = Reduction.MEAN + assert reducer.reduce_locally().item() == 2 + assert reducer.reduce_globally().item() == 2 + + def test_global_reduction(self, torch_distributed): + reducer = MetricReducer(reduction=Reduction.MIN, globally=True) + reducer.append(torch.tensor([1, 2, 3], dtype=torch.float)) + reducer.append(torch.tensor([-1, -2, -3], dtype=torch.float)) + reducer.append(torch.tensor([1, 7, 10], dtype=torch.float)) + + assert reducer.reduce_locally().item() == -3 + assert reducer.reduce_globally().item() == -3 + + reducer.reduction = Reduction.MAX + assert reducer.reduce_locally().item() == 10 + assert reducer.reduce_globally().item() == 10 + + reducer.reduction = Reduction.SUM + assert reducer.reduce_locally().item() == 18 + assert reducer.reduce_globally().item() == 18 + + reducer.reduction = Reduction.MEAN + assert reducer.reduce_locally().item() == 2 + assert reducer.reduce_globally().item() == 2 + + def test_partial_reduction(self): + tensor = torch.tensor([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]], dtype=torch.float) # shape: 2x2x3 + print(tensor.shape) + + reducer = MetricReducer(reduction=Reduction.MIN, globally=False, dim=[1, 2]) + reducer.append(tensor) + result = reducer.reduce_locally() + assert result.shape == (2,) + assert result[0].item() == 1 + assert result[1].item() == 1 + + reducer = MetricReducer(reduction=Reduction.SUM, globally=False, dim=2) + reducer.append(tensor) + result = reducer.reduce_locally() + assert result.shape == (2, 2) + assert result[0, 0].item() == 6 + assert result[0, 1].item() == 15 + assert result[1, 0].item() == 6 + assert result[1, 1].item() == 15 + + def test_serialization(self): + reducer = MetricReducer(reduction=Reduction.MIN, dim=(1, 2, 3)) + reducer.append(torch.tensor([1, 2, 3])) + state_dict = reducer.state_dict() + + new_reducer = MetricReducer() + new_reducer.load_state_dict(state_dict) + assert new_reducer.reduction == Reduction.MIN + assert new_reducer.dim == [1, 2, 3] + assert new_reducer.values == reducer.values + + +class TestMetricTracker: + def test_dictionary(self): + tracker = MetricTracker() + assert len(tracker) == 0 + + tracker.register_metric('A') + tracker.register_metric('B', reduction=Reduction.MEAN, globally=False) + assert len(tracker) == 2 + + assert 'A' in tracker + assert 'B' in tracker + assert 'C' not in tracker + + assert isinstance(tracker['A'], list) + assert len(tracker['A']) == 0 + + def test_is_reduced_metric(self): + tracker = MetricTracker() + tracker.register_metric('A') + tracker.register_metric('B', reduction=Reduction.MEAN, globally=False) + + assert not tracker.is_reduced_metric('A') + assert tracker.is_reduced_metric('B') + + def test_epoch_filling(self): + tracker = MetricTracker() + tracker.register_metric('A') + + tracker.next_epoch() + assert len(tracker['A']) == 1 and tracker['A'][0] is None + assert tracker.epoch == 2 + + tracker.next_epoch() + assert len(tracker['A']) == 2 and tracker['A'][1] is None + assert tracker.epoch == 3 + + tracker.register_metric('B', reduction=Reduction.MEAN, globally=False) + assert len(tracker['B']) == 2 and tracker['B'][1] is None + + def test_track(self): + tracker = MetricTracker() + tracker.register_metric('A') + + tracker.track('A', 1) + with pytest.raises(ValueError): # haven't progressed the epoch yet + tracker.track('A', 42) + tracker.next_epoch() + + tracker.track('A', 42) + + tracker.register_metric('B', reduction=Reduction.MEAN, globally=False) + tracker.track('B', 2.0) + tracker.track('B', 4.0) + tracker.track('B', 1.0) + tracker.track('B', 1.0) + + tracker.next_epoch() + assert tracker['A'] == [1, 42] + assert tracker['B'] == [None, torch.tensor(2.0)] + + def test_str(self): + tracker = MetricTracker() + tracker.register_metric('A') + tracker.register_metric('B', reduction=Reduction.MEAN, globally=False) + tracker.track('A', 1) + print(str(tracker)) + + def test_manual_reduction(self): + tracker = MetricTracker() + tracker.register_metric('A') + tracker.register_metric('B', reduction=Reduction.SUM, globally=False) + tracker.track('B', 1.0) + tracker.track('B', 2.0) + tracker.track('B', 3.0) + tracker.reduce_all(prefix='B') + + assert tracker.has_value('B') + assert not tracker.has_value('A') + assert tracker.current_value('B').item() == 6.0 + assert tracker.current_value('A') is None + assert tracker['B'] == [] + + with pytest.raises(ValueError): + tracker.reduce_all(prefix='B') + + # does not throw, nor modify value + tracker.reduce_all(prefix='B', strict=False) + assert tracker.current_value('B').item() == 6.0 + assert tracker['B'] == [] + + # advances epoch + tracker.next_epoch() + assert tracker['B'] == [torch.tensor(6.0)] + assert tracker['A'] == [None] + assert tracker.current_value('B') is None + + def test_serialization(self): + tracker1 = MetricTracker() + tracker1.register_metric('A') + tracker1.register_metric('B', reduction=Reduction.MEAN, globally=False) + + tracker1.track('A', 1) + tracker1.track('B', torch.randn(3, 2)) + tracker1.next_epoch() + tracker1.track('A', 2) + tracker1.track('B', torch.randn(3, 2)) + + state_dict = tracker1.state_dict() + tracker2 = MetricTracker() + tracker2.load_state_dict(state_dict) + assert tracker2.epoch == tracker1.epoch + assert 'A' in tracker2 and 'B' in tracker2 + assert tracker2['A'] == tracker1['A'] + assert tracker2['B'] == tracker1['B'] + + +if __name__ == '__main__': + sys.exit(pytest.main([__file__])) diff --git a/test/test_sharding.py b/test/test_sharding.py index 34bcb80..53a9006 100644 --- a/test/test_sharding.py +++ b/test/test_sharding.py @@ -2,7 +2,7 @@ import numpy as np import pytest -from dmlcloud.util import shard_indices +from dmlcloud.util.distributed import shard_indices from numpy.testing import assert_array_equal diff --git a/test/test_smoke.py b/test/test_smoke.py index e9c8c1f..eb89c2a 100644 --- a/test/test_smoke.py +++ b/test/test_smoke.py @@ -2,8 +2,8 @@ import pytest import torch -from dmlcloud.config import DefaultConfig -from dmlcloud.training import BaseTrainer, ClassificationTrainer +from dmlcloud.pipeline import TrainingPipeline +from dmlcloud.stage import TrainValStage class DummyDataset(torch.utils.data.Dataset): @@ -14,51 +14,32 @@ def __getitem__(self, idx): return torch.randn(10), torch.randint(0, 10, size=(1,)).item() -class SmokeTrainer(BaseTrainer): - def create_dataset(self): - train_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=4) - val_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=4) - return train_dl, val_dl +class DummyStage(TrainValStage): + def pre_stage(self): + self.model = torch.nn.Linear(10, 10) + self.pipeline.register_model('linear', self.model) - def create_model(self): - return torch.nn.Linear(10, 10) + self.optimizer = torch.optim.SGD(self.model.parameters(), lr=1e-3) + self.pipeline.register_optimizer('sgd', self.optimizer) - def create_loss(self): - return torch.nn.CrossEntropyLoss() + self.pipeline.register_dataset('train', torch.utils.data.DataLoader(DummyDataset(), batch_size=4)) + self.pipeline.register_dataset('val', torch.utils.data.DataLoader(DummyDataset(), batch_size=4)) - def create_optimizer(self, params, lr): - return torch.optim.SGD(params, lr=lr) + self.loss = torch.nn.CrossEntropyLoss() - def forward_step(self, batch_idx, batch): - x, y = (tensor.to(self.device) for tensor in batch) - pred = self.model(x) - loss = self.loss_fn(pred, y) + def step(self, batch): + x, y = batch + x, y = x.to(self.device), y.to(self.device) + output = self.model(x) + loss = self.loss(output, y) return loss -class SmokeClassificationTrainer(ClassificationTrainer): - def create_dataset(self): - train_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=4) - val_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=4) - return train_dl, val_dl - - def create_model(self): - return torch.nn.Linear(10, 10) - - def create_optimizer(self, params, lr): - return torch.optim.SGD(params, lr=lr) - - class TestSmoke: - def test_smoke(self): - cfg = DefaultConfig() - trainer = SmokeTrainer(cfg) - trainer.train(use_checkpointing=False, use_wandb=False, print_diagnostics=False) - - def test_classification_smoke(self): - cfg = DefaultConfig() - trainer = SmokeClassificationTrainer(cfg) - trainer.train(use_checkpointing=False, use_wandb=False, print_diagnostics=False) + def test_smoke(self, torch_distributed): + pipeline = TrainingPipeline() + pipeline.append_stage(DummyStage(), max_epochs=1) + pipeline.run() if __name__ == '__main__':