Skip to content

Commit

Permalink
Temp change to use Orbax checkpointer for Fuji
Browse files Browse the repository at this point in the history
  • Loading branch information
jiya-zhang committed Nov 27, 2024
1 parent 7b1f012 commit 86308df
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion axlearn/experiments/text/gpt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
set_double_shard_weights_config,
)
from axlearn.common.checkpointer import every_n_steps_and_last_policy
from axlearn.common.checkpointer_orbax import OrbaxCheckpointer
from axlearn.common.config import (
ConfigOr,
FunctionConfigBase,
Expand Down Expand Up @@ -700,11 +701,12 @@ def config_fn() -> InstantiableConfig:
)
cfg.evalers[name] = evaler_cfg
# Summaries and checkpoints.
cfg.checkpointer = OrbaxCheckpointer.default_config()
cfg.checkpointer.save_policy = config_for_function(every_n_steps_and_last_policy).set(
n=save_every_n_steps or min(eval_every_n_steps, 5_000),
max_step=max_step,
)
cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps)
# cfg.checkpointer.keep_every_n_steps = min(max_step, keep_every_n_steps)
cfg.checkpointer.keep_last_n = 3
cfg.summary_writer.write_every_n_steps = min(eval_every_n_steps, 100)
cfg.summary_writer.max_queue = 1000
Expand Down

0 comments on commit 86308df

Please sign in to comment.