From c806456f42343e2121b04d4f510429f2ce89edbd Mon Sep 17 00:00:00 2001 From: Hanzhi Zhou Date: Tue, 19 Nov 2024 15:27:18 -0800 Subject: [PATCH] use cpu platform --- axlearn/common/checkpointer_orbax_emergency.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/axlearn/common/checkpointer_orbax_emergency.py b/axlearn/common/checkpointer_orbax_emergency.py index 72601a98e..7a0abc1c7 100644 --- a/axlearn/common/checkpointer_orbax_emergency.py +++ b/axlearn/common/checkpointer_orbax_emergency.py @@ -8,11 +8,11 @@ import copy import functools import hashlib +import multiprocessing as mp import os import time from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from multiprocessing import Process from typing import Any, Dict, List, Optional, Tuple, Union import jax @@ -418,7 +418,10 @@ def get_consistent_proc_info( used as the global coordinator address. """ start_t = time.perf_counter() - proc = Process( + platform = os.environ.get("JAX_PLATFORMS", "") + # Patch platform so the process doesn't waste time initializing accelerators. + os.environ["JAX_PLATFORMS"] = "cpu" + proc = mp.get_context("spawn").Process( target=_init_consistent_proc_ids, kwargs=dict( local_address=local_address, @@ -431,6 +434,9 @@ def get_consistent_proc_info( proc.start() proc.join() assert proc.exitcode == 0 + # Restore previous platform settings. + if platform != "": + os.environ["JAX_PLATFORMS"] = platform info = _get_previous_process_info(local_ckpt_dir, unique_str=trainer_dir) assert info.inv_proc_id != -1