Skip to content

Commit

Permalink
Internal
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 650927891
  • Loading branch information
aravindhm authored and The kauldron Authors committed Jul 11, 2024
1 parent ec5067d commit d2b9409
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
36 changes: 36 additions & 0 deletions kauldron/evals/eval_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
# XManager API do not have API for jobs within a work-unit to communicate,
# so use files for communication.
TRAIN_COMPLETE_FILENAME = 'train_complete.txt'
EVAL_COMPLETE_FILENAME = 'eval_{}_complete.txt'


def continuous_eval(
Expand Down Expand Up @@ -80,17 +81,29 @@ def continuous_eval(
# Split evaluators
every_checkpoint_evals: list[evaluators_lib.EvaluatorBase] = []
last_checkpoint_evals: list[evaluators_lib.EvaluatorBase] = []
# BEGIN:GOOGLE_INTERNAL
best_checkpoint_evals: list[evaluators_lib.EvaluatorBase] = []
# END: GOOGLE_INTERNAL
for name in eval_names:
ev = trainer.evals[name]
if isinstance(ev.run, run_strategies.StandaloneLastCheckpoint):
last_checkpoint_evals.append(ev)
elif isinstance(ev.run, run_strategies.StandaloneEveryCheckpoint):
every_checkpoint_evals.append(ev)
# BEGIN:GOOGLE_INTERNAL
elif isinstance(ev.run, run_strategies.StandaloneBestCheckpoint):
best_checkpoint_evals.append(ev)
# END:GOOGLE_INTERNAL
else:
raise ValueError(
f'Remote eval ({name!r}) should be standalone. Got run={ev.run}'
)

# BEGIN:GOOGLE_INTERNAL
if best_checkpoint_evals:
best_checkpoint_eval_impl.validate_trainer(trainer)
# END:GOOGLE_INTERNAL

logging.info('Start evaluating...')
# Initialize the final step from the state for eval-only jobs which restore
# the step from the `init_transforms`.
Expand Down Expand Up @@ -120,11 +133,34 @@ def continuous_eval(

final_step = step

# All every_checkpoint_evals have been processed. Marks those as complete.
if trainer.workdir.exists(): # `TrainEvaluator` do not have a workdir
for ev in every_checkpoint_evals:
epath.Path(trainer.workdir).joinpath(
EVAL_COMPLETE_FILENAME.format(ev.name)
).touch()

logging.info('Running final evals...')
for ev in last_checkpoint_evals:
with tracker.catch_exception(name=ev.name, step=final_step):
aux[ev.name] = ev.evaluate(state=state, step=final_step)

# All last_checkpoint_evals have been processed. Marks those as complete.
if trainer.workdir.exists(): # `TrainEvaluator` do not have a workdir
for ev in last_checkpoint_evals:
epath.Path(trainer.workdir).joinpath(
EVAL_COMPLETE_FILENAME.format(ev.name)
).touch()

# BEGIN:GOOGLE_INTERNAL
logging.info('Running best checkpoint evals...')
if best_checkpoint_evals:
with tracker.catch_exception(name='best_checkpoint_evals', step=final_step):
aux.update(
best_checkpoint_eval_impl.eval(trainer, best_checkpoint_evals, state)
)
# END:GOOGLE_INTERNAL

tracker.maybe_reraise()

# Return the last aux
Expand Down
27 changes: 27 additions & 0 deletions kauldron/evals/run_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from __future__ import annotations

import dataclasses
from typing import Literal

from kauldron import kontext
from kauldron.xm._src import job_params


Expand Down Expand Up @@ -161,3 +163,28 @@ class StandaloneLastCheckpoint(Standalone):
If `job_group='group_name'`, all the evaluators sharing the same `job_group`
will share the same XManager job (to save resources).
"""


# BEGIN: GOOGLE_INTERNAL
@dataclasses.dataclass(kw_only=True, frozen=True)
class StandaloneBestCheckpoint(Standalone):
"""Run eval only after the last checkpoint, after train has completed.
Run as a separate XM job. All `kxm.Job` parameters are optionally supported.
Example:
```python
kd.evals.Evaluator(
run=kd.evals.StandaloneLastCheckpoint(platforms='a100=1'),
)
```
If `job_group='group_name'`, all the evaluators sharing the same `job_group`
will share the same XManager job (to save resources).
"""

best_at_eval: str = ''
best_at_metric: str = ''
best_mode: Literal['min', 'max'] = 'max'
# END: GOOGLE_INTERNAL

0 comments on commit d2b9409

Please sign in to comment.