From b4a00eb8957645fb40a9e9476775078d470c6b66 Mon Sep 17 00:00:00 2001 From: Hanzhi Zhou Date: Mon, 11 Nov 2024 12:06:25 -0800 Subject: [PATCH] init process id --- axlearn/common/checkpointer_orbax.py | 201 ++++++++++++++++++++-- axlearn/common/checkpointer_orbax_test.py | 117 ++++++++++++- axlearn/common/launch.py | 20 ++- 3 files changed, 320 insertions(+), 18 deletions(-) diff --git a/axlearn/common/checkpointer_orbax.py b/axlearn/common/checkpointer_orbax.py index 52cccc92a..04e9cec97 100644 --- a/axlearn/common/checkpointer_orbax.py +++ b/axlearn/common/checkpointer_orbax.py @@ -9,17 +9,21 @@ import copy import dataclasses import functools +import hashlib import os import time from concurrent import futures from concurrent.futures import ThreadPoolExecutor +from multiprocessing import Process from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import jax +import jax.lib import orbax.checkpoint as ocp import orbax.checkpoint.experimental.emergency.checkpoint_manager as oecp import tensorflow as tf from absl import logging +from jax._src.distributed import global_state from jax._src.mesh import thread_resources from jax.experimental.array_serialization import serialization @@ -505,11 +509,181 @@ def _initialize_runtime_to_distributed_ids(timeout: int): ) +_PROCESS_ID_FILE_NAME: str = "process_id.txt" + + +def _get_previous_process_id(local_dir: str, *, unique_str: str) -> int: + """Gets previous process id from local checkpoint directory. Returns -1 if file isn't found.""" + path = os.path.join(local_dir, _get_unique_id(unique_str), _PROCESS_ID_FILE_NAME) + if not fs.exists(path): + return -1 + + with fs.open(path) as f: + proc_id = int(f.read()) + return proc_id + + +def _dump_process_id(local_dir: str, *, unique_str: str, process_index: int): + """Dumps process id to local checkpoint directory.""" + local_dir = os.path.join(local_dir, _get_unique_id(unique_str)) + fs.makedirs(local_dir) + process_id_file = os.path.join(local_dir, _PROCESS_ID_FILE_NAME) + with fs.open(process_id_file, "w") as f: + f.write(str(process_index)) + + +def _get_unique_id(unique_str: str) -> str: + return hashlib.sha256(unique_str.encode(), usedforsecurity=False).hexdigest() + + +def _init_consistent_proc_ids( + *, + distributed_coordinator: Optional[str] = None, + num_processes: Optional[int] = None, + process_id: Optional[int] = None, + initialization_timeout: Optional[int] = None, + trainer_dir: str, + local_ckpt_dir: str, +): + """Reads local process id file and assigns globally consistent process ids through rank 0. + + During failover, healthy nodes will read their locally stored process id file, but failed nodes + will lost their process ids. To assign ids that are free in the global id range (i.e. 0 to + num_processes - 1), we let each node report its process id (-1 if missing) to rank 0, and rank + 0 will figure out suitable IDs to assign to each failed node. We reuse Jax's distributed client + to avoid writing our own coordinator. + """ + jax.distributed.initialize( + coordinator_address=distributed_coordinator, + num_processes=num_processes, + process_id=process_id, + initialization_timeout=initialization_timeout, + ) + timeout_in_ms = 300 * 1000 + client: jax.lib.xla_extension.DistributedRuntimeClient = global_state.client + prev_process_id = _get_previous_process_id(local_ckpt_dir, unique_str=trainer_dir) + prefix = "axlearn/id_reassign" + # Local key just needs to be unique for each process. + local_set_key = f"{prefix}/{jax.process_index()}" + # For TPU backend, only GKE is supported for now. + if jax.default_backend() == "tpu": + # For TPUs, we have the additional requirement that process ids in slice id X must be in + # range [X * num_processes_per_slice, (X + 1) * num_processes_per_slice). Therefore, we + # first identify the healthy slices' ids and then figure out the slice ids to assign to + # failed slices. Each process in the failed slice will then get id `new_slice_id * + # num_proc_per_slice + worker_id`. + client.key_value_set( + local_set_key, + f"{os.environ['MEGASCALE_SLICE_ID']}|{prev_process_id}|{os.environ['TPU_WORKER_ID']}", + ) + client.wait_at_barrier("axlearn/id-reassign-gather-id", timeout_in_ms=timeout_in_ms) + if jax.process_index() == 0: + ids = client.key_value_dir_get(prefix) + parsed_ids: list[tuple[int, int, int]] = [] + for _, v in ids: + data = v.split("|") + assert len(data) == 3 + parsed_ids.append(tuple(int(x) for x in data)) + + num_proc_per_slice = len(str(os.environ.get("TPU_WORKER_HOSTNAMES", None)).split(",")) + failed_slices_new_ids = {} + for slice_id, prev_proc_id, _ in parsed_ids: + if prev_proc_id == -1: + failed_slices_new_ids[slice_id] = -1 + + already_assigned_slice_ids = set() + for slice_id, prev_proc_id, _ in parsed_ids: + if slice_id not in failed_slices_new_ids: + already_assigned_slice_ids.add(prev_proc_id // num_proc_per_slice) + + to_be_assigned_slice_ids = ( + set(range(int(os.environ["MEGASCALE_NUM_SLICES"]))) - already_assigned_slice_ids + ) + assert len(to_be_assigned_slice_ids) == len(failed_slices_new_ids) + for k, new_id in zip(failed_slices_new_ids.keys(), to_be_assigned_slice_ids): + failed_slices_new_ids[k] = new_id + + for (k, _), (slice_id, prev_proc_id, worker_id) in zip(ids, parsed_ids): + if (new_slice_id := failed_slices_new_ids.get(slice_id)) is not None: + client.key_value_set( + k + "/get", str(new_slice_id * num_proc_per_slice + worker_id) + ) + else: + client.key_value_set(k + "/get", str(prev_proc_id)) + elif jax.default_backend() == "gpu": + # For GPU backend, failed nodes are assigned with ids that are missing in the global id + # range with arbitrary order. + client.key_value_set(local_set_key, str(prev_process_id)) + client.wait_at_barrier("axlearn/id-reassign-gather-id", timeout_in_ms=timeout_in_ms) + if jax.process_index() == 0: + ids = client.key_value_dir_get(prefix) + to_be_assigned_proc_ids = list( + set(range(num_processes)) - set(int(value) for _, value in ids if int(value) != -1) + ) + counter = 0 + for k, value in ids: + if int(value) == -1: + client.key_value_set(k + "/get", str(to_be_assigned_proc_ids[counter])) + counter += 1 + else: + client.key_value_set(k + "/get", value) + assert counter == len(to_be_assigned_proc_ids) + else: + raise RuntimeError(f"Unsupported backend {jax.default_backend()}") + + _dump_process_id( + local_ckpt_dir, + unique_str=trainer_dir, + process_index=int( + client.blocking_key_value_get(local_set_key + "/get", timeout_in_ms=timeout_in_ms) + ), + ) + # Block to avoid coordinator exiting too early. + client.wait_at_barrier("axlearn/id-reassign-finalize", timeout_in_ms=timeout_in_ms) + jax.distributed.shutdown() + + +def get_consistent_proc_id( + *, + distributed_coordinator: Optional[str] = None, + num_processes: Optional[int] = None, + process_id: Optional[int] = None, + initialization_timeout: Optional[int] = None, + trainer_dir: str, + local_ckpt_dir: str, +) -> int: + """Returns process id so that process id <-> node mapping stays the same for health nodes. + + This is required to preserve shard order for in-memory checkpoint recovery. For GPU training, + all healthy nodes will have their process id unchanged. For TPU, all nodes in the healthy + slices will have their process id unchanged. See docstring of `_init_consistent_proc_ids` for + implementation details. + """ + proc = Process( + target=_init_consistent_proc_ids, + kwargs=dict( + distributed_coordinator=distributed_coordinator, + num_processes=num_processes, + process_id=process_id, + initialization_timeout=initialization_timeout, + trainer_dir=trainer_dir, + local_ckpt_dir=local_ckpt_dir, + ), + ) + proc.start() + proc.join() + assert proc.exitcode == 0 + + proc_id = _get_previous_process_id(local_ckpt_dir, unique_str=trainer_dir) + assert proc_id != -1 + return proc_id + + class OrbaxEmergencyCheckpointer(BaseCheckpointer): """Checkpointer implementation that uses Orbax emergency checkpoint. This checkpointer is intended for multi-slice training that uses data-parallelism across - slices. Orbax emergency checkpoint works by exploiting the following properties + slices. Orbax emergency checkpoint works by exploiting the following properties: 1. Tensors are replicated across data-parallel replicas. 2. When a slice fails in a multi-slice training and failover is started, only nodes corresponding to the non-healthy slice may be restarted. Healthy nodes from healthy slices @@ -560,9 +734,12 @@ class Config(BaseCheckpointer.Config): keep_every_n_steps: If > 0, keeps at least one persistent checkpoint every N steps. local_keep_last_n: Keep this many past ckpts in local storage (e.g. node memory). This should almost always set to 1 to avoid OOM. - local_dir: Ckpt path for local storage. The content in this path must persist across - pod restarts unless the restart is caused by node failure. `local_dir` must be the - same for all processes or processes may hang. + local_dir: Ckpt base path for local storage. The content in this path must persist + across pod restarts unless the restart is caused by node failure. `local_dir` must + be the same for all processes or processes may hang. + unqiue_str: A string that's unique for the current run. Typically, this is set to + trainer_dir. Local checkpoint will be stored in local_dir/sha256(unique_str). + During init, all other folders in local_dir will be removed. save_policy: Save policy for persistent checkpoints. local_save_policy: Save policy for local checkpoints. This should be more frequent than `save_policy`. Note that data iterator will be saved with either `save_policy` or @@ -580,6 +757,7 @@ class Config(BaseCheckpointer.Config): every_n_steps_policy ).set(n=10) local_dir: str = "/host-tmp/checkpoints" + unique_str: Required[str] = REQUIRED non_tensor_async_timeout_secs: int = 300 async_timeout_secs: int = 3600 replica_axis_index: Required[int] = REQUIRED @@ -624,12 +802,15 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]): if jax.process_index() == 0: fs.makedirs(os.path.join(cfg.dir, self._NON_TENSORS_PREFIX)) fs.makedirs(os.path.join(cfg.dir, self._TENSORS_PREFIX)) - fs.makedirs(cfg.local_dir) - ocp.multihost.sync_global_processes( - "axlearn-persistent-dir-create", timeout=cfg.non_tensor_async_timeout_secs - ) + # Cleanup local checkpoints from different runs. + unique_id = _get_unique_id(cfg.unique_str) + for fd in fs.listdir(cfg.local_dir): + if not fd.startswith(".") and fd != unique_id: + fs.rmtree(os.path.join(cfg.local_dir, fd)) + self._local_dir = os.path.join(cfg.local_dir, unique_id) + fs.makedirs(self._local_dir) # Orbax emergency ckpt requires this function to be called prior to checkpointer - # operations. + # operations. This function also serves as a barrier. _initialize_runtime_to_distributed_ids(cfg.non_tensor_async_timeout_secs) ckpt_cfg: Checkpointer.Config = Checkpointer.default_config() # TODO(hanzhi-zhou): this `keep_last_n` may not be what users expect since non-tensor @@ -695,7 +876,7 @@ def _orbax_save_fn( # For meaning of these options, refer to # https://github.com/google/orbax/blob/95be2c021bc8cbf4badd83a053ff57b7a9f9b314/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py#L277 self._tensor_manager = oecp.CheckpointManager( - cfg.local_dir, + self._local_dir, persistent_directory=os.path.join(cfg.dir, self._TENSORS_PREFIX), global_mesh=thread_resources.env.physical_mesh, abstract_state=self._get_abstract_state(state_with_tensors), diff --git a/axlearn/common/checkpointer_orbax_test.py b/axlearn/common/checkpointer_orbax_test.py index d5338a444..70cb58b72 100644 --- a/axlearn/common/checkpointer_orbax_test.py +++ b/axlearn/common/checkpointer_orbax_test.py @@ -13,7 +13,7 @@ import socket import tempfile from contextlib import ExitStack, closing -from typing import Sequence +from typing import Optional, Sequence import jax import numpy as np @@ -29,8 +29,12 @@ from axlearn.common.checkpointer_orbax import ( OrbaxCheckpointer, OrbaxEmergencyCheckpointer, + _dump_process_id, + _get_previous_process_id, + _init_consistent_proc_ids, config_for_function, every_n_steps_policy, + get_consistent_proc_id, ocp, ) @@ -88,10 +92,18 @@ def _all_devices_excepting_slice( multislice.slice_devices = slice_devices checkpoint_manager._all_devices_excepting_slice = _all_devices_excepting_slice + prev_process_id = get_consistent_proc_id( + distributed_coordinator=f"127.0.0.1:{port}", + num_processes=4, + process_id=process_id, + trainer_dir=persist_dir, + local_ckpt_dir=local_dir, + ) + jax.distributed.initialize( coordinator_address=f"127.0.0.1:{port}", num_processes=4, - process_id=process_id, + process_id=prev_process_id, local_device_ids=[process_id], ) @@ -100,7 +112,8 @@ def _all_devices_excepting_slice( cfg.save_policy = config_for_function(every_n_steps_policy).set(n=25) cfg.local_save_policy = config_for_function(every_n_steps_policy).set(n=5) # Local checkpoint path suffix must be the same for orbax synchronization to work. - cfg.local_dir = os.path.join(local_dir, "checkpoints") + cfg.local_dir = local_dir + cfg.unique_str = persist_dir cfg.dir = persist_dir cfg.keep_last_n = 2 cfg.replica_axis_index = 0 @@ -144,6 +157,37 @@ def _all_devices_excepting_slice( jax.distributed.shutdown() +def _test_init_proc_id_main( + *, + distributed_coordinator: Optional[str] = None, + num_processes: Optional[int] = None, + process_id: Optional[int] = None, + trainer_dir: str, + local_ckpt_dir: str, + proc_per_slice: int, + new_idx_map: dict[int, int], +): + # Fake some envs. + os.environ["MEGASCALE_NUM_SLICES"] = str(num_processes // proc_per_slice) + os.environ["MEGASCALE_SLICE_ID"] = f"{process_id // proc_per_slice}" + os.environ["TPU_WORKER_ID"] = str(process_id % proc_per_slice) + os.environ["TPU_WORKER_HOSTNAMES"] = ",".join(["a"] * proc_per_slice) + + if new_idx_map[process_id] != -1: + _dump_process_id( + local_ckpt_dir, unique_str=trainer_dir, process_index=new_idx_map[process_id] + ) + + jax.default_backend = lambda: "tpu" + _init_consistent_proc_ids( + distributed_coordinator=distributed_coordinator, + num_processes=num_processes, + process_id=process_id, + trainer_dir=trainer_dir, + local_ckpt_dir=local_ckpt_dir, + ) + + class OrbaxCheckpointerTest(test_utils.TestCase): def test_index(self): """Tests that index files saved with orbax can be read with `read_index_file`.""" @@ -171,6 +215,64 @@ def test_index(self): ) self.assertEqual(ref_index, test_index["index"]) + # This test can also run on CPU. + def test_init_proc_id_tpu(self): + free_port = _find_free_port() + new_idx_map = { + # First two slices are healthy, but have different slice id during restart. + 0: 2, + 1: 3, + 2: 6, + 3: 7, + 4: -1, # This failed slice has one node swapped out. + 5: 1, + 6: -1, # This failed slice has two nodes swapped out. + 7: -1, + } + with ExitStack() as stack: + num_processes = 8 + local_tempdirs = [ + stack.enter_context(tempfile.TemporaryDirectory()) for _ in range(num_processes) + ] + processes = [] + for i in range(num_processes): + proc = mp.Process( + target=_test_init_proc_id_main, + kwargs=dict( + distributed_coordinator=f"127.0.0.1:{free_port}", + num_processes=num_processes, + process_id=i, + trainer_dir="any", + local_ckpt_dir=local_tempdirs[i], + proc_per_slice=2, + new_idx_map=new_idx_map, + ), + ) + proc.start() + processes.append(proc) + + for p in processes: + p.join() + self.assertEqual(p.exitcode, 0) + + new_proc_ids = [ + _get_previous_process_id(local_dir, unique_str="any") + for local_dir in local_tempdirs + ] + for i in range(4): + self.assertEqual(new_proc_ids[i], new_idx_map[i]) + + if new_proc_ids[4] == 0: + self.assertEqual(new_proc_ids[5], 1) + self.assertEqual(new_proc_ids[6], 4) + self.assertEqual(new_proc_ids[7], 5) + elif new_proc_ids[4] == 4: + self.assertEqual(new_proc_ids[5], 5) + self.assertEqual(new_proc_ids[6], 0) + self.assertEqual(new_proc_ids[7], 1) + else: + self.fail("new proc id of proc 4 should be either 0 or 4") + # This test requires 4 devices to run. Note: we cannot use skipif(jax.local_device_count() < 4) # because it will initialize the backend, causing the jax.distributed.initialize to fail in # _test_orbax_main. Using the `spawn` context from multiprocessing results in a different error. @@ -187,18 +289,18 @@ def test_emergency_ckpt(self): # Populate log messages. _logger_init() - def start_processes() -> list[mp.Process]: + def start_processes(reverse_process_id: bool = False) -> list[mp.Process]: free_port = _find_free_port() processes = [] for i in range(num_processes): p = mp.Process( target=_test_orbax_main, args=( - i, + i if not reverse_process_id else num_processes - i - 1, free_port, persistent_tempdir, local_tempdirs[i], - None if i > 0 else q, + q, ), ) processes.append(p) @@ -214,7 +316,8 @@ def start_processes() -> list[mp.Process]: for p in processes: p.kill() - processes = start_processes() + # Shuffle the process ids to verify that we are able to restore the process id. + processes = start_processes(reverse_process_id=True) try: for p in processes: diff --git a/axlearn/common/launch.py b/axlearn/common/launch.py index 31532be2a..d8b1d35b8 100644 --- a/axlearn/common/launch.py +++ b/axlearn/common/launch.py @@ -41,6 +41,7 @@ from absl import flags, logging +from axlearn.common.checkpointer_orbax import get_consistent_proc_id from axlearn.common.status_server import StatusHTTPServer from axlearn.common.utils import get_data_dir from axlearn.common.utils_spmd import setup as setup_spmd @@ -88,6 +89,12 @@ "", "See the docstring of your `health_check_module`.", ) +flags.DEFINE_string( + "local_ckpt_dir", + "", + "If specified, enable local checkpoint and saves checkpoints to this " + "directory. See `OrbaxEmergencyCheckpointer` for more details", +) FLAGS = flags.FLAGS @@ -107,10 +114,21 @@ def setup(): health_check = nullcontext() with health_check: + if FLAGS.local_ckpt_dir: + process_id = get_consistent_proc_id( + distributed_coordinator=FLAGS.distributed_coordinator, + num_processes=FLAGS.num_processes, + process_id=FLAGS.process_id, + initialization_timeout=FLAGS.initialization_timeout, + trainer_dir=FLAGS.trainer_dir, + local_ckpt_dir=FLAGS.local_ckpt_dir, + ) + else: + process_id = FLAGS.process_id setup_spmd( distributed_coordinator=FLAGS.distributed_coordinator, num_processes=FLAGS.num_processes, - process_id=FLAGS.process_id, + process_id=process_id, jax_backend=FLAGS.jax_backend, initialization_timeout=FLAGS.initialization_timeout, )