Skip to content

Commit

Permalink
Incorporate Orbax emergency replicator checkpoint manager
Browse files Browse the repository at this point in the history
  • Loading branch information
xuefgu committed Jan 8, 2025
1 parent c57363c commit 74c2685
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 13 deletions.
29 changes: 26 additions & 3 deletions MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import numpy as np
import orbax.checkpoint as ocp
import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager
import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager

# pylint: disable=too-many-positional-arguments

Expand Down Expand Up @@ -91,15 +92,15 @@ def create_orbax_emergency_checkpoint_manager(
persistent_save_interval_steps: int,
orbax_logger: Optional[abstract_logger.AbstractLogger] = None,
):
"""Returns an emergency checkpoint."""
"""Returns an emergency checkpoint manager."""
flags.FLAGS.experimental_orbax_use_distributed_process_id = True
max_logging.log("Creating emergency checkpoint manager...")

options = emergency_checkpoint_manager.CheckpointManagerOptions(
local=LocalCheckpointOptions(save_interval_steps=local_save_interval_steps),
persistent=PersistentCheckpointOptions(save_interval_steps=persistent_save_interval_steps),
)
emergency_mngr = emergency_checkpoint_manager.CheckpointManager(
manager = emergency_checkpoint_manager.CheckpointManager(
local_checkpoint_dir,
epath.Path(persistent_checkpoint_dir),
global_mesh=global_mesh,
Expand All @@ -109,7 +110,29 @@ def create_orbax_emergency_checkpoint_manager(
)

max_logging.log("Emergency checkpoint manager created!")
return emergency_mngr
return manager


def create_orbax_emergency_replicator_checkpoint_manager(
local_checkpoint_dir: str,
save_interval_steps: int,
global_mesh: jax.sharding.Mesh,
):
"""Returns an emergency replicator checkpoint manager."""
flags.FLAGS.experimental_orbax_use_distributed_process_id = True
max_logging.log("Creating emergency replicator checkpoint manager...")

options = emergency_replicator_checkpoint_manager.ReplicatorCheckpointManagerOptions(
save_interval_steps=save_interval_steps,
)
manager = emergency_replicator_checkpoint_manager.ReplicatorCheckpointManager(
local_checkpoint_dir,
options,
global_mesh=global_mesh,
)

max_logging.log("Emergency replicator checkpoint manager created!")
return manager


def print_save_message(step, async_checkpointing):
Expand Down
27 changes: 17 additions & 10 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,16 +648,23 @@ def setup_mesh_and_model(config):
tx = optimizers.get_optimizer(config, learning_rate_schedule)
logger = checkpointing.setup_checkpoint_logger(config)
if config.enable_emergency_checkpoint:
abstract_state, _, _ = max_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True)
checkpoint_manager = checkpointing.create_orbax_emergency_checkpoint_manager(
config.local_checkpoint_directory,
config.checkpoint_dir,
mesh,
abstract_state,
config.local_checkpoint_period,
config.checkpoint_period,
logger,
)
if config.use_replicator_service:
checkpoint_manager = checkpointing.create_orbax_emergency_replicator_checkpoint_manager(
config.local_checkpoint_directory,
config.local_checkpoint_period,
mesh,
)
else:
abstract_state, _, _ = max_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True)
checkpoint_manager = checkpointing.create_orbax_emergency_checkpoint_manager(
config.local_checkpoint_directory,
config.checkpoint_dir,
mesh,
abstract_state,
config.local_checkpoint_period,
config.checkpoint_period,
logger,
)
else:
# TODO(b/368121306): Remove this once zarr3 support is plumbed on the backend
use_ocdbt = config.checkpoint_storage_use_ocdbt
Expand Down

0 comments on commit 74c2685

Please sign in to comment.