Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support setting num_epochs instead of num_training_steps. #652

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions kauldron/train/train_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def train_impl(
num_train_steps=trainer.num_train_steps,
stop_after_steps=trainer.stop_after_steps,
profiler=trainer.profiler,
train_ds=trainer.train_ds,
):
with timer.exclude_from_step_stats():
if ckpt.should_save(i):
Expand Down Expand Up @@ -155,6 +156,7 @@ def _enum_steps_with_hooks(
num_train_steps: Optional[int],
stop_after_steps: Optional[int],
profiler: profile_utils.Profiler,
train_ds: data.Pipeline,
) -> Iterator[int]:
"""Enumerate over the train dataset.

Expand All @@ -169,18 +171,34 @@ def _enum_steps_with_hooks(
num_train_steps: Same as `trainer.num_train_steps`
stop_after_steps: Same as `trainer.stop_after_steps`
profiler: Same as `trainer.profiler`
train_ds: Same as `trainer.train_ds`

Yields:
step: Step number
batch: Example batch
"""
# TODO(epot): Currently, setting `num_train_steps=None` will fail. Instead
# should use `len(ds)` or check `num_epoch is not None`
if num_train_steps is None:
train_num_epochs = train_ds.num_epochs
if train_num_epochs and num_train_steps:
raise ValueError(
"`trainer.num_train_steps is None`. Please provide a value."
"Both `trainer.num_train_steps` and `trainer.train_ds.num_epochs` have"
" been defined. Please only define one of them."
)

try:
ds_len = len(train_ds)
except TypeError:
ds_len = None
if num_train_steps is None and (ds_len is None or train_num_epochs is None):
raise TypeError(
"`trainer.num_train_steps is None` and `len(trainer.train_ds) is None`"
" or `trainer.train_ds.num_epochs is None`. Users must specify either"
" the number of training steps or the number of epochs together with"
" dataset length."
)

if train_num_epochs:
num_train_steps = train_num_epochs * ds_len

total_steps = num_train_steps + 1
if stop_after_steps is not None:
total_steps = min(total_steps, stop_after_steps)
Expand Down