diff --git a/dmlcloud/core/pipeline.py b/dmlcloud/core/pipeline.py index 8c103a8..cc42fdb 100644 --- a/dmlcloud/core/pipeline.py +++ b/dmlcloud/core/pipeline.py @@ -78,13 +78,11 @@ def register_dataset(self, name: str, dataset: Union[DataLoader, Dataset, Sequen msg += ' - Batches (/Worker): N/A\n' dml_logging.info(msg) - def append_stage(self, stage: Stage, max_epochs: Optional[int] = None, name: Optional[str] = None): + def append(self, stage: Stage): if not isinstance(stage, Stage): raise ValueError('stage must be a Stage object') stage.pipe = self - stage.max_epochs = max_epochs - stage.name = name self.stages.append(stage) def enable_checkpointing( diff --git a/dmlcloud/core/stage.py b/dmlcloud/core/stage.py index cd9e48a..97502f4 100644 --- a/dmlcloud/core/stage.py +++ b/dmlcloud/core/stage.py @@ -23,10 +23,11 @@ class Stage: - post_epoch() """ - def __init__(self): + def __init__(self, name: str = None, epochs: int = 1): + self.name = name or self.__class__.__name__ + self.max_epochs = epochs + self.pipe = None # set by the pipeline - self.max_epochs = None # set by the pipeline - self.name = None # set by the pipeline self.history = TrainingHistory() self.tracker = Tracker() diff --git a/examples/barebone_mnist.py b/examples/barebone_mnist.py index 2ce5b67..17eb4f7 100644 --- a/examples/barebone_mnist.py +++ b/examples/barebone_mnist.py @@ -101,7 +101,7 @@ def _val_epoch(self): def main(): pipe = dml.Pipeline() - pipe.append_stage(MNISTStage(), max_epochs=3) + pipe.append(MNISTStage(epochs=3)) pipe.run()