diff --git a/dmlcloud/__init__.py b/dmlcloud/__init__.py index fc6dcd7..51a7ca7 100644 --- a/dmlcloud/__init__.py +++ b/dmlcloud/__init__.py @@ -2,20 +2,136 @@ Hello world """ +__version__ = "0.3.3" + +################################### +# Sub Packages +################################### + import dmlcloud.data as data import dmlcloud.git as git import dmlcloud.slurm as slurm -from dmlcloud.core import * -from dmlcloud.core import __all__ as _core_all +__all__ = [ + 'data', + 'git', + 'slurm', +] -__version__ = "0.3.3" -__all__ = list(_core_all) +################################### +# Top-level API +################################### + + +# Pipeline + +from .core.pipeline import Pipeline -# Packages __all__ += [ - 'data', - 'git', - 'slurm', + Pipeline, +] + +# Stage + +from .core.stage import Stage + +__all__ += [ + Stage, +] + +# Callbacks + +from .core.callbacks import StageCallback + +__all__ += [ + 'StageCallback', +] + +# Distributed helpers + +from .core.distributed import ( + all_gather_object, + broadcast_object, + deinitialize_torch_distributed, + gather_object, + has_environment, + has_mpi, + has_slurm, + init, + is_root, + local_node, + local_rank, + local_world_size, + rank, + root_first, + root_only, + world_size, +) + +__all__ += [ + has_slurm, + has_environment, + has_mpi, + is_root, + root_only, + root_first, + rank, + world_size, + local_rank, + local_world_size, + local_node, + all_gather_object, + gather_object, + broadcast_object, + init, + deinitialize_torch_distributed, +] + +# Metrics + +from .core.metrics import Tracker, TrainingHistory + +__all__ += [ + Tracker, + TrainingHistory, +] + + +from .core.logging import ( + critical, + debug, + error, + flush_logger, + info, + log, + logger, + print_root, + print_worker, + reset_logger, + setup_logger, + warning, +) + +__all__ += [ + logger, + setup_logger, + reset_logger, + flush_logger, + print_root, + print_worker, + log, + debug, + info, + warning, + error, + critical, +] + +from .core.model import count_parameters, scale_lr, wrap_ddp + +__all__ += [ + wrap_ddp, + scale_lr, + count_parameters, ] diff --git a/dmlcloud/core/__init__.py b/dmlcloud/core/__init__.py index d16c8e0..e69de29 100644 --- a/dmlcloud/core/__init__.py +++ b/dmlcloud/core/__init__.py @@ -1,24 +0,0 @@ -from .pipeline import * -from .stage import * -from .distributed import * -from .metrics import * -from .logging import * -from .model import * - -__all__ = [] - -# Main classes -__all__ += pipeline.__all__ -__all__ += stage.__all__ - -# Ditributed helpers -__all__ += distributed.__all__ - -# Metrics -__all__ += metrics.__all__ - -# Logging -__all__ += logging.__all__ - -# Model helpers -__all__ += model.__all__ diff --git a/dmlcloud/core/callbacks.py b/dmlcloud/core/callbacks.py index 158e8b3..f3f9ae4 100644 --- a/dmlcloud/core/callbacks.py +++ b/dmlcloud/core/callbacks.py @@ -2,7 +2,7 @@ import sys from datetime import datetime, timedelta from pathlib import Path -from typing import Callable, Optional, Union +from typing import Callable, Optional, TYPE_CHECKING, Union import torch from progress_table import ProgressTable @@ -11,11 +11,14 @@ from . import logging as dml_logging from .distributed import is_root +if TYPE_CHECKING: + from .stage import Stage + __all__ = [ 'TimedeltaFormatter', 'StageCallback', - 'TimreCallback', + 'TimerCallback', 'TableCallback', 'ReduceMetricsCallback', 'CsvCallback', @@ -94,11 +97,8 @@ def post_epoch(self, stage: 'Stage'): stage.log('misc/epoch_time', (stage.epoch_end_time - self.epoch_start_time).total_seconds(), prefixed=False) stage.log('misc/total_time', (stage.epoch_end_time - self.start_time).total_seconds(), prefixed=False) - eta = ( - (stage.epoch_end_time - self.start_time) - / (stage.current_epoch + 1) - * (stage.max_epochs - stage.current_epoch - 1) - ) + average_epoch_time = (stage.epoch_end_time - self.start_time) / (stage.current_epoch + 1) + eta = average_epoch_time * (stage.max_epochs - stage.current_epoch - 1) stage.log('misc/eta', eta.total_seconds(), prefixed=False) if len(stage.pipe.stages) > 1: diff --git a/dmlcloud/core/metrics.py b/dmlcloud/core/metrics.py index 31b5971..0ee2244 100644 --- a/dmlcloud/core/metrics.py +++ b/dmlcloud/core/metrics.py @@ -1,5 +1,4 @@ from collections import namedtuple -from enum import Enum from typing import Any, Union import numpy as np @@ -105,12 +104,12 @@ def next_step(self): raise ValueError(f'Metric {name} does not have a value for step {self.num_steps}') for name, value in self._current_values.items(): - if type(value) == ArrayLike: + if type(value) == ArrayLike: # noqa value = np.as_array(value) if name not in self._metrics: self._metrics[name] = [value] - self._dtypes[name] = value.dtype if type(value) == ArrayLike else object + self._dtypes[name] = value.dtype if type(value) == ArrayLike else object # noqa else: self._metrics[name].append(value) diff --git a/dmlcloud/core/model.py b/dmlcloud/core/model.py index 16a4548..94da345 100644 --- a/dmlcloud/core/model.py +++ b/dmlcloud/core/model.py @@ -57,7 +57,7 @@ def wrap_ddp( module, broadcast_buffers=False, device_ids=device_ids, find_unused_parameters=find_unused_parameters ) if verbose: - msg = f'* MODEL:\n' + msg = '* MODEL:\n' msg += f' - Parameters: {count_parameters(module) / 1e6:.1f} kk\n' msg += f' - {module}' dml_logging.info(msg) diff --git a/dmlcloud/core/pipeline.py b/dmlcloud/core/pipeline.py index 38c70d2..57f832f 100644 --- a/dmlcloud/core/pipeline.py +++ b/dmlcloud/core/pipeline.py @@ -1,13 +1,11 @@ -import logging import warnings from datetime import datetime, timedelta from functools import cached_property -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Dict, List, Optional, Sequence, Union import torch import torch.distributed as dist from omegaconf import OmegaConf -from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, Dataset from dmlcloud.util.wandb import wandb, wandb_is_initialized, wandb_set_startup_timeout diff --git a/dmlcloud/core/stage.py b/dmlcloud/core/stage.py index 22cd41d..5d43f49 100644 --- a/dmlcloud/core/stage.py +++ b/dmlcloud/core/stage.py @@ -1,9 +1,12 @@ -from typing import Any, Callable, List, Optional +from typing import Any, Callable, TYPE_CHECKING from . import logging as dml_logging -from .callbacks import CsvCallback, ReduceMetricsCallback, StageCallback, TableCallback, TimerCallback +from .callbacks import ReduceMetricsCallback, TableCallback, TimerCallback from .metrics import Tracker, TrainingHistory +if TYPE_CHECKING: + from .callbacks import StageCallback + __all__ = [ 'Stage', ] @@ -22,7 +25,7 @@ def __init__(self, name: str = None, epochs: int = 1): self.name = name or self.__class__.__name__ self.max_epochs = epochs - self.callbacks: List[StageCallback] = [] + self.callbacks: list[StageCallback] = [] self.pipe = None # set by the pipeline @@ -72,7 +75,7 @@ def epoch_end_time(self): def table(self): return self._table_callback.table - def add_callback(self, callback: StageCallback): + def add_callback(self, callback: 'StageCallback'): """ Adds a callback to this stage. @@ -96,11 +99,11 @@ def add_metric(self, name, metric): def add_column( self, name: str, - metric: Optional[str] = None, - formatter: Optional[Callable] = None, - width: Optional[int] = None, - color: Optional[str] = None, - alignment: Optional[str] = None, + metric: str | None = None, + formatter: Callable | None = None, + width: int | None = None, + color: str | None = None, + alignment: str | None = None, ): """ Adds a column to the table. diff --git a/dmlcloud/data/__init__.py b/dmlcloud/data/__init__.py index c0874b9..f70422d 100644 --- a/dmlcloud/data/__init__.py +++ b/dmlcloud/data/__init__.py @@ -1,13 +1,11 @@ """Contains helpers for distributed data processing and loading.""" -from .sharding import * -from .xarray import * -from .interleave import * -from .dataset import * - __all__ = [] # Sharding + +from .sharding import chunk_and_shard_indices, shard_indices, shard_sequence + __all__ += [ 'shard_indices', 'shard_sequence', @@ -15,6 +13,9 @@ ] # Dataset + +from .dataset import BatchDataset, DownstreamDataset, PrefetchDataset, ShardedSequenceDataset + __all__ += [ 'ShardedSequenceDataset', 'DownstreamDataset', @@ -23,6 +24,9 @@ ] # Interleave + +from .interleave import interleave_batches, interleave_dict_batches + __all__ += [ 'interleave_batches', 'interleave_dict_batches', @@ -30,6 +34,9 @@ # Xarray + +from .xarray import sharded_xr_dataset, ShardedXrDataset + __all__ += [ 'sharded_xr_dataset', 'ShardedXrDataset',