diff --git a/kauldron/train/train_lib.py b/kauldron/train/train_lib.py index 4554e433..e6169719 100644 --- a/kauldron/train/train_lib.py +++ b/kauldron/train/train_lib.py @@ -97,6 +97,7 @@ def train_impl( num_train_steps=trainer.num_train_steps, stop_after_steps=trainer.stop_after_steps, profiler=trainer.profiler, + train_ds=trainer.train_ds, ): with timer.exclude_from_step_stats(): if ckpt.should_save(i): @@ -155,6 +156,7 @@ def _enum_steps_with_hooks( num_train_steps: Optional[int], stop_after_steps: Optional[int], profiler: profile_utils.Profiler, + train_ds: data.Pipeline, ) -> Iterator[int]: """Enumerate over the train dataset. @@ -169,18 +171,34 @@ def _enum_steps_with_hooks( num_train_steps: Same as `trainer.num_train_steps` stop_after_steps: Same as `trainer.stop_after_steps` profiler: Same as `trainer.profiler` + train_ds: Same as `trainer.train_ds` Yields: step: Step number batch: Example batch """ - # TODO(epot): Currently, setting `num_train_steps=None` will fail. Instead - # should use `len(ds)` or check `num_epoch is not None` - if num_train_steps is None: + train_num_epochs = train_ds.num_epochs + if train_num_epochs and num_train_steps: raise ValueError( - "`trainer.num_train_steps is None`. Please provide a value." + "Both `trainer.num_train_steps` and `trainer.train_ds.num_epochs` have" + " been defined. Please only define one of them." ) + try: + ds_len = len(train_ds) + except TypeError: + ds_len = None + if num_train_steps is None and (ds_len is None or train_num_epochs is None): + raise TypeError( + "`trainer.num_train_steps is None` and `len(trainer.train_ds) is None`" + " or `trainer.train_ds.num_epochs is None`. Users must specify either" + " the number of training steps or the number of epochs together with" + " dataset length." + ) + + if train_num_epochs: + num_train_steps = train_num_epochs * ds_len + total_steps = num_train_steps + 1 if stop_after_steps is not None: total_steps = min(total_steps, stop_after_steps)