Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][RFC] Required changes for integration with TorchTitan #82

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions torchft/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def do_GET(self):
self.end_headers()

state_dict = ckpt_server._state_dict

torch.save(state_dict, self.wfile)

except Exception as e:
logger.exception(
f"Exception in checkpoint server when handling {self.path=}: {e}",
Expand Down Expand Up @@ -172,7 +172,9 @@ def load_from_address(cls, address: str, timeout: timedelta) -> T:
data = f.read()

reader = io.BytesIO(data)
return torch.load(reader, weights_only=True)
# We have to set weights_only to True as there are some non-tensor
# states like lr_scheduler.
return torch.load(reader, weights_only=False)

def address(self) -> str:
"""
Expand Down
15 changes: 10 additions & 5 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ class Manager:
def __init__(
self,
pg: "ProcessGroup",
load_state_dict: Callable[[T], None],
state_dict: Callable[[], T],
load_state_dict: Optional[Callable[[T], None]],
state_dict: Optional[Callable[[], T]],
min_replica_size: int,
use_async_quorum: bool = True,
timeout: timedelta = timedelta(seconds=60),
Expand Down Expand Up @@ -144,7 +144,7 @@ def __init__(
transfering checkpoints to recovering replicas
"""
self._load_state_dict = load_state_dict
self._state_dict = state_dict
self._user_state_dict = state_dict
self._pending_state_dict: Optional[Dict[str, object]] = None
self._use_async_quorum = use_async_quorum
self._timeout = timeout
Expand All @@ -159,8 +159,6 @@ def __init__(
world_size = world_size or int(os.environ["WORLD_SIZE"])
self._min_replica_size = min_replica_size

self._user_state_dict = state_dict

if checkpoint_transport is None:
checkpoint_transport = CheckpointServer[Dict[str, T]](
timeout=timeout,
Expand Down Expand Up @@ -226,6 +224,12 @@ def __init__(
self._participating_rank: Optional[int] = None
self._participating_world_size: int = 0

def set_state_dict_fns(
self, load_state_dict: Callable[T, None], state_dict: Callable[[], T]
) -> None:
self._load_state_dict = load_state_dict
self._user_state_dict = state_dict

def shutdown(self, wait: bool = True) -> None:
"""
Shutdown the manager and checkpoint server.
Expand Down Expand Up @@ -533,6 +537,7 @@ def _apply_pending_state_dict(self) -> None:
assert self._pending_state_dict is not None, "checkpoint was not staged"
self._load_state_dict(self._pending_state_dict["user"])
self._pending_state_dict = None
self._logger.info("Loaded state dict.")

def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
"""
Expand Down
10 changes: 9 additions & 1 deletion torchft/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

"""

from typing import TYPE_CHECKING, Optional
from typing import Any, TYPE_CHECKING, Optional

from torch.optim import Optimizer

Expand Down Expand Up @@ -52,3 +52,11 @@ def step(self, closure: Optional[object] = None) -> None:
assert closure is None, "optimizers that use closures are not supported"
if self.manager.should_commit():
self.optim.step()

@property
def param_groups(self) -> Any:
return self.optim.param_groups

@property
def state(self) -> Any:
return self.optim.state
34 changes: 27 additions & 7 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import queue
import threading
from datetime import timedelta
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
from typing import Any, TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -861,6 +861,8 @@ def extend_device_mesh(


class ManagedDeviceMesh(DeviceMesh):
replicate_pg_singleton: Optional["ManagedProcessGroup"]

def __init__(
self,
mesh: Optional[DeviceMesh],
Expand Down Expand Up @@ -889,6 +891,15 @@ def __init__(
self._flatten_mesh_list: Tuple[DeviceMesh, ...] = tuple()
self._thread_id: Optional[int] = None

def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
state["replicate_pg"] = None
return state

def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state)
self.replicate_pg = self.replicate_pg_singleton

def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh:
if isinstance(mesh_dim_names, str):
if mesh_dim_names == self.replicate_dim_name:
Expand All @@ -906,13 +917,14 @@ def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh
return self.mesh[mesh_dim_names]
else:
assert isinstance(mesh_dim_names, tuple)
if self.replicate_dim_name in mesh_dim_names:
if self.replicate_dim_name not in mesh_dim_names:
assert self.mesh is not None
return self.mesh[mesh_dim_names]
else:
assert self.mesh is not None
mesh_dim_names_wo_replicate = tuple(n for n in mesh_dim_names if n != self.replicate_dim_name)
return ManagedDeviceMesh(
self.mesh[mesh_dim_names],
self.mesh[mesh_dim_names_wo_replicate],
mesh_dim_names,
self.replicate_pg,
mesh_dim_names.index(self.replicate_dim_name),
Expand Down Expand Up @@ -947,14 +959,16 @@ def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh":
return flatten_mesh

def size(self, mesh_dim: Optional[int] = None) -> int:
replicate_pg_size = self.replicate_pg.size()
replicate_pg_size = 1 if replicate_pg_size == 0 else replicate_pg_size
if mesh_dim is None:
if self.mesh is None:
return self.replicate_pg.size()
return replicate_pg_size
else:
assert self.mesh is not None
return self.mesh.size() * self.replicate_pg.size()
return self.mesh.size() * replicate_pg_size
elif mesh_dim == self.replicate_dim:
return self.replicate_pg.size()
return replicate_pg_size
else:
assert self.mesh is not None
return self.mesh.size(self._real_mesh_dim(mesh_dim))
Expand Down Expand Up @@ -1004,7 +1018,11 @@ def get_coordinate(self) -> Optional[List[int]]:
dimensions of the mesh. If this rank is not part of the mesh, return None.
"""
assert self.mesh is not None
return self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None
ret = self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None
if ret:
ret = ret.copy()
ret.insert(get_rank(self.replicate_pg), self.replicate_dim)
return ret

def get_all_groups(self) -> List[BaseProcessGroup]:
raise NotImplementedError
Expand Down Expand Up @@ -1079,6 +1097,8 @@ def ft_init_device_mesh(
# the same backend has been registered.
replicate_pg.register(mesh_dim_names[replicate_dim])

ManagedDeviceMesh.replicate_pg_singleton = replicate_pg

return ManagedDeviceMesh(
mesh=mesh,
mesh_dim_names=mesh_dim_names,
Expand Down
Loading