Skip to content

Commit

Permalink
Add _test_pg helper (#45)
Browse files Browse the repository at this point in the history
* Add _test_pg helper

* update CONTRIBUTING.md
  • Loading branch information
H-Huang authored Dec 18, 2024
1 parent 6d6e9a4 commit 49d2aec
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 32 deletions.
14 changes: 14 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,20 @@ lintrunner -a

### Tests

We use `pytest` as our testing framework. To execute a specific test, use the following command:

```sh
pytest torchft/process_group_test.py -k test_device_mesh
```

To run the Rust tests run:

```sh
cargo test
```

To run the entire suite of tests:

```sh
$ scripts/test.sh
```
Expand Down
108 changes: 76 additions & 32 deletions torchft/process_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,34 @@

import os
from concurrent.futures import ThreadPoolExecutor
from typing import Tuple
from typing import Any, Dict, Tuple
from unittest import TestCase, skipUnless
from unittest.mock import Mock

import torch
import torch.distributed as dist
from torch import nn
from torch._C._distributed_c10d import _resolve_process_group
from torch.distributed import ReduceOp, TCPStore, Work, _functional_collectives
from torch._C._distributed_c10d import (
AllgatherOptions,
AllreduceOptions,
BroadcastOptions,
ReduceOp,
_resolve_process_group,
)
from torch.distributed import (
ReduceOp,
TCPStore,
Work,
_functional_collectives,
get_world_size,
)
from torch.distributed.device_mesh import init_device_mesh

from torchft.manager import Manager
from torchft.process_group import (
ErrorSwallowingProcessGroupWrapper,
ManagedProcessGroup,
ProcessGroup,
ProcessGroupBabyGloo,
ProcessGroupBabyNCCL,
ProcessGroupDummy,
Expand All @@ -41,6 +54,56 @@ def dummy_init_pg() -> None:
)


def _test_pg(
pg: ProcessGroup,
example_tensor: torch.Tensor = torch.randn((2, 3), dtype=torch.float32),
) -> Dict[str, dist._Work]:
"""
Helper function to test a set of collective operations on a given process group.
"""

shape: torch.Size = example_tensor.shape
dtype: torch.dtype = example_tensor.dtype

# Create some dummy tensors for testing
input_tensor = example_tensor.clone()
output_tensors = [
[torch.empty_like(input_tensor) for _ in range(get_world_size(pg))]
]
tensor_list = [torch.empty_like(input_tensor)]

def check_tensors(arg: Any) -> None: # pyre-ignore[2]
"""Recursively check tensors for expected shape and dtype."""
if isinstance(arg, torch.Tensor):
assert arg.dtype == dtype, f"Output dtype mismatch: {arg.dtype} != {dtype}"
assert arg.shape == shape, f"Output shape mismatch: {arg.shape} != {shape}"
elif isinstance(arg, (list, tuple)):
for item in arg:
check_tensors(item)

# Test collectives
collectives = {
"allreduce": ([input_tensor], AllreduceOptions()),
"allgather": (output_tensors, [input_tensor], AllgatherOptions()),
"broadcast": (tensor_list, BroadcastOptions()),
"broadcast_one": (input_tensor, 0),
}
works: Dict[str, dist._Work] = {}
for coll_str, args in collectives.items():
coll = getattr(pg, coll_str)
work = coll(*args)
works[coll_str] = work
work.wait()
fut = work.get_future()
fut.wait()

# Check that all tensor arguments have the expected shapes and dtypes
check_tensors(args)

print(works)
return works


class ProcessGroupTest(TestCase):
def test_gloo(self) -> None:
store = TCPStore(
Expand All @@ -53,11 +116,7 @@ def test_gloo(self) -> None:

self.assertEqual(pg.size(), 1)

at = torch.tensor([2])

a_work = pg.allreduce([at], ReduceOp.SUM)
a_work.wait()
a_work.get_future().wait()
_test_pg(pg)

m = nn.Linear(3, 4)
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
Expand All @@ -77,10 +136,7 @@ def test_nccl(self) -> None:

self.assertEqual(pg.size(), 1)

at = torch.tensor([2], device=device)
a_work = pg.allreduce([at], ReduceOp.SUM)
a_work.wait()
a_work.get_future().wait()
_test_pg(pg, torch.tensor([2], device=device))

m = nn.Linear(3, 4).to(device)
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
Expand All @@ -90,9 +146,7 @@ def test_nccl(self) -> None:
store_addr = f"localhost:{store.port}/prefix2"
pg.configure(store_addr, 0, 1)

at = torch.tensor([2], device=device)
a_work = pg.allreduce([at], ReduceOp.SUM)
a_work.wait()
_test_pg(pg, torch.tensor([2], device=device))

torch.cuda.synchronize()

Expand Down Expand Up @@ -220,22 +274,16 @@ def test_error_swallowing_process_group_wrapper(self) -> None:
wrapper = ErrorSwallowingProcessGroupWrapper(pg)
self.assertIs(wrapper.parent, pg)

t = torch.zeros(10)
work = wrapper.allreduce([t], ReduceOp.SUM)
self.assertIsInstance(work, _ErrorSwallowingWork)
work.wait()
fut = work.get_future()
fut.wait()
works = _test_pg(wrapper)
self.assertIsInstance(list(works.values())[0], _ErrorSwallowingWork)

err = RuntimeError("test")
wrapper.report_error(err)
self.assertEqual(wrapper.error(), err)

work = wrapper.allreduce([t], ReduceOp.SUM)
self.assertIsInstance(work, _DummyWork)
work.wait()
fut = work.get_future()
fut.wait()
works = _test_pg(wrapper)
for work in works.values():
self.assertIsInstance(work, _DummyWork)

def test_managed_process_group(self) -> None:
manager = Mock(spec=Manager)
Expand All @@ -246,12 +294,8 @@ def test_managed_process_group(self) -> None:

self.assertEqual(pg.size(), 123)

t = torch.zeros(10)
work = pg.allreduce([t], ReduceOp.SUM)
self.assertIsInstance(work, _ManagedWork)
work.wait()
fut = work.get_future()
fut.wait()
works = _test_pg(pg)
self.assertIsInstance(list(works.values())[0], _ManagedWork)

self.assertEqual(manager.report_error.call_count, 0)
self.assertEqual(manager.wrap_future.call_count, 1)

0 comments on commit 49d2aec

Please sign in to comment.