From 438bc60290f7eb65533ad3dac25d3b18e64ae197 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 27 Jan 2025 13:28:40 -0800 Subject: [PATCH 1/5] [WIP][RFC] Required changes for integration with TorchTitan Summary: We are not going to land this PR, this PR may be further divided into several PRs. Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchft/checkpointing.py | 6 ++++-- torchft/manager.py | 17 +++++++++++++---- torchft/optim.py | 10 +++++++++- torchft/process_group.py | 34 +++++++++++++++++++++++++++------- 4 files changed, 53 insertions(+), 14 deletions(-) diff --git a/torchft/checkpointing.py b/torchft/checkpointing.py index 48a5d51..8124099 100644 --- a/torchft/checkpointing.py +++ b/torchft/checkpointing.py @@ -134,8 +134,10 @@ def do_GET(self): self.end_headers() state_dict = ckpt_server._state_dict - + self._logger.warning("Before torch.save ===================.") torch.save(state_dict, self.wfile) + self._logger.warning("After torch.save ===================.") + except Exception as e: logger.exception( f"Exception in checkpoint server when handling {self.path=}: {e}", @@ -172,7 +174,7 @@ def load_from_address(cls, address: str, timeout: timedelta) -> T: data = f.read() reader = io.BytesIO(data) - return torch.load(reader, weights_only=True) + return torch.load(reader, weights_only=False) def address(self) -> str: """ diff --git a/torchft/manager.py b/torchft/manager.py index dc5ab30..ae4717f 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -87,8 +87,8 @@ class Manager: def __init__( self, pg: "ProcessGroup", - load_state_dict: Callable[[T], None], - state_dict: Callable[[], T], + load_state_dict: Optional[Callable[[T], None]], + state_dict: Optional[Callable[[], T]], min_replica_size: int, use_async_quorum: bool = True, timeout: timedelta = timedelta(seconds=60), @@ -144,7 +144,6 @@ def __init__( transfering checkpoints to recovering replicas """ self._load_state_dict = load_state_dict - self._state_dict = state_dict self._pending_state_dict: Optional[Dict[str, object]] = None self._use_async_quorum = use_async_quorum self._timeout = timeout @@ -226,6 +225,12 @@ def __init__( self._participating_rank: Optional[int] = None self._participating_world_size: int = 0 + def set_state_dict_fns( + self, load_state_dict: Callable[T, None], state_dict: Callable[[], T] + ) -> None: + self._load_state_dict = load_state_dict + self._user_state_dict = state_dict + def shutdown(self, wait: bool = True) -> None: """ Shutdown the manager and checkpoint server. @@ -533,6 +538,7 @@ def _apply_pending_state_dict(self) -> None: assert self._pending_state_dict is not None, "checkpoint was not staged" self._load_state_dict(self._pending_state_dict["user"]) self._pending_state_dict = None + self._logger.info("Loaded state dict.") def should_commit(self, timeout: Optional[timedelta] = None) -> bool: """ @@ -602,10 +608,13 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None: self._batches_committed = state_dict["batches_committed"] def _manager_state_dict(self) -> Dict[str, object]: - return { + self._logger.warning("Before state_dict ===================.") + ret = { "user": self._user_state_dict(), "torchft": self.state_dict(), } + self._logger.warning("After state_dict ===================.") + return ret def state_dict(self) -> Dict[str, int]: """ diff --git a/torchft/optim.py b/torchft/optim.py index ce24823..b26ed1e 100644 --- a/torchft/optim.py +++ b/torchft/optim.py @@ -12,7 +12,7 @@ """ -from typing import TYPE_CHECKING, Optional +from typing import Any, TYPE_CHECKING, Optional from torch.optim import Optimizer @@ -52,3 +52,11 @@ def step(self, closure: Optional[object] = None) -> None: assert closure is None, "optimizers that use closures are not supported" if self.manager.should_commit(): self.optim.step() + + @property + def param_groups(self) -> Any: + return self.optim.param_groups + + @property + def state(self) -> Any: + return self.optim.state diff --git a/torchft/process_group.py b/torchft/process_group.py index d1d2cbe..d073854 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -20,7 +20,7 @@ import queue import threading from datetime import timedelta -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union +from typing import Any, TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union import torch import torch.distributed as dist @@ -861,6 +861,8 @@ def extend_device_mesh( class ManagedDeviceMesh(DeviceMesh): + replicate_pg_singleton: Optional["ManagedProcessGroup"] + def __init__( self, mesh: Optional[DeviceMesh], @@ -889,6 +891,15 @@ def __init__( self._flatten_mesh_list: Tuple[DeviceMesh, ...] = tuple() self._thread_id: Optional[int] = None + def __getstate__(self) -> Dict[str, Any]: + state = self.__dict__.copy() + state["replicate_pg"] = None + return state + + def __setstate__(self, state: Dict[str, Any]) -> None: + self.__dict__.update(state) + self.replicate_pg = self.replicate_pg_singleton + def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh: if isinstance(mesh_dim_names, str): if mesh_dim_names == self.replicate_dim_name: @@ -906,13 +917,14 @@ def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh return self.mesh[mesh_dim_names] else: assert isinstance(mesh_dim_names, tuple) - if self.replicate_dim_name in mesh_dim_names: + if self.replicate_dim_name not in mesh_dim_names: assert self.mesh is not None return self.mesh[mesh_dim_names] else: assert self.mesh is not None + mesh_dim_names_wo_replicate = tuple(n for n in mesh_dim_names if n != self.replicate_dim_name) return ManagedDeviceMesh( - self.mesh[mesh_dim_names], + self.mesh[mesh_dim_names_wo_replicate], mesh_dim_names, self.replicate_pg, mesh_dim_names.index(self.replicate_dim_name), @@ -947,14 +959,16 @@ def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh": return flatten_mesh def size(self, mesh_dim: Optional[int] = None) -> int: + replicate_pg_size = self.replicate_pg.size() + replicate_pg_size = 1 if replicate_pg_size == 0 else replicate_pg_size if mesh_dim is None: if self.mesh is None: - return self.replicate_pg.size() + return replicate_pg_size else: assert self.mesh is not None - return self.mesh.size() * self.replicate_pg.size() + return self.mesh.size() * replicate_pg_size elif mesh_dim == self.replicate_dim: - return self.replicate_pg.size() + return replicate_pg_size else: assert self.mesh is not None return self.mesh.size(self._real_mesh_dim(mesh_dim)) @@ -1004,7 +1018,11 @@ def get_coordinate(self) -> Optional[List[int]]: dimensions of the mesh. If this rank is not part of the mesh, return None. """ assert self.mesh is not None - return self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None + ret = self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None + if ret: + ret = ret.copy() + ret.insert(get_rank(self.replicate_pg), self.replicate_dim) + return ret def get_all_groups(self) -> List[BaseProcessGroup]: raise NotImplementedError @@ -1079,6 +1097,8 @@ def ft_init_device_mesh( # the same backend has been registered. replicate_pg.register(mesh_dim_names[replicate_dim]) + ManagedDeviceMesh.replicate_pg_singleton = replicate_pg + return ManagedDeviceMesh( mesh=mesh, mesh_dim_names=mesh_dim_names, From aab923919a4fc70bebd4e4170e005004782761b2 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 28 Jan 2025 16:09:40 -0800 Subject: [PATCH 2/5] Update log --- torchft/checkpointing.py | 4 ++-- torchft/manager.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torchft/checkpointing.py b/torchft/checkpointing.py index 8124099..4704661 100644 --- a/torchft/checkpointing.py +++ b/torchft/checkpointing.py @@ -134,9 +134,9 @@ def do_GET(self): self.end_headers() state_dict = ckpt_server._state_dict - self._logger.warning("Before torch.save ===================.") + logger.warning("Before torch.save ===================.") torch.save(state_dict, self.wfile) - self._logger.warning("After torch.save ===================.") + logger.warning("After torch.save ===================.") except Exception as e: logger.exception( diff --git a/torchft/manager.py b/torchft/manager.py index ae4717f..681c851 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -608,12 +608,12 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None: self._batches_committed = state_dict["batches_committed"] def _manager_state_dict(self) -> Dict[str, object]: - self._logger.warning("Before state_dict ===================.") + self._logger.warn("Before state_dict ===================.") ret = { "user": self._user_state_dict(), "torchft": self.state_dict(), } - self._logger.warning("After state_dict ===================.") + self._logger.warn("After state_dict ===================.") return ret def state_dict(self) -> Dict[str, int]: From 5d88b9ecb7fc0697387070ce287de06e1fdd3d8a Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 29 Jan 2025 10:21:21 -0800 Subject: [PATCH 3/5] Remove logs --- torchft/checkpointing.py | 2 -- torchft/manager.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/torchft/checkpointing.py b/torchft/checkpointing.py index 4704661..6b2d3fd 100644 --- a/torchft/checkpointing.py +++ b/torchft/checkpointing.py @@ -134,9 +134,7 @@ def do_GET(self): self.end_headers() state_dict = ckpt_server._state_dict - logger.warning("Before torch.save ===================.") torch.save(state_dict, self.wfile) - logger.warning("After torch.save ===================.") except Exception as e: logger.exception( diff --git a/torchft/manager.py b/torchft/manager.py index 681c851..2206130 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -608,12 +608,10 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None: self._batches_committed = state_dict["batches_committed"] def _manager_state_dict(self) -> Dict[str, object]: - self._logger.warn("Before state_dict ===================.") ret = { "user": self._user_state_dict(), "torchft": self.state_dict(), } - self._logger.warn("After state_dict ===================.") return ret def state_dict(self) -> Dict[str, int]: From 91dd25d48be7ff317c1f2703cd036a6a69bf7158 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 29 Jan 2025 11:10:01 -0800 Subject: [PATCH 4/5] Minor modifications --- torchft/checkpointing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchft/checkpointing.py b/torchft/checkpointing.py index 6b2d3fd..28a0a32 100644 --- a/torchft/checkpointing.py +++ b/torchft/checkpointing.py @@ -172,6 +172,8 @@ def load_from_address(cls, address: str, timeout: timedelta) -> T: data = f.read() reader = io.BytesIO(data) + # We have to set weights_only to True as there are some non-tensor + # states like lr_scheduler. return torch.load(reader, weights_only=False) def address(self) -> str: From 9daf023e5c11ff54902758fcd07fb781cd22c326 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 29 Jan 2025 11:10:45 -0800 Subject: [PATCH 5/5] Minor changes --- torchft/manager.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torchft/manager.py b/torchft/manager.py index 2206130..e31d35c 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -144,6 +144,7 @@ def __init__( transfering checkpoints to recovering replicas """ self._load_state_dict = load_state_dict + self._user_state_dict = state_dict self._pending_state_dict: Optional[Dict[str, object]] = None self._use_async_quorum = use_async_quorum self._timeout = timeout @@ -158,8 +159,6 @@ def __init__( world_size = world_size or int(os.environ["WORLD_SIZE"]) self._min_replica_size = min_replica_size - self._user_state_dict = state_dict - if checkpoint_transport is None: checkpoint_transport = CheckpointServer[Dict[str, T]]( timeout=timeout, @@ -608,11 +607,10 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None: self._batches_committed = state_dict["batches_committed"] def _manager_state_dict(self) -> Dict[str, object]: - ret = { + return { "user": self._user_state_dict(), "torchft": self.state_dict(), } - return ret def state_dict(self) -> Dict[str, int]: """