Skip to content

Commit

Permalink
chore: simplified pipeline creation
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Dec 18, 2024
1 parent fcd7042 commit 792a232
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 7 deletions.
4 changes: 1 addition & 3 deletions dmlcloud/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,11 @@ def register_dataset(self, name: str, dataset: Union[DataLoader, Dataset, Sequen
msg += ' - Batches (/Worker): N/A\n'
dml_logging.info(msg)

def append_stage(self, stage: Stage, max_epochs: Optional[int] = None, name: Optional[str] = None):
def append(self, stage: Stage):
if not isinstance(stage, Stage):
raise ValueError('stage must be a Stage object')

stage.pipe = self
stage.max_epochs = max_epochs
stage.name = name
self.stages.append(stage)

def enable_checkpointing(
Expand Down
7 changes: 4 additions & 3 deletions dmlcloud/core/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ class Stage:
- post_epoch()
"""

def __init__(self):
def __init__(self, name: str = None, epochs: int = 1):
self.name = name or self.__class__.__name__
self.max_epochs = epochs

self.pipe = None # set by the pipeline
self.max_epochs = None # set by the pipeline
self.name = None # set by the pipeline

self.history = TrainingHistory()
self.tracker = Tracker()
Expand Down
2 changes: 1 addition & 1 deletion examples/barebone_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _val_epoch(self):

def main():
pipe = dml.Pipeline()
pipe.append_stage(MNISTStage(), max_epochs=3)
pipe.append(MNISTStage(epochs=3))
pipe.run()


Expand Down

0 comments on commit 792a232

Please sign in to comment.