From 74236245cb6290630159e9b70761c9a0e82a86db Mon Sep 17 00:00:00 2001 From: Maximilian Stadler Date: Wed, 22 Nov 2023 10:18:02 -0800 Subject: [PATCH 1/6] slightly re-write partition routines and add option to use coordinates or any arbitrary device mapping --- .../models/gnn_layers/distributed_graph.py | 498 +++++++++++++++--- test/models/graphcast/test_graphcast_snmg.py | 2 +- .../meshgraphnet/test_meshgraphnet_snmg.py | 5 +- 3 files changed, 425 insertions(+), 80 deletions(-) diff --git a/modulus/models/gnn_layers/distributed_graph.py b/modulus/models/gnn_layers/distributed_graph.py index c3ec474a92..29300d0b64 100644 --- a/modulus/models/gnn_layers/distributed_graph.py +++ b/modulus/models/gnn_layers/distributed_graph.py @@ -30,7 +30,8 @@ @dataclass class GraphPartition: # pragma: no cover - """Class acting as a struct to hold all relevant buffers and variables + """ + Class acting as a struct to hold all relevant buffers and variables to define a graph partition. """ @@ -39,16 +40,16 @@ class GraphPartition: # pragma: no cover device: torch.device # data structures for local graph # of this current partition rank - local_offsets: torch.Tensor = field(init=False) - local_indices: torch.Tensor = field(init=False) + local_offsets: Optional[torch.Tensor] = None + local_indices: Optional[torch.Tensor] = None num_local_src_nodes: int = 0 num_local_dst_nodes: int = 0 num_local_indices: int = 0 # mapping from local to global ID space # for this current partition rank - partitioned_src_node_ids_to_global: torch.Tensor = field(init=False) - partitioned_dst_node_ids_to_global: torch.Tensor = field(init=False) - partitioned_indices_to_global: torch.Tensor = field(init=False) + partitioned_src_node_ids_to_global: Optional[torch.Tensor] = None + partitioned_dst_node_ids_to_global: Optional[torch.Tensor] = None + partitioned_indices_to_global: Optional[torch.Tensor] = None # buffers, sizes, and ID counts to support # distributed communication primitives # number of IDs each rank potentially sends to all other ranks @@ -71,19 +72,26 @@ def __post_init__(self): self.num_indices_in_each_partition = [None] * self.partition_size -def partition_graph_nodewise( +def partition_graph_with_id_mapping( global_offsets: torch.Tensor, global_indices: torch.Tensor, + mapping_src_ids_to_ranks: torch.Tensor, + mapping_dst_ids_to_ranks: torch.Tensor, partition_size: int, partition_rank: int, device: torch.device, -): # pragma: no cover - """Utility function which partitions a global graph given as CSC structure naively - by splitting both the IDs of source and destination nodes into chunks of equal - size. Each partition rank then manages its according chunk of both source and - destination nodes. Indices are assigned to the rank such that each rank manages - all the incoming edges for all the destination nodes on the corresponding - partition rank. +) -> GraphPartition: # pragma: no cover + """ + Utility function which partitions a global graph given as CSC structure. + It partitions both the global ID spaces for source nodes and destination nodes + based on the corresponding mappings as well as the graph structure and edge IDs. + Each rank maintains both a partition of the global source and destination nodes. + In terms of graph structure, each rank manages its own local graph structure + based on its partition of destination node IDs and all edges which - from the + point of view of each destination node on a current rank - are incoming edges. + For GNN operations this means, that features from source nodes need to be exchanged + between ranks. The partitioning scheme computes necessary indices which facilitate + later communication primitives. The function performs the partitioning based on a global graph in CPU memory for each rank independently. It could be rewritten to e.g. only do it one rank and exchange the partitions or to an algorithm that also @@ -98,6 +106,10 @@ def partition_graph_nodewise( CSC offsets, can live on the CPU global_indices : torch.Tensor CSC indices, can live on the CPU + mapping_src_ids_to_ranks: torch.Tensor + maps each global ID from every source node to its partition rank + mapping_dst_ids_to_ranks: torch.Tensor + maps each global ID from every destination node to its partition rank partition_size : int number of process groups across which graph is partitioned, i.e. the number of graph partitions @@ -121,90 +133,106 @@ def partition_graph_nodewise( num_global_src_nodes = global_indices.max().item() + 1 num_global_dst_nodes = global_offsets.size(0) - 1 - # global IDs of destination nodes in this partition - dst_nodes_in_partition = None - # global IDs of source nodes in this partition - src_nodes_in_partition = None + # global IDs of in each partition + dst_nodes_in_each_partition = [None] * partition_size + src_nodes_in_each_partition = [None] * partition_size + num_dst_nodes_in_each_partition = [None] * partition_size + num_src_nodes_in_each_partition = [None] * partition_size + mapping_global_src_ids_to_local_ids = torch.zeros_like(mapping_src_ids_to_ranks) - # get distribution of destination IDs: simply divide them into equal chunks - dst_nodes_in_partition = ( - num_global_dst_nodes + partition_size - 1 - ) // partition_size - dst_offsets_in_partition = [ - rank * dst_nodes_in_partition for rank in range(partition_size + 1) - ] - dst_offsets_in_partition[-1] = min( - num_global_dst_nodes, dst_offsets_in_partition[-1] - ) + for rank in range(partition_size): + dst_nodes_in_each_partition[rank] = torch.nonzero( + mapping_dst_ids_to_ranks == rank + ).view(-1) + src_nodes_in_each_partition[rank] = torch.nonzero( + mapping_src_ids_to_ranks == partition_rank + ).view(-1) + num_nodes = dst_nodes_in_each_partition[rank].numel() + if num_nodes == 0: + raise RuntimeError( + f"Aborting partitioning, rank {rank} has 0 destination nodes to work on." + ) + num_dst_nodes_in_each_partition[rank] = num_nodes - # get distribution of source IDs: again simply divide them into equal chunks - src_nodes_in_partition = ( - num_global_src_nodes + partition_size - 1 - ) // partition_size - src_offsets_in_partition = [ - rank * src_nodes_in_partition for rank in range(partition_size + 1) - ] - src_offsets_in_partition[-1] = min( - num_global_src_nodes, src_offsets_in_partition[-1] - ) + num_nodes = src_nodes_in_each_partition[rank].numel() + num_src_nodes_in_each_partition[rank] = num_nodes + if num_nodes == 0: + raise RuntimeError( + f"Aborting partitioning, rank {rank} has 0 source nodes to work on." + ) + # create mapping of global IDs to local IDs + # as each rank is expected to operate on disting global IDs, this is expected + # to not cause any data races + ids = src_nodes_in_each_partition[rank] + mapping_global_src_ids_to_local_ids[ids] = torch.arange( + ids.numel(), dtype=mapping_global_src_ids_to_local_ids.dtype, device=mapping_global_src_ids_to_local_ids.device + ) + + graph_partition.num_src_nodes_in_each_partition = num_src_nodes_in_each_partition + graph_partition.num_dst_nodes_in_each_partition = num_dst_nodes_in_each_partition + + # create local graph structures for rank in range(partition_size): - offset_start = dst_offsets_in_partition[rank] - offset_end = dst_offsets_in_partition[rank + 1] - offsets = global_offsets[offset_start : offset_end + 1].detach().clone() - partition_indices = global_indices[offsets[0] : offsets[-1]].detach().clone() - offsets -= offsets[0].item() + offset_start = global_offsets[dst_nodes_in_each_partition[rank]].view(-1, 1) + offset_end = global_offsets[dst_nodes_in_each_partition[rank] + 1].view(-1, 1) + degree = offset_end - offset_start + local_offsets = degree.view(-1).cumsum(dim=0) + local_offsets = torch.cat( + [ + torch.Tensor([0]).to( + dtype=local_offsets.dtype, device=local_offsets.device + ), + local_offsets, + ] + ) + + # create boolean mask to find contigouus sections of global_indices + # which are taken care of current rank without using loops + tmp = torch.arange( + global_indices.numel(), dtype=global_indices.dtype, device=global_indices.device + ) + mask = (offset_start <= tmp) & (tmp < offset_end) + mask = torch.any(mask, dim=0) - global_src_ids_per_rank, inverse_mapping = partition_indices.unique( + partition_indices = global_indices[mask].detach().clone() + global_src_ids_on_rank, inverse_mapping = partition_indices.unique( sorted=True, return_inverse=True ) - local_src_ids_per_rank = torch.arange( + local_src_ids_on_rank = torch.arange( 0, - global_src_ids_per_rank.size(0), - dtype=offsets.dtype, - device=offsets.device, - ) - global_src_ids_to_gpu = global_src_ids_per_rank // src_nodes_in_partition - remote_src_ids_per_rank = ( - global_src_ids_per_rank - global_src_ids_to_gpu * src_nodes_in_partition + global_src_ids_on_rank.size(0), + dtype=local_offsets.dtype, + device=local_offsets.device, ) + remote_src_ids_on_rank = mapping_global_src_ids_to_local_ids[ + global_src_ids_on_rank + ] - indices = local_src_ids_per_rank[inverse_mapping] + indices = local_src_ids_on_rank[inverse_mapping] graph_partition.num_indices_in_each_partition[rank] = indices.size(0) if rank == partition_rank: graph_partition.num_local_indices = indices.size(0) - graph_partition.num_local_dst_nodes = offsets.size(0) - 1 - graph_partition.num_dst_nodes_in_each_partition = [ - dst_offsets_in_partition[rank + 1] - dst_offsets_in_partition[rank] - for rank in range(partition_size) - ] - graph_partition.num_local_src_nodes = global_src_ids_per_rank.size(0) - graph_partition.num_src_nodes_in_each_partition = [ - src_offsets_in_partition[rank + 1] - src_offsets_in_partition[rank] - for rank in range(partition_size) - ] - - graph_partition.partitioned_src_node_ids_to_global = range( - src_offsets_in_partition[rank], src_offsets_in_partition[rank + 1] + graph_partition.num_local_dst_nodes = num_dst_nodes_in_each_partition[rank] + graph_partition.num_local_src_nodes = local_src_ids_on_rank.size(0) + graph_partition.partitioned_src_node_ids_to_global = ( + src_nodes_in_each_partition[rank] ) - graph_partition.partitioned_dst_node_ids_to_global = range( - dst_offsets_in_partition[rank], dst_offsets_in_partition[rank + 1] + graph_partition.partitioned_dst_node_ids_to_global = ( + dst_nodes_in_each_partition[rank] ) - graph_partition.partitioned_indices_to_global = range( - global_offsets[offset_start], global_offsets[offset_end] - ) - - graph_partition.local_offsets = offsets.to(device=device) + graph_partition.partitioned_indices_to_global = partition_indices + graph_partition.local_offsets = local_offsets.to(device=device) graph_partition.local_indices = indices.to(device=device) for rank_offset in range(partition_size): - mask = global_src_ids_to_gpu == rank_offset + mask = mapping_src_ids_to_ranks[global_src_ids_on_rank] == rank_offset if partition_rank == rank_offset: # indices to send to this rank from this rank graph_partition.scatter_indices[rank] = ( - remote_src_ids_per_rank[mask] + remote_src_ids_on_rank[mask] .detach() .clone() .to(device=device, dtype=torch.int64) @@ -223,6 +251,323 @@ def partition_graph_nodewise( return graph_partition +def partition_graph_nodewise( + global_offsets: torch.Tensor, + global_indices: torch.Tensor, + partition_size: int, + partition_rank: int, + device: torch.device, +) -> GraphPartition: # pragma: no cover + """ + Utility function which partitions a global graph given as CSC structure naively + by splitting both the IDs of source and destination nodes into chunks of equal + size. Each partition rank then manages its according chunk of both source and + destination nodes. Indices are assigned to the rank such that each rank manages + all the incoming edges for all the destination nodes on the corresponding + partition rank. + The function performs the partitioning based on a global graph in CPU + memory for each rank independently. It could be rewritten to e.g. only + do it one rank and exchange the partitions or to an algorithm that also + assumes an already distributed global graph, however, we expect global + graphs to fit in CPU memory. After the partitioning, we can get rid off + the larger one in CPU memory, only keep the local graphs on each GPU, and + avoid tedious gather/scatter routines for exchanging partitions in the process. + + Parameters + ---------- + global_offsets : torch.Tensor + CSC offsets, can live on the CPU + global_indices : torch.Tensor + CSC indices, can live on the CPU + partition_size : int + number of process groups across which graph is partitioned, + i.e. the number of graph partitions + partition_rank : int + rank within process group managing the distributed graph, i.e. + the rank determining which partition the corresponding local rank + will manage + device : torch.device + device connected to the passed partition rank, i.e. the device + on which the local graph and related buffers will live on + """ + + num_global_src_nodes = global_indices.max().item() + 1 + num_global_dst_nodes = global_offsets.size(0) - 1 + num_dst_nodes_per_partition = ( + num_global_dst_nodes + partition_size - 1 + ) // partition_size + num_src_nodes_per_partition = ( + num_global_src_nodes + partition_size - 1 + ) // partition_size + + mapping_dst_ids_to_ranks = ( + torch.arange( + num_global_dst_nodes, + dtype=global_offsets.dtype, + device=global_offsets.device, + ) + // num_dst_nodes_per_partition + ) + mapping_src_ids_to_ranks = ( + torch.arange( + num_global_src_nodes, + dtype=global_offsets.dtype, + device=global_offsets.device, + ) + // num_src_nodes_per_partition + ) + + return partition_graph_with_id_mapping( + global_offsets, + global_indices, + mapping_src_ids_to_ranks, + mapping_dst_ids_to_ranks, + partition_size, + partition_rank, + device, + ) + + +def partition_graph_coordinatewise( + global_offsets: torch.Tensor, + global_indices: torch.Tensor, + src_coordinates: torch.Tensor, + dst_coordinates: torch.Tensor, + coordinate_separators_min: List[List[float]], + coordinate_separators_max: List[List[float]], + partition_size: int, + partition_rank: int, + device: torch.device, +) -> GraphPartition: # pragma: no cover + """ + Utility function which partitions a global graph given as CSC structure. + It partitions both the global ID spaces for source nodes and destination nodes + based on their corresponding coordinates. Each partition will manage points which + fulfill the boxconstraints specified by the specified coordinate separators. For each + rank one is expected to specify the minimum and maximum coordinate value for each dimension. + A partition the will manage all points for which ``min_val <= coord[d] < max_val`` holds. + Specifying either of these separation values as `None` will result in this constraint not being + considered for the division. Each rank maintains both a partition of the global source and + destination nodes resulting from this subspace division. + In terms of graph structure, each rank manages its own local graph structure + based on its partition of destination node IDs and all edges which - from the + point of view of each destination node on a current rank - are incoming edges. + For GNN operations this means, that features from source nodes need to be exchanged + between ranks. The partitioning scheme computes necessary indices which facilitate + later communication primitives. + The function performs the partitioning based on a global graph in CPU + memory for each rank independently. It could be rewritten to e.g. only + do it one rank and exchange the partitions or to an algorithm that also + assumes an already distributed global graph, however, we expect global + graphs to fit in CPU memory. After the partitioning, we can get rid off + the larger one in CPU memory, only keep the local graphs on each GPU, and + avoid tedious gather/scatter routines for exchanging partitions in the process. + + Examples + -------- + >>> import torch + >>> from modulus.models.gnn_layers import partition_graph_coordinatewise + >>> + >>> # simple graph with a degree of 2 per node + >>> num_src_nodes = 8 + >>> num_dst_nodes = 4 + >>> offsets = torch.arange(num_dst_nodes + 1, dtype=torch.int64) * 2 + >>> indices = torch.arange(num_src_nodes, dtype=torch.int64) + >>> + >>> # example with 2D coordinates + >>> # assuming partitioning a 2D problem into the 4 quadrants + >>> partition_size = 4 + >>> partition_rank = 0 + >>> coordinate_separators_min = [[0, 0], [None, 0], [None, None], [0, None]] + >>> coordinate_separators_max = [[None, None], [0, None], [0, 0], [None, 0]] + >>> device = "cuda:0" + >>> # dummy coordinates + >>> src_coordinates = torch.FloatTensor( + >>> [ + >>> [-1.0, 1.0], + >>> [1.0, 1.0], + >>> [-1.0, -1.0], + >>> [1.0, -1.0], + >>> [-2.0, 2.0], + >>> [2.0, 2.0], + >>> [-2.0, -2.0], + >>> [2.0, -2.0], + >>> ] + >>> ) + >>> dst_coordinates = torch.FloatTensor( + >>> [ + >>> [-1.0, 1.0], + >>> [1.0, 1.0], + >>> [-1.0, -1.0], + >>> [1.0, -1.0], + >>> ] + >>> ) + >>> # call partitioning routine + >>> pg = partition_graph_coordinatewise( + >>> offsets, + >>> indices, + >>> src_coordinates, + >>> dst_coordinates, + >>> coordinate_separators_min, + >>> coordinate_separators_max, + >>> partition_size, + >>> partition_rank, + >>> device, + >>> ) + GraphPartition( + partition_size=4, + partition_rank=0, + device="cuda:0", + local_offsets=tensor([0, 2], device="cuda:0"), + local_indices=tensor([0, 1], device="cuda:0"), + num_local_src_nodes=2, + num_local_dst_nodes=1, + num_local_indices=2, + partitioned_src_node_ids_to_global=tensor([1, 5]), + partitioned_dst_node_ids_to_global=tensor([1]), + partitioned_indices_to_global=tensor([2, 3]), + sizes=[[0, 1, 1, 0], [0, 1, 1, 0], [1, 0, 0, 1], [1, 0, 0, 1]], + scatter_indices=[ + tensor([], device="cuda:0", dtype=torch.int64), + tensor([0], device="cuda:0"), + tensor([1], device="cuda:0"), + tensor([], device="cuda:0", dtype=torch.int64), + ], + num_src_nodes_in_each_partition=[2, 2, 2, 2], + num_dst_nodes_in_each_partition=[1, 1, 1, 1], + num_indices_in_each_partition=[2, 2, 2, 2], + ) + >>> + >>> # example with lat-long coordinates + >>> # dummy coordinates + >>> src_lat = torch.FloatTensor([-75, -60, -45, -30, 30, 45, 60, 75]).view(-1, 1) + >>> dst_lat = torch.FloatTensor([-60, -30, 30, 30]).view(-1, 1) + >>> src_long = torch.FloatTensor([-135, -135, 135, 135, -45, -45, 45, 45]).view(-1, 1) + >>> dst_long = torch.FloatTensor([-135, 135, -45, 45]).view(-1, 1) + >>> src_coordinates = torch.cat([src_lat, src_long], dim=1) + >>> dst_coordinates = torch.cat([dst_lat, dst_long], dim=1) + >>> # separate sphere at equator and 0 degree longitude into 4 parts + >>> coordinate_separators_min = [ + >>> [-90, -180], + >>> [-90, 0], + >>> [0, -180], + >>> [0, 0], + >>> ] + >>> coordinate_separators_max = [ + >>> [0, 0], + >>> [0, 180], + >>> [90, 0], + >>> [90, 180], + >>> ] + >>> # call partitioning routine + >>> partition_size = 4 + >>> partition_rank = 0 + >>> device = "cuda:0" + >>> pg = partition_graph_coordinatewise( + >>> offsets, + >>> indices, + >>> src_coordinates, + >>> dst_coordinates, + >>> coordinate_separators_min, + >>> coordinate_separators_max, + >>> partition_size, + >>> partition_rank, + >>> device, + >>> ) + GraphPartition( + partition_size=4, + partition_rank=0, + device="cuda:0", + local_offsets=tensor([0, 2], device="cuda:0"), + local_indices=tensor([0, 1], device="cuda:0"), + num_local_src_nodes=2, + num_local_dst_nodes=1, + num_local_indices=2, + partitioned_src_node_ids_to_global=tensor([0, 1]), + partitioned_dst_node_ids_to_global=tensor([0]), + partitioned_indices_to_global=tensor([0, 1]), + sizes=[[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0], [0, 0, 0, 2]], + scatter_indices=[ + tensor([0, 1], device="cuda:0"), + tensor([], device="cuda:0", dtype=torch.int64), + tensor([], device="cuda:0", dtype=torch.int64), + tensor([], device="cuda:0", dtype=torch.int64), + ], + num_src_nodes_in_each_partition=[2, 2, 2, 2], + num_dst_nodes_in_each_partition=[1, 1, 1, 1], + num_indices_in_each_partition=[2, 2, 2, 2], + ) + + Parameters + ---------- + global_offsets : torch.Tensor + CSC offsets, can live on the CPU + global_indices : torch.Tensor + CSC indices, can live on the CPU + src_coordinates : torch.Tensor + coordinates of each source node + dst_coordinates : torch.Tensor + coordinates of each destination node + partition_size : int + number of process groups across which graph is partitioned, + i.e. the number of graph partitions + partition_rank : int + rank within process group managing the distributed graph, i.e. + the rank determining which partition the corresponding local rank + will manage + device : torch.device + device connected to the passed partition rank, i.e. the device + on which the local graph and related buffers will live on + """ + + dim = src_coordinates.size(-1) + assert dst_coordinates.size(-1) == dim + assert len(coordinate_separators_min) == partition_size + assert len(coordinate_separators_max) == partition_size + for rank in range(partition_size): + assert len(coordinate_separators_min[rank]) == dim + assert len(coordinate_separators_max[rank]) == dim + + num_global_src_nodes = global_indices.max().item() + 1 + num_global_dst_nodes = global_offsets.size(0) - 1 + + mapping_dst_ids_to_ranks = torch.zeros( + num_global_dst_nodes, dtype=global_offsets.dtype, device=global_offsets.device + ) + mapping_src_ids_to_ranks = torch.zeros( + num_global_src_nodes, + dtype=global_offsets.dtype, + device=global_offsets.device, + ) + + def _assign_ranks(mapping, coordinates): + for p in range(partition_size): + mask = torch.ones_like(mapping).to(dtype=torch.bool) + for d in range(dim): + min_val, max_val = ( + coordinate_separators_min[p][d], + coordinate_separators_max[p][d], + ) + if min_val is not None: + mask = mask & (coordinates[:, d] >= min_val) + if max_val is not None: + mask = mask & (coordinates[:, d] < max_val) + mapping[mask] = p + + _assign_ranks(mapping_src_ids_to_ranks, src_coordinates) + _assign_ranks(mapping_dst_ids_to_ranks, dst_coordinates) + + return partition_graph_with_id_mapping( + global_offsets, + global_indices, + mapping_src_ids_to_ranks, + mapping_dst_ids_to_ranks, + partition_size, + partition_rank, + device, + ) + + class DistributedGraph: def __init__( self, @@ -232,7 +577,8 @@ def __init__( graph_partition_group_name: str = None, graph_partition: Optional[GraphPartition] = None, ): # pragma: no cover - """Utility Class representing a distributed graph based on a given + """ + Utility Class representing a distributed graph based on a given partitioning of a CSC graph structure. By default, a naive node-wise partitioning scheme is applied, see ``partition_graph_nodewise`` for details on that. This class then wraps necessary communication primitives diff --git a/test/models/graphcast/test_graphcast_snmg.py b/test/models/graphcast/test_graphcast_snmg.py index 53b3f1f8ae..5cdeb776b9 100644 --- a/test/models/graphcast/test_graphcast_snmg.py +++ b/test/models/graphcast/test_graphcast_snmg.py @@ -172,7 +172,7 @@ def run_test_distributed_graphcast( @pytest.mark.multigpu -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) @pytest.mark.parametrize("do_concat_trick", [False, True]) @pytest.mark.parametrize("do_checkpointing", [False, True]) def test_distributed_graphcast(dtype, do_concat_trick, do_checkpointing): diff --git a/test/models/meshgraphnet/test_meshgraphnet_snmg.py b/test/models/meshgraphnet/test_meshgraphnet_snmg.py index e880e02d7e..7ebbca9449 100644 --- a/test/models/meshgraphnet/test_meshgraphnet_snmg.py +++ b/test/models/meshgraphnet/test_meshgraphnet_snmg.py @@ -28,7 +28,6 @@ from modulus.distributed import DistributedManager -@import_or_fail("dgl") def run_test_distributed_meshgraphnet(rank, world_size, dtype): from modulus.models.gnn_layers.utils import CuGraphCSC from modulus.models.meshgraphnet.meshgraphnet import MeshGraphNet @@ -184,8 +183,8 @@ def run_test_distributed_meshgraphnet(rank, world_size, dtype): @pytest.mark.multigpu -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) -def test_distributed_meshgraphnet(dtype): +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) +def test_distributed_meshgraphnet(dtype, pytestconfig): num_gpus = torch.cuda.device_count() assert num_gpus >= 2, "Not enough GPUs available for test" world_size = num_gpus From 4703c3b3fdeab3834b94de4624e4641e459b1ae3 Mon Sep 17 00:00:00 2001 From: Maximilian Stadler Date: Wed, 22 Nov 2023 11:08:42 -0800 Subject: [PATCH 2/6] format and add some further comments --- .../models/gnn_layers/distributed_graph.py | 41 ++++++++++++++----- .../meshgraphnet/test_meshgraphnet_snmg.py | 1 + 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/modulus/models/gnn_layers/distributed_graph.py b/modulus/models/gnn_layers/distributed_graph.py index 29300d0b64..c5e19a7c0c 100644 --- a/modulus/models/gnn_layers/distributed_graph.py +++ b/modulus/models/gnn_layers/distributed_graph.py @@ -99,6 +99,8 @@ def partition_graph_with_id_mapping( graphs to fit in CPU memory. After the partitioning, we can get rid off the larger one in CPU memory, only keep the local graphs on each GPU, and avoid tedious gather/scatter routines for exchanging partitions in the process. + Note: It is up to the user to ensure that the provided mapping is valid. In particular, + we expect each rank to receive a non-empty partition of node IDs. Parameters ---------- @@ -129,9 +131,6 @@ def partition_graph_with_id_mapping( # -------------------------------------------------------------- # initialize temporary variables used in computing the partition - # global information about node ids and edge ids - num_global_src_nodes = global_indices.max().item() + 1 - num_global_dst_nodes = global_offsets.size(0) - 1 # global IDs of in each partition dst_nodes_in_each_partition = [None] * partition_size @@ -166,7 +165,9 @@ def partition_graph_with_id_mapping( # to not cause any data races ids = src_nodes_in_each_partition[rank] mapping_global_src_ids_to_local_ids[ids] = torch.arange( - ids.numel(), dtype=mapping_global_src_ids_to_local_ids.dtype, device=mapping_global_src_ids_to_local_ids.device + ids.numel(), + dtype=mapping_global_src_ids_to_local_ids.dtype, + device=mapping_global_src_ids_to_local_ids.device, ) graph_partition.num_src_nodes_in_each_partition = num_src_nodes_in_each_partition @@ -190,7 +191,9 @@ def partition_graph_with_id_mapping( # create boolean mask to find contigouus sections of global_indices # which are taken care of current rank without using loops tmp = torch.arange( - global_indices.numel(), dtype=global_indices.dtype, device=global_indices.device + global_indices.numel(), + dtype=global_indices.dtype, + device=global_indices.device, ) mask = (offset_start <= tmp) & (tmp < offset_end) mask = torch.any(mask, dim=0) @@ -362,6 +365,8 @@ def partition_graph_coordinatewise( graphs to fit in CPU memory. After the partitioning, we can get rid off the larger one in CPU memory, only keep the local graphs on each GPU, and avoid tedious gather/scatter routines for exchanging partitions in the process. + Note: It is up to the user to ensure that the provided partition is valid. + In particular, we expect each rank to receive a non-empty partition of node IDs. Examples -------- @@ -521,12 +526,28 @@ def partition_graph_coordinatewise( """ dim = src_coordinates.size(-1) - assert dst_coordinates.size(-1) == dim - assert len(coordinate_separators_min) == partition_size - assert len(coordinate_separators_max) == partition_size + if dst_coordinates.size(-1) != dim: + raise ValueError() + if len(coordinate_separators_min) != partition_size: + a, b = len(coordinate_separators_min), partition_size + error_msg = "Expected len(coordinate_separators_min) == partition_size" + error_msg += f", but got {a} and {b} respectively" + raise ValueError(error_msg) + if len(coordinate_separators_max) != partition_size: + a, b = len(coordinate_separators_max), partition_size + error_msg = "Expected len(coordinate_separators_max) == partition_size" + error_msg += f", but got {a} and {b} respectively" + raise ValueError(error_msg) + for rank in range(partition_size): - assert len(coordinate_separators_min[rank]) == dim - assert len(coordinate_separators_max[rank]) == dim + if len(coordinate_separators_min[rank]) != dim: + a, b = len(coordinate_separators_min[rank]), dim + error_msg = f"Expected len(coordinate_separators_min[{rank}]) == dim" + error_msg += f", but got {a} and {b} respectively" + if len(coordinate_separators_max[rank]) != dim: + a, b = len(coordinate_separators_max[rank]), dim + error_msg = f"Expected len(coordinate_separators_max[{rank}]) == dim" + error_msg += f", but got {a} and {b} respectively" num_global_src_nodes = global_indices.max().item() + 1 num_global_dst_nodes = global_offsets.size(0) - 1 diff --git a/test/models/meshgraphnet/test_meshgraphnet_snmg.py b/test/models/meshgraphnet/test_meshgraphnet_snmg.py index 7ebbca9449..aeb1e83111 100644 --- a/test/models/meshgraphnet/test_meshgraphnet_snmg.py +++ b/test/models/meshgraphnet/test_meshgraphnet_snmg.py @@ -182,6 +182,7 @@ def run_test_distributed_meshgraphnet(rank, world_size, dtype): DistributedManager.cleanup() +@import_or_fail("dgl") @pytest.mark.multigpu @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) def test_distributed_meshgraphnet(dtype, pytestconfig): From d68b764679c5b19983d04145dadc659e48e9e3d7 Mon Sep 17 00:00:00 2001 From: Maximilian Stadler Date: Fri, 8 Dec 2023 08:20:12 -0800 Subject: [PATCH 3/6] address feedback, add more tests --- CHANGELOG.md | 1 + modulus/models/gnn_layers/__init__.py | 8 +- .../models/gnn_layers/distributed_graph.py | 151 +++++---- modulus/models/gnn_layers/graph.py | 4 +- .../meshgraphnet/test_meshgraphnet_snmg.py | 59 +++- test/models/test_graph_partition.py | 296 ++++++++++++++++++ 6 files changed, 457 insertions(+), 62 deletions(-) create mode 100644 test/models/test_graph_partition.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c5143313a..e8333065ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ process group config. - Updated Frechet Inception Distance to use Wasserstein 2-norm with improved stability. - Molecular Dynamics example. +- Improved usage of GraphPartition, added more flexible ways of defining a partitioned graph. ### Changed diff --git a/modulus/models/gnn_layers/__init__.py b/modulus/models/gnn_layers/__init__.py index c9906362b3..9f2bcbb591 100644 --- a/modulus/models/gnn_layers/__init__.py +++ b/modulus/models/gnn_layers/__init__.py @@ -12,5 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .distributed_graph import DistributedGraph +from .distributed_graph import ( + DistributedGraph, + GraphPartition, + partition_graph_by_coordinate_bbox, + partition_graph_nodewise, + partition_graph_with_id_mapping, +) from .graph import CuGraphCSC diff --git a/modulus/models/gnn_layers/distributed_graph.py b/modulus/models/gnn_layers/distributed_graph.py index c5e19a7c0c..7edeaf0fbf 100644 --- a/modulus/models/gnn_layers/distributed_graph.py +++ b/modulus/models/gnn_layers/distributed_graph.py @@ -13,7 +13,7 @@ # limitations under the License. -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import List, Optional import torch @@ -29,47 +29,101 @@ @dataclass -class GraphPartition: # pragma: no cover +class GraphPartition: """ - Class acting as a struct to hold all relevant buffers and variables - to define a graph partition. + Class acting as an "utility" structure to hold all relevant buffers and variables + to define a graph partition and faciliate exchange of necessary buffers for + message passing on a distributed graph. + + A global graph is assumed to be defined through a global CSC structure + defining edges between source nodes and destination nodes which are assumed + to be numbered indexed by contiguous IDs. Hence, features associated to both + nodes and edges can be represented through dense feature tables globally. + When partitioning graph and features, we distribute destination nodes and all + their incoming edges on all ranks within the partition group based on a specified + mapping. Based on this scheme, there will a be a difference between + partitioned source nodes (partitioned features) and local source node + IDs which refer to the node IDs within the local graph defined by the + destination nodes on each rank. To allow message passing, communication + primitives have to ensure to gather all corresponding features for all + local source nodes based on the applied partitioning scheme. This also + leads to the distinction of local source node IDs and remote source node + IDs on each rank where the latter simply refers to the local row ID within + the dense partitioning of node features and the former indicates the source + of a message for each edge within each local graph. + + Parameters + ---------- + partition_size : int + size of partition + partition_rank : int + local rank of this partition w.r.t. group of partitions + device : torch.device + device handle for buffers within this partition rank """ partition_size: int partition_rank: int device: torch.device - # data structures for local graph - # of this current partition rank + + # data structures defining partition + # set in after initialization or during execution + # of desired partition scheme + + # local CSR offsets defining local graph on each `partition_rank` local_offsets: Optional[torch.Tensor] = None + # local CSR indices defining local graph on each `partition_rank` local_indices: Optional[torch.Tensor] = None - num_local_src_nodes: int = 0 - num_local_dst_nodes: int = 0 - num_local_indices: int = 0 - # mapping from local to global ID space - # for this current partition rank + # number of source nodes in local graph on each `partition_rank` + num_local_src_nodes: int = -1 + # number of destination nodes in local graph on each `partition_rank` + num_local_dst_nodes: int = -1 + # number of edges in local graph on each `partition_rank` + num_local_indices: int = -1 + # mapping from local to global ID space (source node IDs) partitioned_src_node_ids_to_global: Optional[torch.Tensor] = None + # mapping from local to global ID space (destination node IDs) partitioned_dst_node_ids_to_global: Optional[torch.Tensor] = None + # mapping from local to global ID space (edge IDs) partitioned_indices_to_global: Optional[torch.Tensor] = None - # buffers, sizes, and ID counts to support - # distributed communication primitives + + # utility lists and sizes required for exchange of messages + # between graph partitions through distributed communication primitives + # number of IDs each rank potentially sends to all other ranks - sizes: List[List[int]] = field(init=False) + sizes: Optional[List[List[int]]] = None # local indices of IDs current rank sends to all other ranks - scatter_indices: List[torch.Tensor] = field(init=False) - num_src_nodes_in_each_partition: List[int] = field(init=False) - num_dst_nodes_in_each_partition: List[int] = field(init=False) - num_indices_in_each_partition: List[int] = field(init=False) + scatter_indices: Optional[List[torch.Tensor]] = None + # number of global source nodes for each `partition_rank` + num_src_nodes_in_each_partition: Optional[List[int]] = None + # number of global destination nodes for each `partition_rank` + num_dst_nodes_in_each_partition: Optional[List[int]] = None + # number of global indices for each `partition_rank` + num_indices_in_each_partition: Optional[List[int]] = None def __post_init__(self): # after partition_size has been set in __init__ - self.sizes = [ - [None for _ in range(self.partition_size)] - for _ in range(self.partition_size) - ] - self.scatter_indices = [None] * self.partition_size - self.num_src_nodes_in_each_partition = [None] * self.partition_size - self.num_dst_nodes_in_each_partition = [None] * self.partition_size - self.num_indices_in_each_partition = [None] * self.partition_size + if self.partition_size <= 0: + raise ValueError(f"Expected partition_size > 0, got {self.partition_size}") + if not (0 <= self.partition_rank < self.partition_size): + raise ValueError( + f"Expected 0 <= partition_rank < {self.partition_size}, got {self.partiton_rank}" + ) + + if self.sizes is None: + self.sizes = [ + [None for _ in range(self.partition_size)] + for _ in range(self.partition_size) + ] + + if self.scatter_indices is None: + self.scatter_indices = [None] * self.partition_size + if self.num_src_nodes_in_each_partition is None: + self.num_src_nodes_in_each_partition = [None] * self.partition_size + if self.num_dst_nodes_in_each_partition is None: + self.num_dst_nodes_in_each_partition = [None] * self.partition_size + if self.num_indices_in_each_partition is None: + self.num_indices_in_each_partition = [None] * self.partition_size def partition_graph_with_id_mapping( @@ -80,18 +134,12 @@ def partition_graph_with_id_mapping( partition_size: int, partition_rank: int, device: torch.device, -) -> GraphPartition: # pragma: no cover +) -> GraphPartition: """ Utility function which partitions a global graph given as CSC structure. It partitions both the global ID spaces for source nodes and destination nodes based on the corresponding mappings as well as the graph structure and edge IDs. - Each rank maintains both a partition of the global source and destination nodes. - In terms of graph structure, each rank manages its own local graph structure - based on its partition of destination node IDs and all edges which - from the - point of view of each destination node on a current rank - are incoming edges. - For GNN operations this means, that features from source nodes need to be exchanged - between ranks. The partitioning scheme computes necessary indices which facilitate - later communication primitives. + For more details on partitioning in general see `GraphPartition`. The function performs the partitioning based on a global graph in CPU memory for each rank independently. It could be rewritten to e.g. only do it one rank and exchange the partitions or to an algorithm that also @@ -161,7 +209,7 @@ def partition_graph_with_id_mapping( ) # create mapping of global IDs to local IDs - # as each rank is expected to operate on disting global IDs, this is expected + # as each rank is expected to operate on distint global IDs, this is expected # to not cause any data races ids = src_nodes_in_each_partition[rank] mapping_global_src_ids_to_local_ids[ids] = torch.arange( @@ -260,14 +308,11 @@ def partition_graph_nodewise( partition_size: int, partition_rank: int, device: torch.device, -) -> GraphPartition: # pragma: no cover +) -> GraphPartition: """ Utility function which partitions a global graph given as CSC structure naively by splitting both the IDs of source and destination nodes into chunks of equal - size. Each partition rank then manages its according chunk of both source and - destination nodes. Indices are assigned to the rank such that each rank manages - all the incoming edges for all the destination nodes on the corresponding - partition rank. + size. For more details on partitioning in general see `GraphPartition`. The function performs the partitioning based on a global graph in CPU memory for each rank independently. It could be rewritten to e.g. only do it one rank and exchange the partitions or to an algorithm that also @@ -331,33 +376,27 @@ def partition_graph_nodewise( ) -def partition_graph_coordinatewise( +def partition_graph_by_coordinate_bbox( global_offsets: torch.Tensor, global_indices: torch.Tensor, src_coordinates: torch.Tensor, dst_coordinates: torch.Tensor, - coordinate_separators_min: List[List[float]], - coordinate_separators_max: List[List[float]], + coordinate_separators_min: List[List[Optional[float]]], + coordinate_separators_max: List[List[Optional[float]]], partition_size: int, partition_rank: int, device: torch.device, -) -> GraphPartition: # pragma: no cover +) -> GraphPartition: """ Utility function which partitions a global graph given as CSC structure. It partitions both the global ID spaces for source nodes and destination nodes based on their corresponding coordinates. Each partition will manage points which fulfill the boxconstraints specified by the specified coordinate separators. For each rank one is expected to specify the minimum and maximum coordinate value for each dimension. - A partition the will manage all points for which ``min_val <= coord[d] < max_val`` holds. - Specifying either of these separation values as `None` will result in this constraint not being - considered for the division. Each rank maintains both a partition of the global source and + A partition the will manage all points for which ``min_val <= coord[d] < max_val`` holds. If one + of the constraints is passed as `None`, it is assumed to be non-binding and the partition is defined + by the corresponding half-space. Each rank maintains both a partition of the global source and destination nodes resulting from this subspace division. - In terms of graph structure, each rank manages its own local graph structure - based on its partition of destination node IDs and all edges which - from the - point of view of each destination node on a current rank - are incoming edges. - For GNN operations this means, that features from source nodes need to be exchanged - between ranks. The partitioning scheme computes necessary indices which facilitate - later communication primitives. The function performs the partitioning based on a global graph in CPU memory for each rank independently. It could be rewritten to e.g. only do it one rank and exchange the partitions or to an algorithm that also @@ -371,7 +410,7 @@ def partition_graph_coordinatewise( Examples -------- >>> import torch - >>> from modulus.models.gnn_layers import partition_graph_coordinatewise + >>> from modulus.models.gnn_layers import partition_graph_by_coordinate_bbox >>> >>> # simple graph with a degree of 2 per node >>> num_src_nodes = 8 @@ -408,7 +447,7 @@ def partition_graph_coordinatewise( >>> ] >>> ) >>> # call partitioning routine - >>> pg = partition_graph_coordinatewise( + >>> pg = partition_graph_by_coordinate_bbox( >>> offsets, >>> indices, >>> src_coordinates, @@ -468,7 +507,7 @@ def partition_graph_coordinatewise( >>> partition_size = 4 >>> partition_rank = 0 >>> device = "cuda:0" - >>> pg = partition_graph_coordinatewise( + >>> pg = partition_graph_by_coordinate_bbox( >>> offsets, >>> indices, >>> src_coordinates, diff --git a/modulus/models/gnn_layers/graph.py b/modulus/models/gnn_layers/graph.py index 501d9f5012..1799f07958 100644 --- a/modulus/models/gnn_layers/graph.py +++ b/modulus/models/gnn_layers/graph.py @@ -25,7 +25,7 @@ # for Python versions < 3.11 from typing_extensions import Self -from modulus.models.gnn_layers import DistributedGraph +from modulus.models.gnn_layers import DistributedGraph, GraphPartition try: from pylibcugraphops.pytorch import BipartiteCSC, StaticCSC @@ -90,6 +90,7 @@ def __init__( cache_graph: bool = True, partition_size: Optional[int] = -1, partition_group_name: Optional[str] = None, + graph_partition: Optional[GraphPartition] = None, ) -> None: self.offsets = offsets self.indices = indices @@ -121,6 +122,7 @@ def __init__( self.indices, partition_size, partition_group_name, + graph_partition=graph_partition, ) # overwrite graph information with local graph after distribution self.offsets = self.dist_graph.graph_partition.local_offsets diff --git a/test/models/meshgraphnet/test_meshgraphnet_snmg.py b/test/models/meshgraphnet/test_meshgraphnet_snmg.py index aeb1e83111..a4bd4a17c7 100644 --- a/test/models/meshgraphnet/test_meshgraphnet_snmg.py +++ b/test/models/meshgraphnet/test_meshgraphnet_snmg.py @@ -26,9 +26,13 @@ from torch.nn.parallel import DistributedDataParallel from modulus.distributed import DistributedManager +from modulus.models.gnn_layers import ( + partition_graph_by_coordinate_bbox, + partition_graph_with_id_mapping, +) -def run_test_distributed_meshgraphnet(rank, world_size, dtype): +def run_test_distributed_meshgraphnet(rank, world_size, dtype, partition_scheme): from modulus.models.gnn_layers.utils import CuGraphCSC from modulus.models.meshgraphnet.meshgraphnet import MeshGraphNet @@ -92,6 +96,51 @@ def run_test_distributed_meshgraphnet(rank, world_size, dtype): num_nodes, num_nodes, ) + + graph_partition = None + + if partition_scheme == "nodewise": + pass # nodewise is default + + elif partition_scheme == "coordinate_bbox": + src_coordinates = torch.rand((num_nodes, 1), device=offsets.device) + dst_coordinates = src_coordinates + + step_size = 1.0 / (world_size + 1) + coordinate_separators_min = [[step_size * p] for p in range(world_size)] + coordinate_separators_max = [[step_size * (p + 1)] for p in range(world_size)] + + graph_partition = partition_graph_by_coordinate_bbox( + offsets, + indices, + src_coordinates, + dst_coordinates, + coordinate_separators_min, + coordinate_separators_max, + world_size, + manager.rank, + manager.device, + ) + + elif partition_scheme == "mapping": + mapping_src_ids_to_ranks = torch.randint( + 0, world_size, (num_nodes,), device=offsets.device + ) + mapping_dst_ids_to_ranks = mapping_src_ids_to_ranks + + graph_partition = partition_graph_with_id_mapping( + offsets, + indices, + mapping_src_ids_to_ranks, + mapping_dst_ids_to_ranks, + world_size, + manager.rank, + manager.device, + ) + + else: + assert False # only schemes above are supported + graph_multi_gpu = CuGraphCSC( offsets.to(manager.device), indices.to(manager.device), @@ -99,6 +148,7 @@ def run_test_distributed_meshgraphnet(rank, world_size, dtype): num_nodes, partition_size=world_size, partition_group_name="graph_partition", + graph_partition=graph_partition, ) nfeat_single_gpu = ( @@ -184,19 +234,20 @@ def run_test_distributed_meshgraphnet(rank, world_size, dtype): @import_or_fail("dgl") @pytest.mark.multigpu +@pytest.mark.parametrize("partition_scheme", ["nodewise", "coordinate_bbox", "mapping"]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) -def test_distributed_meshgraphnet(dtype, pytestconfig): +def test_distributed_meshgraphnet(dtype, partition_scheme, pytestconfig): num_gpus = torch.cuda.device_count() assert num_gpus >= 2, "Not enough GPUs available for test" world_size = num_gpus torch.multiprocessing.spawn( run_test_distributed_meshgraphnet, - args=(world_size, dtype), + args=(world_size, dtype, partition_scheme), nprocs=world_size, start_method="spawn", ) if __name__ == "__main__": - pytest.main([__file__]) + py \ No newline at end of file diff --git a/test/models/test_graph_partition.py b/test/models/test_graph_partition.py new file mode 100644 index 0000000000..571f6e7680 --- /dev/null +++ b/test/models/test_graph_partition.py @@ -0,0 +1,296 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import pytest +import torch +from modulus.models.gnn_layers import ( + partition_graph_by_coordinate_bbox, + partition_graph_with_id_mapping, + partition_graph_nodewise, + GraphPartition, +) + + +@pytest.fixture +def global_graph(): + # simple graph with a degree of 2 per node + num_src_nodes = 8 + num_dst_nodes = 4 + offsets = torch.arange(num_dst_nodes + 1, dtype=torch.int64) * 2 + indices = torch.arange(num_src_nodes, dtype=torch.int64) + + return (offsets, indices, num_src_nodes, num_dst_nodes) + + +def assert_partitions_are_equal(a, b): + is_equal = True + + attributes = [ + "partition_size", + "partition_rank", + "device", + "num_local_src_nodes", + "num_local_dst_nodes", + "num_local_indices", + "sizes", + "num_src_nodes_in_each_partition", + "num_dst_nodes_in_each_partition", + "num_indices_in_each_partition" + ] + torch_attributes = [ + "local_offsets", + "local_indices", + "scatter_indices", + "partitioned_src_node_ids_to_global", + "partitioned_dst_node_ids_to_global", + "partitioned_indices_to_global", + ] + + for attr in attributes: + val_a, val_b = getattr(a, attr), getattr(b, attr) + error_msg = f"{attr} does not match, got {val_a} and {val_b}" + assert val_a == val_b, error_msg + + for attr in torch_attributes: + val_a, val_b = getattr(a, attr), getattr(b, attr) + error_msg = f"{attr} does not match, got {val_a} and {val_b}" + if isinstance(val_a, list): + assert isinstance(val_b, list), error_msg + assert len(val_a) == len(val_b), error_msg + for i in range(len(val_a)): + assert torch.allclose(val_a[i], val_b[i]), error_msg + else: + assert torch.allclose(val_a, val_b), error_msg + + +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_gp_mapping(global_graph, device): + offsets, indices, num_src_nodes, num_dst_nodes = global_graph + partition_size = 4 + partition_rank = 0 + + mapping_src_ids_to_ranks = torch.tensor([0, 1, 2, 3, 0, 1, 2, 3]) + mapping_dst_ids_to_ranks = torch.tensor([0, 1, 2, 3]) + + pg = partition_graph_with_id_mapping( + offsets, + indices, + mapping_src_ids_to_ranks, + mapping_dst_ids_to_ranks, + partition_size, + partition_rank, + device, + ) + + pg_expected = GraphPartition( + partition_size=4, + partition_rank=0, + device=device, + local_offsets=torch.tensor([0, 2], device=device), + local_indices=torch.tensor([0, 1], device=device), + num_local_src_nodes=2, + num_local_dst_nodes=1, + num_local_indices=2, + partitioned_src_node_ids_to_global=torch.tensor([0, 4]), + partitioned_dst_node_ids_to_global=torch.tensor([0]), + partitioned_indices_to_global=torch.tensor([0, 1]), + sizes=[[1, 0, 1, 0], [1, 0, 1, 0], [0, 1, 0, 1], [0, 1, 0, 1]], + scatter_indices=[ + torch.tensor([0], device=device), + torch.tensor([1], device=device), + torch.tensor([], device=device, dtype=torch.int64), + torch.tensor([], device=device, dtype=torch.int64), + ], + num_src_nodes_in_each_partition=[2, 2, 2, 2], + num_dst_nodes_in_each_partition=[1, 1, 1, 1], + num_indices_in_each_partition=[2, 2, 2, 2], + ) + + assert_partitions_are_equal(pg, pg_expected) + + +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_gp_nodewise(global_graph, device): + offsets, indices, num_src_nodes, num_dst_nodes = global_graph + partition_size = 4 + partition_rank = 0 + + pg = partition_graph_nodewise( + offsets, + indices, + partition_size, + partition_rank, + device, + ) + + pg_expected = GraphPartition( + partition_size=4, + partition_rank=0, + device=device, + local_offsets=torch.tensor([0, 2], device=device), + local_indices=torch.tensor([0, 1], device=device), + num_local_src_nodes=2, + num_local_dst_nodes=1, + num_local_indices=2, + partitioned_src_node_ids_to_global=torch.tensor([0, 1]), + partitioned_dst_node_ids_to_global=torch.tensor([0]), + partitioned_indices_to_global=torch.tensor([0, 1]), + sizes=[[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0], [0, 0, 0, 2]], + scatter_indices=[ + torch.tensor([0, 1], device=device), + torch.tensor([], device=device, dtype=torch.int64), + torch.tensor([], device=device, dtype=torch.int64), + torch.tensor([], device=device, dtype=torch.int64), + ], + num_src_nodes_in_each_partition=[2, 2, 2, 2], + num_dst_nodes_in_each_partition=[1, 1, 1, 1], + num_indices_in_each_partition=[2, 2, 2, 2], + ) + + assert_partitions_are_equal(pg, pg_expected) + + +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_gp_coordinate_bbox(global_graph, device): + offsets, indices, num_src_nodes, num_dst_nodes = global_graph + partition_size = 4 + partition_rank = 0 + coordinate_separators_min = [[0, 0], [None, 0], [None, None], [0, None]] + coordinate_separators_max = [[None, None], [0, None], [0, 0], [None, 0]] + device = "cuda:0" + src_coordinates = torch.FloatTensor( + [ + [-1.0, 1.0], + [1.0, 1.0], + [-1.0, -1.0], + [1.0, -1.0], + [-2.0, 2.0], + [2.0, 2.0], + [-2.0, -2.0], + [2.0, -2.0], + ] + ) + dst_coordinates = torch.FloatTensor( + [ + [-1.0, 1.0], + [1.0, 1.0], + [-1.0, -1.0], + [1.0, -1.0], + ] + ) + pg = partition_graph_by_coordinate_bbox( + offsets, + indices, + src_coordinates, + dst_coordinates, + coordinate_separators_min, + coordinate_separators_max, + partition_size, + partition_rank, + device, + ) + + pg_expected = GraphPartition( + partition_size=4, + partition_rank=0, + device=device, + local_offsets=torch.tensor([0, 2], device=device), + local_indices=torch.tensor([0, 1], device=device), + num_local_src_nodes=2, + num_local_dst_nodes=1, + num_local_indices=2, + partitioned_src_node_ids_to_global=torch.tensor([1, 5]), + partitioned_dst_node_ids_to_global=torch.tensor([1]), + partitioned_indices_to_global=torch.tensor([2, 3]), + sizes=[[0, 1, 1, 0], [0, 1, 1, 0], [1, 0, 0, 1], [1, 0, 0, 1]], + scatter_indices=[ + torch.tensor([], device=device, dtype=torch.int64), + torch.tensor([0], device=device), + torch.tensor([1], device=device), + torch.tensor([], device=device, dtype=torch.int64), + ], + num_src_nodes_in_each_partition=[2, 2, 2, 2], + num_dst_nodes_in_each_partition=[1, 1, 1, 1], + num_indices_in_each_partition=[2, 2, 2, 2], + ) + + assert_partitions_are_equal(pg, pg_expected) + + +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_gp_coordinate_bbox_lat_long(global_graph, device): + offsets, indices, num_src_nodes, num_dst_nodes = global_graph + src_lat = torch.FloatTensor([-75, -60, -45, -30, 30, 45, 60, 75]).view(-1, 1) + dst_lat = torch.FloatTensor([-60, -30, 30, 30]).view(-1, 1) + src_long = torch.FloatTensor([-135, -135, 135, 135, -45, -45, 45, 45]).view(-1, 1) + dst_long = torch.FloatTensor([-135, 135, -45, 45]).view(-1, 1) + src_coordinates = torch.cat([src_lat, src_long], dim=1) + dst_coordinates = torch.cat([dst_lat, dst_long], dim=1) + coordinate_separators_min = [ + [-90, -180], + [-90, 0], + [0, -180], + [0, 0], + ] + coordinate_separators_max = [ + [0, 0], + [0, 180], + [90, 0], + [90, 180], + ] + partition_size = 4 + partition_rank = 0 + device = "cuda:0" + pg = partition_graph_by_coordinate_bbox( + offsets, + indices, + src_coordinates, + dst_coordinates, + coordinate_separators_min, + coordinate_separators_max, + partition_size, + partition_rank, + device, + ) + pg_expected = GraphPartition( + partition_size=4, + partition_rank=0, + device=device, + local_offsets=torch.tensor([0, 2], device=device), + local_indices=torch.tensor([0, 1], device=device), + num_local_src_nodes=2, + num_local_dst_nodes=1, + num_local_indices=2, + partitioned_src_node_ids_to_global=torch.tensor([0, 1]), + partitioned_dst_node_ids_to_global=torch.tensor([0]), + partitioned_indices_to_global=torch.tensor([0, 1]), + sizes=[[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0], [0, 0, 0, 2]], + scatter_indices=[ + torch.tensor([0, 1], device=device), + torch.tensor([], device=device, dtype=torch.int64), + torch.tensor([], device=device, dtype=torch.int64), + torch.tensor([], device=device, dtype=torch.int64), + ], + num_src_nodes_in_each_partition=[2, 2, 2, 2], + num_dst_nodes_in_each_partition=[1, 1, 1, 1], + num_indices_in_each_partition=[2, 2, 2, 2], + ) + + assert_partitions_are_equal(pg, pg_expected) + + +if __name__ == "__main__": + pytest.main([__file__]) From 78d18cbdc006439d2b77a24b397b67893bb4870b Mon Sep 17 00:00:00 2001 From: Maximilian Stadler Date: Fri, 8 Dec 2023 08:21:44 -0800 Subject: [PATCH 4/6] format --- .../meshgraphnet/test_meshgraphnet_snmg.py | 2 +- test/models/test_graph_partition.py | 19 ++++++++----------- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/test/models/meshgraphnet/test_meshgraphnet_snmg.py b/test/models/meshgraphnet/test_meshgraphnet_snmg.py index a4bd4a17c7..aa55602ebe 100644 --- a/test/models/meshgraphnet/test_meshgraphnet_snmg.py +++ b/test/models/meshgraphnet/test_meshgraphnet_snmg.py @@ -250,4 +250,4 @@ def test_distributed_meshgraphnet(dtype, partition_scheme, pytestconfig): if __name__ == "__main__": - py \ No newline at end of file + pytest.main([__file__]) diff --git a/test/models/test_graph_partition.py b/test/models/test_graph_partition.py index 571f6e7680..3e12b473c5 100644 --- a/test/models/test_graph_partition.py +++ b/test/models/test_graph_partition.py @@ -12,15 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import random - import pytest import torch + from modulus.models.gnn_layers import ( + GraphPartition, partition_graph_by_coordinate_bbox, - partition_graph_with_id_mapping, partition_graph_nodewise, - GraphPartition, + partition_graph_with_id_mapping, ) @@ -31,13 +30,11 @@ def global_graph(): num_dst_nodes = 4 offsets = torch.arange(num_dst_nodes + 1, dtype=torch.int64) * 2 indices = torch.arange(num_src_nodes, dtype=torch.int64) - + return (offsets, indices, num_src_nodes, num_dst_nodes) def assert_partitions_are_equal(a, b): - is_equal = True - attributes = [ "partition_size", "partition_rank", @@ -48,7 +45,7 @@ def assert_partitions_are_equal(a, b): "sizes", "num_src_nodes_in_each_partition", "num_dst_nodes_in_each_partition", - "num_indices_in_each_partition" + "num_indices_in_each_partition", ] torch_attributes = [ "local_offsets", @@ -71,9 +68,9 @@ def assert_partitions_are_equal(a, b): assert isinstance(val_b, list), error_msg assert len(val_a) == len(val_b), error_msg for i in range(len(val_a)): - assert torch.allclose(val_a[i], val_b[i]), error_msg + assert torch.allclose(val_a[i], val_b[i]), error_msg else: - assert torch.allclose(val_a, val_b), error_msg + assert torch.allclose(val_a, val_b), error_msg @pytest.mark.parametrize("device", ["cuda:0", "cpu"]) @@ -151,7 +148,7 @@ def test_gp_nodewise(global_graph, device): sizes=[[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0], [0, 0, 0, 2]], scatter_indices=[ torch.tensor([0, 1], device=device), - torch.tensor([], device=device, dtype=torch.int64), + torch.tensor([], device=device, dtype=torch.int64), torch.tensor([], device=device, dtype=torch.int64), torch.tensor([], device=device, dtype=torch.int64), ], From c9ac8571ea78576e064d194a914fdc4bcc1380c2 Mon Sep 17 00:00:00 2001 From: Maximilian Stadler Date: Thu, 14 Dec 2023 14:51:05 +0000 Subject: [PATCH 5/6] drop assert num_gpus == 2 in multi-gpu tests --- test/distributed/test_config.py | 2 +- test/distributed/test_distributed_fft.py | 2 +- test/distributed/test_manager.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/distributed/test_config.py b/test/distributed/test_config.py index f21bb9ae8d..3b15920d71 100644 --- a/test/distributed/test_config.py +++ b/test/distributed/test_config.py @@ -135,7 +135,7 @@ def run_distributed_model_config(rank, model_parallel_size, verbose): @pytest.mark.multigpu def test_distributed_model_config(): num_gpus = torch.cuda.device_count() - assert num_gpus == 2, "Not enough GPUs available for test" + assert num_gpus >= 2, "Not enough GPUs available for test" model_parallel_size = 2 verbose = False # Change to True for debug diff --git a/test/distributed/test_distributed_fft.py b/test/distributed/test_distributed_fft.py index 6c167ad0ca..5a9c4030e4 100644 --- a/test/distributed/test_distributed_fft.py +++ b/test/distributed/test_distributed_fft.py @@ -189,7 +189,7 @@ def run_distributed_fft(rank, model_parallel_size, verbose): @pytest.mark.multigpu def test_distributed_fft(): num_gpus = torch.cuda.device_count() - assert num_gpus == 2, "Not enough GPUs available for test" + assert num_gpus >= 2, "Not enough GPUs available for test" model_parallel_size = 2 verbose = False # Change to True for debug diff --git a/test/distributed/test_manager.py b/test/distributed/test_manager.py index 8ac7797f1e..8d3ecad092 100644 --- a/test/distributed/test_manager.py +++ b/test/distributed/test_manager.py @@ -268,7 +268,7 @@ def run_process_groups_from_config(rank, model_parallel_size, verbose): @pytest.mark.multigpu def test_process_groups_from_config(): num_gpus = torch.cuda.device_count() - assert num_gpus == 2, "Not enough GPUs available for test" + assert num_gpus >= 2, "Not enough GPUs available for test" model_parallel_size = 2 verbose = False # Change to True for debug From f5da383f65cbe83b242cd5a2ee5902aeaa8b39ce Mon Sep 17 00:00:00 2001 From: Maximilian Stadler Date: Fri, 15 Dec 2023 03:35:18 -0800 Subject: [PATCH 6/6] fix doc example --- .../models/gnn_layers/distributed_graph.py | 156 +++++++----------- 1 file changed, 60 insertions(+), 96 deletions(-) diff --git a/modulus/models/gnn_layers/distributed_graph.py b/modulus/models/gnn_layers/distributed_graph.py index 7edeaf0fbf..2165b4fda5 100644 --- a/modulus/models/gnn_layers/distributed_graph.py +++ b/modulus/models/gnn_layers/distributed_graph.py @@ -411,13 +411,11 @@ def partition_graph_by_coordinate_bbox( -------- >>> import torch >>> from modulus.models.gnn_layers import partition_graph_by_coordinate_bbox - >>> >>> # simple graph with a degree of 2 per node >>> num_src_nodes = 8 >>> num_dst_nodes = 4 >>> offsets = torch.arange(num_dst_nodes + 1, dtype=torch.int64) * 2 >>> indices = torch.arange(num_src_nodes, dtype=torch.int64) - >>> >>> # example with 2D coordinates >>> # assuming partitioning a 2D problem into the 4 quadrants >>> partition_size = 4 @@ -427,60 +425,43 @@ def partition_graph_by_coordinate_bbox( >>> device = "cuda:0" >>> # dummy coordinates >>> src_coordinates = torch.FloatTensor( - >>> [ - >>> [-1.0, 1.0], - >>> [1.0, 1.0], - >>> [-1.0, -1.0], - >>> [1.0, -1.0], - >>> [-2.0, 2.0], - >>> [2.0, 2.0], - >>> [-2.0, -2.0], - >>> [2.0, -2.0], - >>> ] - >>> ) + ... [ + ... [-1.0, 1.0], + ... [1.0, 1.0], + ... [-1.0, -1.0], + ... [1.0, -1.0], + ... [-2.0, 2.0], + ... [2.0, 2.0], + ... [-2.0, -2.0], + ... [2.0, -2.0], + ... ] + ... ) >>> dst_coordinates = torch.FloatTensor( - >>> [ - >>> [-1.0, 1.0], - >>> [1.0, 1.0], - >>> [-1.0, -1.0], - >>> [1.0, -1.0], - >>> ] - >>> ) + ... [ + ... [-1.0, 1.0], + ... [1.0, 1.0], + ... [-1.0, -1.0], + ... [1.0, -1.0], + ... ] + ... ) >>> # call partitioning routine >>> pg = partition_graph_by_coordinate_bbox( - >>> offsets, - >>> indices, - >>> src_coordinates, - >>> dst_coordinates, - >>> coordinate_separators_min, - >>> coordinate_separators_max, - >>> partition_size, - >>> partition_rank, - >>> device, - >>> ) - GraphPartition( - partition_size=4, - partition_rank=0, - device="cuda:0", - local_offsets=tensor([0, 2], device="cuda:0"), - local_indices=tensor([0, 1], device="cuda:0"), - num_local_src_nodes=2, - num_local_dst_nodes=1, - num_local_indices=2, - partitioned_src_node_ids_to_global=tensor([1, 5]), - partitioned_dst_node_ids_to_global=tensor([1]), - partitioned_indices_to_global=tensor([2, 3]), - sizes=[[0, 1, 1, 0], [0, 1, 1, 0], [1, 0, 0, 1], [1, 0, 0, 1]], - scatter_indices=[ - tensor([], device="cuda:0", dtype=torch.int64), - tensor([0], device="cuda:0"), - tensor([1], device="cuda:0"), - tensor([], device="cuda:0", dtype=torch.int64), - ], - num_src_nodes_in_each_partition=[2, 2, 2, 2], - num_dst_nodes_in_each_partition=[1, 1, 1, 1], - num_indices_in_each_partition=[2, 2, 2, 2], - ) + ... offsets, + ... indices, + ... src_coordinates, + ... dst_coordinates, + ... coordinate_separators_min, + ... coordinate_separators_max, + ... partition_size, + ... partition_rank, + ... device, + ... ) + >>> pg.local_offsets + tensor([0, 2], device='cuda:0') + >>> pg.local_indices + tensor([0, 1], device='cuda:0') + >>> pg.sizes + [[0, 1, 1, 0], [0, 1, 1, 0], [1, 0, 0, 1], [1, 0, 0, 1]] >>> >>> # example with lat-long coordinates >>> # dummy coordinates @@ -492,55 +473,38 @@ def partition_graph_by_coordinate_bbox( >>> dst_coordinates = torch.cat([dst_lat, dst_long], dim=1) >>> # separate sphere at equator and 0 degree longitude into 4 parts >>> coordinate_separators_min = [ - >>> [-90, -180], - >>> [-90, 0], - >>> [0, -180], - >>> [0, 0], - >>> ] + ... [-90, -180], + ... [-90, 0], + ... [0, -180], + ... [0, 0], + ... ] >>> coordinate_separators_max = [ - >>> [0, 0], - >>> [0, 180], - >>> [90, 0], - >>> [90, 180], - >>> ] + ... [0, 0], + ... [0, 180], + ... [90, 0], + ... [90, 180], + ... ] >>> # call partitioning routine >>> partition_size = 4 >>> partition_rank = 0 >>> device = "cuda:0" >>> pg = partition_graph_by_coordinate_bbox( - >>> offsets, - >>> indices, - >>> src_coordinates, - >>> dst_coordinates, - >>> coordinate_separators_min, - >>> coordinate_separators_max, - >>> partition_size, - >>> partition_rank, - >>> device, - >>> ) - GraphPartition( - partition_size=4, - partition_rank=0, - device="cuda:0", - local_offsets=tensor([0, 2], device="cuda:0"), - local_indices=tensor([0, 1], device="cuda:0"), - num_local_src_nodes=2, - num_local_dst_nodes=1, - num_local_indices=2, - partitioned_src_node_ids_to_global=tensor([0, 1]), - partitioned_dst_node_ids_to_global=tensor([0]), - partitioned_indices_to_global=tensor([0, 1]), - sizes=[[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0], [0, 0, 0, 2]], - scatter_indices=[ - tensor([0, 1], device="cuda:0"), - tensor([], device="cuda:0", dtype=torch.int64), - tensor([], device="cuda:0", dtype=torch.int64), - tensor([], device="cuda:0", dtype=torch.int64), - ], - num_src_nodes_in_each_partition=[2, 2, 2, 2], - num_dst_nodes_in_each_partition=[1, 1, 1, 1], - num_indices_in_each_partition=[2, 2, 2, 2], - ) + ... offsets, + ... indices, + ... src_coordinates, + ... dst_coordinates, + ... coordinate_separators_min, + ... coordinate_separators_max, + ... partition_size, + ... partition_rank, + ... device, + ... ) + >>> pg.local_offsets + tensor([0, 2], device='cuda:0') + >>> pg.local_indices + tensor([0, 1], device='cuda:0') + >>> pg.sizes + [[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0], [0, 0, 0, 2]] Parameters ----------