Skip to content

Commit

Permalink
fix: linting
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Dec 27, 2024
1 parent 974c9ef commit 75482be
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 60 deletions.
132 changes: 124 additions & 8 deletions dmlcloud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,136 @@
Hello world
"""

__version__ = "0.3.3"

###################################
# Sub Packages
###################################

import dmlcloud.data as data
import dmlcloud.git as git
import dmlcloud.slurm as slurm

from dmlcloud.core import *
from dmlcloud.core import __all__ as _core_all
__all__ = [
'data',
'git',
'slurm',
]

__version__ = "0.3.3"

__all__ = list(_core_all)
###################################
# Top-level API
###################################


# Pipeline

from .core.pipeline import Pipeline

# Packages
__all__ += [
'data',
'git',
'slurm',
Pipeline,
]

# Stage

from .core.stage import Stage

__all__ += [
Stage,
]

# Callbacks

from .core.callbacks import StageCallback

__all__ += [
'StageCallback',
]

# Distributed helpers

from .core.distributed import (
all_gather_object,
broadcast_object,
deinitialize_torch_distributed,
gather_object,
has_environment,
has_mpi,
has_slurm,
init,
is_root,
local_node,
local_rank,
local_world_size,
rank,
root_first,
root_only,
world_size,
)

__all__ += [
has_slurm,
has_environment,
has_mpi,
is_root,
root_only,
root_first,
rank,
world_size,
local_rank,
local_world_size,
local_node,
all_gather_object,
gather_object,
broadcast_object,
init,
deinitialize_torch_distributed,
]

# Metrics

from .core.metrics import Tracker, TrainingHistory

__all__ += [
Tracker,
TrainingHistory,
]


from .core.logging import (
critical,
debug,
error,
flush_logger,
info,
log,
logger,
print_root,
print_worker,
reset_logger,
setup_logger,
warning,
)

__all__ += [
logger,
setup_logger,
reset_logger,
flush_logger,
print_root,
print_worker,
log,
debug,
info,
warning,
error,
critical,
]

from .core.model import count_parameters, scale_lr, wrap_ddp

__all__ += [
wrap_ddp,
scale_lr,
count_parameters,
]
24 changes: 0 additions & 24 deletions dmlcloud/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +0,0 @@
from .pipeline import *
from .stage import *
from .distributed import *
from .metrics import *
from .logging import *
from .model import *

__all__ = []

# Main classes
__all__ += pipeline.__all__
__all__ += stage.__all__

# Ditributed helpers
__all__ += distributed.__all__

# Metrics
__all__ += metrics.__all__

# Logging
__all__ += logging.__all__

# Model helpers
__all__ += model.__all__
14 changes: 7 additions & 7 deletions dmlcloud/core/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
from datetime import datetime, timedelta
from pathlib import Path
from typing import Callable, Optional, Union
from typing import Callable, Optional, TYPE_CHECKING, Union

import torch
from progress_table import ProgressTable
Expand All @@ -11,11 +11,14 @@
from . import logging as dml_logging
from .distributed import is_root

if TYPE_CHECKING:
from .stage import Stage


__all__ = [
'TimedeltaFormatter',
'StageCallback',
'TimreCallback',
'TimerCallback',
'TableCallback',
'ReduceMetricsCallback',
'CsvCallback',
Expand Down Expand Up @@ -94,11 +97,8 @@ def post_epoch(self, stage: 'Stage'):
stage.log('misc/epoch_time', (stage.epoch_end_time - self.epoch_start_time).total_seconds(), prefixed=False)
stage.log('misc/total_time', (stage.epoch_end_time - self.start_time).total_seconds(), prefixed=False)

eta = (
(stage.epoch_end_time - self.start_time)
/ (stage.current_epoch + 1)
* (stage.max_epochs - stage.current_epoch - 1)
)
average_epoch_time = (stage.epoch_end_time - self.start_time) / (stage.current_epoch + 1)
eta = average_epoch_time * (stage.max_epochs - stage.current_epoch - 1)
stage.log('misc/eta', eta.total_seconds(), prefixed=False)

if len(stage.pipe.stages) > 1:
Expand Down
5 changes: 2 additions & 3 deletions dmlcloud/core/metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from collections import namedtuple
from enum import Enum
from typing import Any, Union

import numpy as np
Expand Down Expand Up @@ -105,12 +104,12 @@ def next_step(self):
raise ValueError(f'Metric {name} does not have a value for step {self.num_steps}')

for name, value in self._current_values.items():
if type(value) == ArrayLike:
if type(value) == ArrayLike: # noqa
value = np.as_array(value)

if name not in self._metrics:
self._metrics[name] = [value]
self._dtypes[name] = value.dtype if type(value) == ArrayLike else object
self._dtypes[name] = value.dtype if type(value) == ArrayLike else object # noqa
else:
self._metrics[name].append(value)

Expand Down
2 changes: 1 addition & 1 deletion dmlcloud/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def wrap_ddp(
module, broadcast_buffers=False, device_ids=device_ids, find_unused_parameters=find_unused_parameters
)
if verbose:
msg = f'* MODEL:\n'
msg = '* MODEL:\n'
msg += f' - Parameters: {count_parameters(module) / 1e6:.1f} kk\n'
msg += f' - {module}'
dml_logging.info(msg)
Expand Down
4 changes: 1 addition & 3 deletions dmlcloud/core/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import logging
import warnings
from datetime import datetime, timedelta
from functools import cached_property
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Dict, List, Optional, Sequence, Union

import torch
import torch.distributed as dist
from omegaconf import OmegaConf
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, Dataset

from dmlcloud.util.wandb import wandb, wandb_is_initialized, wandb_set_startup_timeout
Expand Down
21 changes: 12 additions & 9 deletions dmlcloud/core/stage.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Any, Callable, List, Optional
from typing import Any, Callable, TYPE_CHECKING

from . import logging as dml_logging
from .callbacks import CsvCallback, ReduceMetricsCallback, StageCallback, TableCallback, TimerCallback
from .callbacks import ReduceMetricsCallback, TableCallback, TimerCallback
from .metrics import Tracker, TrainingHistory

if TYPE_CHECKING:
from .callbacks import StageCallback

__all__ = [
'Stage',
]
Expand All @@ -22,7 +25,7 @@ def __init__(self, name: str = None, epochs: int = 1):
self.name = name or self.__class__.__name__
self.max_epochs = epochs

self.callbacks: List[StageCallback] = []
self.callbacks: list[StageCallback] = []

self.pipe = None # set by the pipeline

Expand Down Expand Up @@ -72,7 +75,7 @@ def epoch_end_time(self):
def table(self):
return self._table_callback.table

def add_callback(self, callback: StageCallback):
def add_callback(self, callback: 'StageCallback'):
"""
Adds a callback to this stage.
Expand All @@ -96,11 +99,11 @@ def add_metric(self, name, metric):
def add_column(
self,
name: str,
metric: Optional[str] = None,
formatter: Optional[Callable] = None,
width: Optional[int] = None,
color: Optional[str] = None,
alignment: Optional[str] = None,
metric: str | None = None,
formatter: Callable | None = None,
width: int | None = None,
color: str | None = None,
alignment: str | None = None,
):
"""
Adds a column to the table.
Expand Down
17 changes: 12 additions & 5 deletions dmlcloud/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
"""Contains helpers for distributed data processing and loading."""

from .sharding import *
from .xarray import *
from .interleave import *
from .dataset import *

__all__ = []

# Sharding

from .sharding import chunk_and_shard_indices, shard_indices, shard_sequence

__all__ += [
'shard_indices',
'shard_sequence',
'chunk_and_shard_indices',
]

# Dataset

from .dataset import BatchDataset, DownstreamDataset, PrefetchDataset, ShardedSequenceDataset

__all__ += [
'ShardedSequenceDataset',
'DownstreamDataset',
Expand All @@ -23,13 +24,19 @@
]

# Interleave

from .interleave import interleave_batches, interleave_dict_batches

__all__ += [
'interleave_batches',
'interleave_dict_batches',
]


# Xarray

from .xarray import sharded_xr_dataset, ShardedXrDataset

__all__ += [
'sharded_xr_dataset',
'ShardedXrDataset',
Expand Down

0 comments on commit 75482be

Please sign in to comment.