Skip to content

Commit

Permalink
Update on "[WIP][RFC] TorchFT integration"
Browse files Browse the repository at this point in the history
**Summary**
This is a WIP TorchFT integration PR.

**Current Issues**

This doesn't work at this moment as there are hanged groups when a new group joins. 

**Issue 1:**
~Group 0 and group 1 will hang during the first `should_commit` after group 1 applying the pending state_dict from group 0.~

Fixed with: pytorch/torchft#83

**Issue 2:**
~Group 0 and group 1 will pass the `should_commit` but group 0 needs healing which is wrong and the healing process will cause another hang.~

Fixed with: pytorch/torchft#83

**Issue 3:**
The byproduct of issue 1 and issue 2: group 1 will continue to print out
```
[rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer devgpu051.cln3.svc.fbinfra.net<33618>
```

***How to reproduce?***
Using the following the steps in `Reproduce steps` to run 2 groups. Then kill any of the group after both start training. Remember to apply pytorch/torchft#83.

**Issue 4:**
When there are 3 groups, everyone requests the state dict every step.

***How to reproduce?***
Using the `Reproduce steps` to run 2 groups, then add another group by modifying the command. 

**Reproduce steps:**

1. Patch TorchFT with pytorch/torchft#82
2. Execute lighthouse
3. Execute the following command in one terminal:
```
TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0
```
4. Wait 10 seconds, execute following command in another terminal:
```
TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1
```



[ghstack-poisoned]
  • Loading branch information
fegin committed Jan 31, 2025
2 parents 4b2edcb + f7ae033 commit 15c4b33
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def main(job_config: JobConfig):
# take control of garbage collection to avoid stragglers
gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq)

device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
device_module.set_device(device)
ft_manager = init_ft_manager(job_config)

# init distributed
Expand All @@ -60,8 +62,6 @@ def main(job_config: JobConfig):
enable_loss_parallel=not job_config.training.disable_loss_parallel,
ft_manager=ft_manager,
)
device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
device_module.set_device(device)
utils.init_distributed(job_config)
# initialize device memory monitor and get peak flops for MFU calculation
device_memory_monitor = build_device_memory_monitor()
Expand Down

0 comments on commit 15c4b33

Please sign in to comment.