Skip to content

Commit

Permalink
chore: TrainingPipeline -> Pipeline, .pipeline -> .pipe
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Dec 18, 2024
1 parent 4775c02 commit fcd7042
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 22 deletions.
8 changes: 4 additions & 4 deletions dmlcloud/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@
from ..util.logging import experiment_header, general_diagnostics, IORedirector
from . import logging as dml_logging
from .checkpoint import CheckpointDir, find_slurm_checkpoint, generate_checkpoint_path
from .distributed import all_gather_object, broadcast_object, init, is_root, local_rank, root_only
from .distributed import all_gather_object, broadcast_object, init, local_rank, root_only
from .stage import Stage


__all__ = [
'TrainingPipeline',
'Pipeline',
]


class TrainingPipeline:
class Pipeline:
def __init__(self, config: Optional[Union[OmegaConf, Dict]] = None, name: Optional[str] = None):
if config is None:
self.config = OmegaConf.create()
Expand Down Expand Up @@ -82,7 +82,7 @@ def append_stage(self, stage: Stage, max_epochs: Optional[int] = None, name: Opt
if not isinstance(stage, Stage):
raise ValueError('stage must be a Stage object')

stage.pipeline = self
stage.pipe = self
stage.max_epochs = max_epochs
stage.name = name
self.stages.append(stage)
Expand Down
14 changes: 7 additions & 7 deletions dmlcloud/core/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Stage:
"""

def __init__(self):
self.pipeline = None # set by the pipeline
self.pipe = None # set by the pipeline
self.max_epochs = None # set by the pipeline
self.name = None # set by the pipeline

Expand All @@ -44,11 +44,11 @@ def __init__(self):

@property
def device(self):
return self.pipeline.device
return self.pipe.device

@property
def config(self):
return self.pipeline.config
return self.pipe.config

@property
def current_epoch(self):
Expand Down Expand Up @@ -120,7 +120,7 @@ def run(self):

def _pre_stage(self):
self.start_time = datetime.now()
if len(self.pipeline.stages) > 1:
if len(self.pipe.stages) > 1:
dml_logging.info(f'\n========== STAGE: {self.name} ==========')

self.table = ProgressTable(file=sys.stdout if is_root() else DevNullIO())
Expand All @@ -132,14 +132,14 @@ def _pre_stage(self):

dml_logging.flush_logger()

self.pipeline.barrier(self.barrier_timeout)
self.pipe.barrier(self.barrier_timeout)

def _post_stage(self):
self.table.close()
self.post_stage()
self.pipeline.barrier(self.barrier_timeout)
self.pipe.barrier(self.barrier_timeout)
self.stop_time = datetime.now()
if len(self.pipeline.stages) > 1:
if len(self.pipe.stages) > 1:
dml_logging.info(f'Finished stage in {self.stop_time - self.start_time}')

def _pre_epoch(self):
Expand Down
9 changes: 3 additions & 6 deletions doc/dmlcloud.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ Core
.. autosummary::
:toctree: generated

TrainingPipeline
Pipeline
Stage
TrainValStage


torch.distributed helpers
Expand Down Expand Up @@ -73,7 +72,5 @@ Metric Tracking
.. autosummary::
:toctree: generated

MetricReducer
MetricTracker
reduce_tensor
Reduction
TrainingHistory
Tracker
10 changes: 5 additions & 5 deletions examples/barebone_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _train_epoch(self):
self.train_sampler.set_epoch(self.current_epoch)

for img, target in self.train_loader:
img, target = img.to(self.pipeline.device), target.to(self.pipeline.device)
img, target = img.to(self.device), target.to(self.device)

self.optimizer.zero_grad()
output = self.model(img)
Expand All @@ -89,7 +89,7 @@ def _val_epoch(self):
self.metric_prefix = 'val'

for img, target in self.val_loader:
img, target = img.to(self.pipeline.device), target.to(self.pipeline.device)
img, target = img.to(self.device), target.to(self.device)

output = self.model(img)
loss = self.loss(output, target)
Expand All @@ -100,9 +100,9 @@ def _val_epoch(self):


def main():
pipeline = dml.TrainingPipeline()
pipeline.append_stage(MNISTStage(), max_epochs=3)
pipeline.run()
pipe = dml.Pipeline()
pipe.append_stage(MNISTStage(), max_epochs=3)
pipe.run()


if __name__ == '__main__':
Expand Down

0 comments on commit fcd7042

Please sign in to comment.