Skip to content

Commit

Permalink
Import error in manager.py + switch to sync mode
Browse files Browse the repository at this point in the history
  • Loading branch information
mreso committed Jan 23, 2025
1 parent 7cc11f2 commit 50b846b
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
1 change: 1 addition & 0 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast

import torch
import torch.distributed as dist
from torch.distributed import ReduceOp, TCPStore

from torchft.checkpointing import CheckpointServer
Expand Down
3 changes: 1 addition & 2 deletions train_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def state_dict():
load_state_dict=load_state_dict,
state_dict=state_dict,
replica_id=f"train_fsdp_{REPLICA_GROUP_ID}",
use_async_quorum=False,
)

mesh = hsdp_device_mesh(NUM_REPLICA_GROUPS, NUM_REPLICAS, "cuda" if torch.cuda.is_available() else "cpu", manager=manager)
Expand All @@ -136,8 +137,6 @@ def state_dict():

optimizer = Optimizer(manager, torch.optim.Adam(model.parameters(), lr=1e-5))

optimizer.zero_grad()

while manager.current_step() < 500:
model.train()
for batch in tqdm(train_dataloader):
Expand Down

0 comments on commit 50b846b

Please sign in to comment.