diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7110a85..aa9325f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -67,6 +67,20 @@ lintrunner -a ### Tests +We use `pytest` as our testing framework. To execute a specific test, use the following command: + +```sh +pytest torchft/process_group_test.py -k test_device_mesh +``` + +To run the Rust tests run: + +```sh +cargo test +``` + +To run the entire suite of tests: + ```sh $ scripts/test.sh ``` diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index a5f73e0..44e770d 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -6,21 +6,34 @@ import os from concurrent.futures import ThreadPoolExecutor -from typing import Tuple +from typing import Any, Dict, Tuple from unittest import TestCase, skipUnless from unittest.mock import Mock import torch import torch.distributed as dist from torch import nn -from torch._C._distributed_c10d import _resolve_process_group -from torch.distributed import ReduceOp, TCPStore, Work, _functional_collectives +from torch._C._distributed_c10d import ( + AllgatherOptions, + AllreduceOptions, + BroadcastOptions, + ReduceOp, + _resolve_process_group, +) +from torch.distributed import ( + ReduceOp, + TCPStore, + Work, + _functional_collectives, + get_world_size, +) from torch.distributed.device_mesh import init_device_mesh from torchft.manager import Manager from torchft.process_group import ( ErrorSwallowingProcessGroupWrapper, ManagedProcessGroup, + ProcessGroup, ProcessGroupBabyGloo, ProcessGroupBabyNCCL, ProcessGroupDummy, @@ -41,6 +54,56 @@ def dummy_init_pg() -> None: ) +def _test_pg( + pg: ProcessGroup, + example_tensor: torch.Tensor = torch.randn((2, 3), dtype=torch.float32), +) -> Dict[str, dist._Work]: + """ + Helper function to test a set of collective operations on a given process group. + """ + + shape: torch.Size = example_tensor.shape + dtype: torch.dtype = example_tensor.dtype + + # Create some dummy tensors for testing + input_tensor = example_tensor.clone() + output_tensors = [ + [torch.empty_like(input_tensor) for _ in range(get_world_size(pg))] + ] + tensor_list = [torch.empty_like(input_tensor)] + + def check_tensors(arg: Any) -> None: # pyre-ignore[2] + """Recursively check tensors for expected shape and dtype.""" + if isinstance(arg, torch.Tensor): + assert arg.dtype == dtype, f"Output dtype mismatch: {arg.dtype} != {dtype}" + assert arg.shape == shape, f"Output shape mismatch: {arg.shape} != {shape}" + elif isinstance(arg, (list, tuple)): + for item in arg: + check_tensors(item) + + # Test collectives + collectives = { + "allreduce": ([input_tensor], AllreduceOptions()), + "allgather": (output_tensors, [input_tensor], AllgatherOptions()), + "broadcast": (tensor_list, BroadcastOptions()), + "broadcast_one": (input_tensor, 0), + } + works: Dict[str, dist._Work] = {} + for coll_str, args in collectives.items(): + coll = getattr(pg, coll_str) + work = coll(*args) + works[coll_str] = work + work.wait() + fut = work.get_future() + fut.wait() + + # Check that all tensor arguments have the expected shapes and dtypes + check_tensors(args) + + print(works) + return works + + class ProcessGroupTest(TestCase): def test_gloo(self) -> None: store = TCPStore( @@ -53,11 +116,7 @@ def test_gloo(self) -> None: self.assertEqual(pg.size(), 1) - at = torch.tensor([2]) - - a_work = pg.allreduce([at], ReduceOp.SUM) - a_work.wait() - a_work.get_future().wait() + _test_pg(pg) m = nn.Linear(3, 4) m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg) @@ -77,10 +136,7 @@ def test_nccl(self) -> None: self.assertEqual(pg.size(), 1) - at = torch.tensor([2], device=device) - a_work = pg.allreduce([at], ReduceOp.SUM) - a_work.wait() - a_work.get_future().wait() + _test_pg(pg, torch.tensor([2], device=device)) m = nn.Linear(3, 4).to(device) m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg) @@ -90,9 +146,7 @@ def test_nccl(self) -> None: store_addr = f"localhost:{store.port}/prefix2" pg.configure(store_addr, 0, 1) - at = torch.tensor([2], device=device) - a_work = pg.allreduce([at], ReduceOp.SUM) - a_work.wait() + _test_pg(pg, torch.tensor([2], device=device)) torch.cuda.synchronize() @@ -220,22 +274,16 @@ def test_error_swallowing_process_group_wrapper(self) -> None: wrapper = ErrorSwallowingProcessGroupWrapper(pg) self.assertIs(wrapper.parent, pg) - t = torch.zeros(10) - work = wrapper.allreduce([t], ReduceOp.SUM) - self.assertIsInstance(work, _ErrorSwallowingWork) - work.wait() - fut = work.get_future() - fut.wait() + works = _test_pg(wrapper) + self.assertIsInstance(list(works.values())[0], _ErrorSwallowingWork) err = RuntimeError("test") wrapper.report_error(err) self.assertEqual(wrapper.error(), err) - work = wrapper.allreduce([t], ReduceOp.SUM) - self.assertIsInstance(work, _DummyWork) - work.wait() - fut = work.get_future() - fut.wait() + works = _test_pg(wrapper) + for work in works.values(): + self.assertIsInstance(work, _DummyWork) def test_managed_process_group(self) -> None: manager = Mock(spec=Manager) @@ -246,12 +294,8 @@ def test_managed_process_group(self) -> None: self.assertEqual(pg.size(), 123) - t = torch.zeros(10) - work = pg.allreduce([t], ReduceOp.SUM) - self.assertIsInstance(work, _ManagedWork) - work.wait() - fut = work.get_future() - fut.wait() + works = _test_pg(pg) + self.assertIsInstance(list(works.values())[0], _ManagedWork) self.assertEqual(manager.report_error.call_count, 0) self.assertEqual(manager.wrap_future.call_count, 1)