diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index d32120d2..001e3672 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -39,6 +39,7 @@ set_double_shard_weights_config, ) from axlearn.common.checkpointer import every_n_steps_and_last_policy +from axlearn.common.checkpointer_orbax import OrbaxCheckpointer from axlearn.common.config import ( ConfigOr, FunctionConfigBase, @@ -700,11 +701,12 @@ def config_fn() -> InstantiableConfig: ) cfg.evalers[name] = evaler_cfg # Summaries and checkpoints. + cfg.checkpointer = OrbaxCheckpointer.default_config() cfg.checkpointer.save_policy = config_for_function(every_n_steps_and_last_policy).set( n=save_every_n_steps or min(eval_every_n_steps, 5_000), max_step=max_step, ) - cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps) + # cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps) cfg.checkpointer.keep_last_n = 3 cfg.summary_writer.write_every_n_steps = min(eval_every_n_steps, 100) cfg.summary_writer.max_queue = 1000