From db41e44e6189314418d9d2e5b21a25efecf74e54 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Tue, 5 Dec 2023 06:19:00 -0800 Subject: [PATCH 01/11] adding more flexibility to shapes Signed-off-by: Thorsten Kurth --- modulus/distributed/mappings.py | 56 ++++------------------ modulus/distributed/utils.py | 58 ++++++++++++++++------- modulus/models/afno/distributed/afno.py | 4 +- modulus/models/afno/distributed/layers.py | 11 +++-- 4 files changed, 62 insertions(+), 67 deletions(-) diff --git a/modulus/distributed/mappings.py b/modulus/distributed/mappings.py index ba99887cf7..f1b3f03d0c 100644 --- a/modulus/distributed/mappings.py +++ b/modulus/distributed/mappings.py @@ -15,7 +15,7 @@ import torch from modulus.distributed.manager import DistributedManager -from modulus.distributed.utils import _gather, _reduce, _split +from modulus.distributed.utils import compute_split_shapes, _gather, _reduce, _split class _CopyToParallelRegion(torch.autograd.Function): @@ -62,12 +62,13 @@ def symbolic(graph, input_, dim_, group_): # pragma: no cover def forward(ctx, input_, dim_, group_): # pragma: no cover ctx.dim = dim_ ctx.group = group_ + ctx.split_shapes = compute_split_shapes(input_.shape[dim_], DistributedManager().group_size(group_)) return _split(input_, dim_, group=DistributedManager().group(group_)) @staticmethod def backward(ctx, grad_output): # pragma: no cover return ( - _gather(grad_output, ctx.dim, group=DistributedManager().group(ctx.group_)), + _gather(grad_output, ctx.dim, shapes_=ctx.split_shapes, group=DistributedManager().group(ctx.group_)), None, None, ) @@ -77,14 +78,14 @@ class _GatherFromParallelRegion(torch.autograd.Function): """Gather the input from parallel region and concatenate.""" @staticmethod - def symbolic(graph, input_, dim_, group_): # pragma: no cover - return _gather(input_, dim_, group=DistributedManager().group(group_)) + def symbolic(graph, input_, dim_, group_, shapes_): # pragma: no cover + return _gather(input_, dim_, shapes_, group=DistributedManager().group(group_)) @staticmethod - def forward(ctx, input_, dim_, group_): # pragma: no cover + def forward(ctx, input_, dim_, shapes_, group_): # pragma: no cover ctx.dim = dim_ ctx.group = group_ - return _gather(input_, dim_, group=DistributedManager().group(group_)) + return _gather(input_, dim_, shapes_=shapes_, group=DistributedManager().group(group_)) @staticmethod def backward(ctx, grad_output): # pragma: no cover @@ -94,35 +95,6 @@ def backward(ctx, grad_output): # pragma: no cover None, ) - -class _GatherWithinParallelRegion(torch.autograd.Function): - """ - Gather the input within parallel region and concatenate. - The same forward method as _GatherFromParallelRegion, the difference is only in the - backward pass. This method performs a reduction of the gradients before the split in - the backward pass while the other version only performs a split - """ - - @staticmethod - def symbolic(graph, input_, dim_, group_): # pragma: no cover - return _gather(input_, dim_, group=DistributedManager().group(group_)) - - @staticmethod - def forward(ctx, input_, dim_, group_): # pragma: no cover - ctx.dim = dim_ - ctx.group = group_ - return _gather(input_, dim_, group=DistributedManager().group(group_)) - - @staticmethod - def backward(ctx, grad_output): # pragma: no cover - red = _reduce(grad_output, group=DistributedManager().group(ctx.group_)) - return ( - _split(red, ctx.dim, group=DistributedManager().group(ctx.group)), - None, - None, - ) - - # ----------------- # Helper functions. # ----------------- @@ -141,16 +113,6 @@ def scatter_to_parallel_region(input, dim, group): # pragma: no cover return _ScatterToParallelRegion.apply(input, dim, group) -def gather_from_parallel_region(input, dim, group): # pragma: no cover +def gather_from_parallel_region(input, dim, shapes, group): # pragma: no cover """Gather the input from matmul parallel region and concatenate.""" - return _GatherFromParallelRegion.apply(input, dim, group) - - -def gather_within_parallel_region(input, dim, group): # pragma: no cover - """ - Gather the input within parallel region and concatenate. - The same forward method as gather_from_parallel_region, the difference is only in - the backward pass. This method performs a reduction of the gradients before the - split in the backward pass while the other version only performs a split - """ - return _GatherWithinParallelRegion.apply(input, dim, group) + return _GatherFromParallelRegion.apply(input, dim, shapes, group) diff --git a/modulus/distributed/utils.py b/modulus/distributed/utils.py index dc6e6389fc..72ccf25432 100644 --- a/modulus/distributed/utils.py +++ b/modulus/distributed/utils.py @@ -22,6 +22,26 @@ from .manager import DistributedManager +def compute_split_shapes(size: int, num_chunks: int) -> List[int]: + + # treat trivial case first + if num_chunks == 1: + return [size] + + # first, check if we can split using div-up to balance the load: + chunk_size = (size + num_chunks - 1) // num_chunks + last_chunk_size = max(0, size - chunk_size * (num_chunks - 1)) + if last_chunk_size == 0: + # in this case, the last shard would be empty, split with floor instead: + chunk_size = size // num_chunks + last_chunk_size = size - chunk_size * (num_chunks-1) + + # generate sections list + sections = [chunk_size for _ in range(num_chunks - 1)] + [last_chunk_size] + + return sections + + def get_memory_format(tensor): """Gets format for tensor""" if tensor.is_contiguous(memory_format=torch.channels_last): @@ -71,20 +91,14 @@ def truncate_helper(tensor, dim, new_size): def split_tensor_along_dim(tensor, dim, num_chunks): - """splits tensor along specific dim""" - if not (dim < tensor.dim()): - raise AssertionError( - f"Error, tensor dimension is {tensor.dim()} which cannot be" - f"split along {dim}" - ) - if not (tensor.shape[dim] % num_chunks == 0): - raise AssertionError( - f"Error, cannot split dim {dim} evenly. Dim size is \ - {tensor.shape[dim]} and requested numnber of splits is {num_chunks}" - ) - chunk_size = tensor.shape[dim] // num_chunks - tensor_list = torch.split(tensor, chunk_size, dim=dim) - + assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}" + assert (tensor.shape[dim] > num_chunks), f"Error, cannot split dim {dim} of size {tensor.shape[dim]} into \ + {num_chunks} chunks. Empty slices are currently not supported." + + # get split + sections = compute_split_shapes(tensor.shape[dim], num_chunks) + tensor_list = torch.split(tensor, sections, dim=dim) + return tensor_list @@ -198,7 +212,7 @@ def _split(input_, dim_, group=None): # pragma: no cover return output -def _gather(input_, dim_, group=None): # pragma: no cover +def _gather(input_, dim_, shapes_=None, group=None): # pragma: no cover """Gather tensors and concatenate along the specified dimension.""" # get input format input_format = get_memory_format(input_) @@ -218,7 +232,19 @@ def _gather(input_, dim_, group=None): # pragma: no cover # Size and dimension. comm_rank = dist.get_rank(group=group) - tensor_list = [torch.empty_like(input_) for _ in range(comm_size)] + # make input contiguous + input_ = input_.contiguous(memory_format=input_format) + + if shapes_ is not None: + shape = list(input_.shape) + gather_shapes = [] + for i in range(comm_size): + shape[dim] = shapes_[i] + gather_shapes.append(shape) + tensor_list = [torch.empty(shape, device=input_.device, dtype=input_.dtype) for shape in gather_shapes] + else: + tensor_list = [torch.empty_like(input_) for _ in range(comm_size)] + tensor_list[comm_rank] = input_ dist.all_gather(tensor_list, input_, group=group) diff --git a/modulus/models/afno/distributed/afno.py b/modulus/models/afno/distributed/afno.py index 456733c70b..aeea9987cf 100644 --- a/modulus/models/afno/distributed/afno.py +++ b/modulus/models/afno/distributed/afno.py @@ -38,6 +38,7 @@ DropPath, trunc_normal_, ) +from modulus.distributed.utils import compute_split_shapes logger = logging.getLogger(__name__) @@ -98,6 +99,7 @@ def __init__( def forward(self, x): if not self.input_is_matmul_parallel: + scatter_shapes = compute_split_shapes(x.shape[1], DistributedManager().group_size("model_parallel")) x = scatter_to_parallel_region(x, dim=1, group="model_parallel") residual = x @@ -114,7 +116,7 @@ def forward(self, x): x = x + residual if not self.output_is_matmul_parallel: - x = gather_from_parallel_region(x, dim=1, group="model_parallel") + x = gather_from_parallel_region(x, dim=1, shapes=scatter_shapes, group="model_parallel") return x diff --git a/modulus/models/afno/distributed/layers.py b/modulus/models/afno/distributed/layers.py index d48a12a18e..152faeb670 100644 --- a/modulus/models/afno/distributed/layers.py +++ b/modulus/models/afno/distributed/layers.py @@ -26,6 +26,7 @@ reduce_from_parallel_region, scatter_to_parallel_region, ) +from modulus.distributed.utils import compute_split_shapes def _no_grad_trunc_normal_(tensor, mean, std, a, b): @@ -175,7 +176,8 @@ def _init_weights(self): def forward(self, x): # gather if input is MP if self.input_is_matmul_parallel: - x = gather_from_parallel_region(x, dim=1, group="model_parallel") + shapes = compute_split_shapes(in_features, DistributedManager().group_size("model_parallel")) + x = gather_from_parallel_region(x, dim=1, shapes=shapes, group="model_parallel") x = copy_to_parallel_region(x, group="model_parallel") x = F.conv2d(x, self.w1, bias=self.b1) @@ -223,6 +225,7 @@ def __init__( raise ValueError( "Error, the in_chans needs to be divisible by matmul_parallel_size" ) + self.in_shapes = compute_split_shapes(in_chans, DistributedManager().group_size("model_parallel")) # get effective embedding size: if self.output_parallel: @@ -245,7 +248,7 @@ def __init__( def forward(self, x): if self.input_parallel: - x = gather_from_parallel_region(x, dim=1, group="model_parallel") + x = gather_from_parallel_region(x, dim=1, shapes=self.in_shapes, group="model_parallel") if self.output_parallel: x = copy_to_parallel_region(x, group="model_parallel") @@ -373,6 +376,7 @@ def __init__( def forward(self, x): if not self.input_is_matmul_parallel: # distribute data + num_chans = x.shape[1] x = scatter_to_parallel_region(x, dim=1, group="model_parallel") # bias @@ -418,6 +422,7 @@ def forward(self, x): # gather if not self.output_is_matmul_parallel: - x = gather_from_parallel_region(x, dim=1, group="model_parallel") + gather_shapes = compute_split_sizes(num_chans, DistributedManager().group_size("model_parallel")) + x = gather_from_parallel_region(x, dim=1, shapes=gather_shapes, group="model_parallel") return x From e68fe871393b77222327fa5ecf43159a59c0b4ec Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 6 Dec 2023 01:12:59 -0800 Subject: [PATCH 02/11] fixing bug in _gather Signed-off-by: Thorsten Kurth --- modulus/distributed/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modulus/distributed/utils.py b/modulus/distributed/utils.py index 72ccf25432..4e29e04a27 100644 --- a/modulus/distributed/utils.py +++ b/modulus/distributed/utils.py @@ -239,9 +239,9 @@ def _gather(input_, dim_, shapes_=None, group=None): # pragma: no cover shape = list(input_.shape) gather_shapes = [] for i in range(comm_size): - shape[dim] = shapes_[i] + shape[dim_] = shapes_[i] gather_shapes.append(shape) - tensor_list = [torch.empty(shape, device=input_.device, dtype=input_.dtype) for shape in gather_shapes] + tensor_list = [torch.empty(shape, device=input_.device, dtype=input_.dtype) for shape in gather_shapes] else: tensor_list = [torch.empty_like(input_) for _ in range(comm_size)] From cb7723c5ae069bdb1c72ec0764d01c88c3beb86e Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 6 Dec 2023 01:42:40 -0800 Subject: [PATCH 03/11] adding missing None gradfient for input shape list to gather Signed-off-by: Thorsten Kurth --- modulus/distributed/mappings.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modulus/distributed/mappings.py b/modulus/distributed/mappings.py index f1b3f03d0c..f3a0c72f68 100644 --- a/modulus/distributed/mappings.py +++ b/modulus/distributed/mappings.py @@ -93,6 +93,7 @@ def backward(ctx, grad_output): # pragma: no cover _split(grad_output, ctx.dim, group=DistributedManager().group(ctx.group)), None, None, + None, ) # ----------------- From 409fd3e93086b846d82bc0dc68fc819b2a72032c Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 6 Dec 2023 05:01:53 -0800 Subject: [PATCH 04/11] debugging utils.py Signed-off-by: Thorsten Kurth --- modulus/distributed/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modulus/distributed/utils.py b/modulus/distributed/utils.py index 4e29e04a27..0ffa68e768 100644 --- a/modulus/distributed/utils.py +++ b/modulus/distributed/utils.py @@ -242,6 +242,8 @@ def _gather(input_, dim_, shapes_=None, group=None): # pragma: no cover shape[dim_] = shapes_[i] gather_shapes.append(shape) tensor_list = [torch.empty(shape, device=input_.device, dtype=input_.dtype) for shape in gather_shapes] + + print("rank {comm_rank}", [x.shape for x in tensor_list], input_.shape) else: tensor_list = [torch.empty_like(input_) for _ in range(comm_size)] From bd4c50b1ec893b4e105d25d6901b2a35aabde975 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 6 Dec 2023 05:54:14 -0800 Subject: [PATCH 05/11] fixing small typo in split backward Signed-off-by: Thorsten Kurth --- modulus/distributed/mappings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modulus/distributed/mappings.py b/modulus/distributed/mappings.py index f3a0c72f68..3cd36062b2 100644 --- a/modulus/distributed/mappings.py +++ b/modulus/distributed/mappings.py @@ -68,7 +68,7 @@ def forward(ctx, input_, dim_, group_): # pragma: no cover @staticmethod def backward(ctx, grad_output): # pragma: no cover return ( - _gather(grad_output, ctx.dim, shapes_=ctx.split_shapes, group=DistributedManager().group(ctx.group_)), + _gather(grad_output, ctx.dim, shapes_=ctx.split_shapes, group=DistributedManager().group(ctx.group)), None, None, ) From b503f5451aff30f5fa2cc214de1389ddb8578dbe Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 6 Dec 2023 05:59:20 -0800 Subject: [PATCH 06/11] debugging gather problem Signed-off-by: Thorsten Kurth --- modulus/distributed/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modulus/distributed/utils.py b/modulus/distributed/utils.py index 0ffa68e768..2c1edd1600 100644 --- a/modulus/distributed/utils.py +++ b/modulus/distributed/utils.py @@ -241,9 +241,9 @@ def _gather(input_, dim_, shapes_=None, group=None): # pragma: no cover for i in range(comm_size): shape[dim_] = shapes_[i] gather_shapes.append(shape) - tensor_list = [torch.empty(shape, device=input_.device, dtype=input_.dtype) for shape in gather_shapes] + tensor_list = [torch.empty(s, device=input_.device, dtype=input_.dtype) for s in gather_shapes] - print("rank {comm_rank}", [x.shape for x in tensor_list], input_.shape) + print(f"rank {comm_rank}", [x.shape for x in tensor_list], input_.shape, shapes_) else: tensor_list = [torch.empty_like(input_) for _ in range(comm_size)] From db18d84dca1a266f3a1a3b06c8333acb6408acdd Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 6 Dec 2023 06:57:51 -0800 Subject: [PATCH 07/11] fixing some gather bugs Signed-off-by: Thorsten Kurth --- modulus/distributed/utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/modulus/distributed/utils.py b/modulus/distributed/utils.py index 2c1edd1600..eaf9d77fe9 100644 --- a/modulus/distributed/utils.py +++ b/modulus/distributed/utils.py @@ -237,11 +237,10 @@ def _gather(input_, dim_, shapes_=None, group=None): # pragma: no cover if shapes_ is not None: shape = list(input_.shape) - gather_shapes = [] + tensor_list = [None for _ in range(comm_size)] for i in range(comm_size): shape[dim_] = shapes_[i] - gather_shapes.append(shape) - tensor_list = [torch.empty(s, device=input_.device, dtype=input_.dtype) for s in gather_shapes] + tensor_list[i] = torch.empty(shape, device=input_.device, dtype=input_.dtype) print(f"rank {comm_rank}", [x.shape for x in tensor_list], input_.shape, shapes_) else: From 9f6384b7170423683a84a862e978f27a61ffcef5 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 6 Dec 2023 07:52:22 -0800 Subject: [PATCH 08/11] fixing assert error in split_tensor_along_dim Signed-off-by: Thorsten Kurth --- modulus/distributed/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modulus/distributed/utils.py b/modulus/distributed/utils.py index eaf9d77fe9..a6aea636fd 100644 --- a/modulus/distributed/utils.py +++ b/modulus/distributed/utils.py @@ -92,7 +92,7 @@ def truncate_helper(tensor, dim, new_size): def split_tensor_along_dim(tensor, dim, num_chunks): assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}" - assert (tensor.shape[dim] > num_chunks), f"Error, cannot split dim {dim} of size {tensor.shape[dim]} into \ + assert (tensor.shape[dim] >= num_chunks), f"Error, cannot split dim {dim} of size {tensor.shape[dim]} into \ {num_chunks} chunks. Empty slices are currently not supported." # get split From c1e1859fb7fc7f0aee428749aa9a00319817b262 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Wed, 6 Dec 2023 07:56:13 -0800 Subject: [PATCH 09/11] removing debug printing Signed-off-by: Thorsten Kurth --- modulus/distributed/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/modulus/distributed/utils.py b/modulus/distributed/utils.py index a6aea636fd..2b7bbb17e1 100644 --- a/modulus/distributed/utils.py +++ b/modulus/distributed/utils.py @@ -241,8 +241,6 @@ def _gather(input_, dim_, shapes_=None, group=None): # pragma: no cover for i in range(comm_size): shape[dim_] = shapes_[i] tensor_list[i] = torch.empty(shape, device=input_.device, dtype=input_.dtype) - - print(f"rank {comm_rank}", [x.shape for x in tensor_list], input_.shape, shapes_) else: tensor_list = [torch.empty_like(input_) for _ in range(comm_size)] From f4fbd15bc18b364155b64c20d8a47ea060cf2162 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Thu, 7 Dec 2023 02:36:43 -0800 Subject: [PATCH 10/11] passing-pre-commit Signed-off-by: Thorsten Kurth --- modulus/distributed/mappings.py | 18 ++++++++++--- modulus/distributed/utils.py | 32 ++++++++++++++--------- modulus/models/afno/distributed/afno.py | 10 ++++--- modulus/models/afno/distributed/layers.py | 26 +++++++++++++----- 4 files changed, 61 insertions(+), 25 deletions(-) diff --git a/modulus/distributed/mappings.py b/modulus/distributed/mappings.py index 3cd36062b2..9e0b073d42 100644 --- a/modulus/distributed/mappings.py +++ b/modulus/distributed/mappings.py @@ -15,7 +15,7 @@ import torch from modulus.distributed.manager import DistributedManager -from modulus.distributed.utils import compute_split_shapes, _gather, _reduce, _split +from modulus.distributed.utils import _gather, _reduce, _split, compute_split_shapes class _CopyToParallelRegion(torch.autograd.Function): @@ -62,13 +62,20 @@ def symbolic(graph, input_, dim_, group_): # pragma: no cover def forward(ctx, input_, dim_, group_): # pragma: no cover ctx.dim = dim_ ctx.group = group_ - ctx.split_shapes = compute_split_shapes(input_.shape[dim_], DistributedManager().group_size(group_)) + ctx.split_shapes = compute_split_shapes( + input_.shape[dim_], DistributedManager().group_size(group_) + ) return _split(input_, dim_, group=DistributedManager().group(group_)) @staticmethod def backward(ctx, grad_output): # pragma: no cover return ( - _gather(grad_output, ctx.dim, shapes_=ctx.split_shapes, group=DistributedManager().group(ctx.group)), + _gather( + grad_output, + ctx.dim, + shapes_=ctx.split_shapes, + group=DistributedManager().group(ctx.group), + ), None, None, ) @@ -85,7 +92,9 @@ def symbolic(graph, input_, dim_, group_, shapes_): # pragma: no cover def forward(ctx, input_, dim_, shapes_, group_): # pragma: no cover ctx.dim = dim_ ctx.group = group_ - return _gather(input_, dim_, shapes_=shapes_, group=DistributedManager().group(group_)) + return _gather( + input_, dim_, shapes_=shapes_, group=DistributedManager().group(group_) + ) @staticmethod def backward(ctx, grad_output): # pragma: no cover @@ -96,6 +105,7 @@ def backward(ctx, grad_output): # pragma: no cover None, ) + # ----------------- # Helper functions. # ----------------- diff --git a/modulus/distributed/utils.py b/modulus/distributed/utils.py index 2b7bbb17e1..8c6bbfcd7f 100644 --- a/modulus/distributed/utils.py +++ b/modulus/distributed/utils.py @@ -23,18 +23,18 @@ def compute_split_shapes(size: int, num_chunks: int) -> List[int]: - + # treat trivial case first if num_chunks == 1: return [size] - - # first, check if we can split using div-up to balance the load: + + # first, check if we can split using div-up to balance the load: chunk_size = (size + num_chunks - 1) // num_chunks last_chunk_size = max(0, size - chunk_size * (num_chunks - 1)) if last_chunk_size == 0: # in this case, the last shard would be empty, split with floor instead: chunk_size = size // num_chunks - last_chunk_size = size - chunk_size * (num_chunks-1) + last_chunk_size = size - chunk_size * (num_chunks - 1) # generate sections list sections = [chunk_size for _ in range(num_chunks - 1)] + [last_chunk_size] @@ -91,14 +91,20 @@ def truncate_helper(tensor, dim, new_size): def split_tensor_along_dim(tensor, dim, num_chunks): - assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}" - assert (tensor.shape[dim] >= num_chunks), f"Error, cannot split dim {dim} of size {tensor.shape[dim]} into \ - {num_chunks} chunks. Empty slices are currently not supported." - + if dim >= tensor.dim(): + raise ValueError( + f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}" + ) + if tensor.shape[dim] < num_chunks: + raise ValueError( + "Error, cannot split dim {dim} of size {tensor.shape[dim]} into \ + {num_chunks} chunks. Empty slices are currently not supported." + ) + # get split sections = compute_split_shapes(tensor.shape[dim], num_chunks) tensor_list = torch.split(tensor, sections, dim=dim) - + return tensor_list @@ -234,16 +240,18 @@ def _gather(input_, dim_, shapes_=None, group=None): # pragma: no cover # make input contiguous input_ = input_.contiguous(memory_format=input_format) - + if shapes_ is not None: shape = list(input_.shape) tensor_list = [None for _ in range(comm_size)] for i in range(comm_size): shape[dim_] = shapes_[i] - tensor_list[i] = torch.empty(shape, device=input_.device, dtype=input_.dtype) + tensor_list[i] = torch.empty( + shape, device=input_.device, dtype=input_.dtype + ) else: tensor_list = [torch.empty_like(input_) for _ in range(comm_size)] - + tensor_list[comm_rank] = input_ dist.all_gather(tensor_list, input_, group=group) diff --git a/modulus/models/afno/distributed/afno.py b/modulus/models/afno/distributed/afno.py index aeea9987cf..f44d332817 100644 --- a/modulus/models/afno/distributed/afno.py +++ b/modulus/models/afno/distributed/afno.py @@ -31,6 +31,7 @@ gather_from_parallel_region, scatter_to_parallel_region, ) +from modulus.distributed.utils import compute_split_shapes from modulus.models.afno.distributed.layers import ( DistributedAFNO2D, DistributedMLP, @@ -38,7 +39,6 @@ DropPath, trunc_normal_, ) -from modulus.distributed.utils import compute_split_shapes logger = logging.getLogger(__name__) @@ -99,7 +99,9 @@ def __init__( def forward(self, x): if not self.input_is_matmul_parallel: - scatter_shapes = compute_split_shapes(x.shape[1], DistributedManager().group_size("model_parallel")) + scatter_shapes = compute_split_shapes( + x.shape[1], DistributedManager().group_size("model_parallel") + ) x = scatter_to_parallel_region(x, dim=1, group="model_parallel") residual = x @@ -116,7 +118,9 @@ def forward(self, x): x = x + residual if not self.output_is_matmul_parallel: - x = gather_from_parallel_region(x, dim=1, shapes=scatter_shapes, group="model_parallel") + x = gather_from_parallel_region( + x, dim=1, shapes=scatter_shapes, group="model_parallel" + ) return x diff --git a/modulus/models/afno/distributed/layers.py b/modulus/models/afno/distributed/layers.py index 152faeb670..e0ea587b7e 100644 --- a/modulus/models/afno/distributed/layers.py +++ b/modulus/models/afno/distributed/layers.py @@ -164,6 +164,11 @@ def __init__( self.act = act_layer() self.drop = nn.Dropout(drop) if drop > 0.0 else nn.Identity() + if self.input_is_matmul_parallel: + self.gather_shapes = compute_split_shapes( + in_features, DistributedManager().group_size("model_parallel") + ) + # init weights self._init_weights() @@ -176,8 +181,9 @@ def _init_weights(self): def forward(self, x): # gather if input is MP if self.input_is_matmul_parallel: - shapes = compute_split_shapes(in_features, DistributedManager().group_size("model_parallel")) - x = gather_from_parallel_region(x, dim=1, shapes=shapes, group="model_parallel") + x = gather_from_parallel_region( + x, dim=1, shapes=self.gather_shapes, group="model_parallel" + ) x = copy_to_parallel_region(x, group="model_parallel") x = F.conv2d(x, self.w1, bias=self.b1) @@ -225,7 +231,9 @@ def __init__( raise ValueError( "Error, the in_chans needs to be divisible by matmul_parallel_size" ) - self.in_shapes = compute_split_shapes(in_chans, DistributedManager().group_size("model_parallel")) + self.in_shapes = compute_split_shapes( + in_chans, DistributedManager().group_size("model_parallel") + ) # get effective embedding size: if self.output_parallel: @@ -248,7 +256,9 @@ def __init__( def forward(self, x): if self.input_parallel: - x = gather_from_parallel_region(x, dim=1, shapes=self.in_shapes, group="model_parallel") + x = gather_from_parallel_region( + x, dim=1, shapes=self.in_shapes, group="model_parallel" + ) if self.output_parallel: x = copy_to_parallel_region(x, group="model_parallel") @@ -422,7 +432,11 @@ def forward(self, x): # gather if not self.output_is_matmul_parallel: - gather_shapes = compute_split_sizes(num_chans, DistributedManager().group_size("model_parallel")) - x = gather_from_parallel_region(x, dim=1, shapes=gather_shapes, group="model_parallel") + gather_shapes = compute_split_shapes( + num_chans, DistributedManager().group_size("model_parallel") + ) + x = gather_from_parallel_region( + x, dim=1, shapes=gather_shapes, group="model_parallel" + ) return x From 9f6276a057b2018ef9aa0900d1dd9c2e632c9213 Mon Sep 17 00:00:00 2001 From: Thorsten Kurth Date: Tue, 12 Dec 2023 01:05:01 -0800 Subject: [PATCH 11/11] merging _gather and all_gather_v Signed-off-by: Thorsten Kurth --- modulus/distributed/mappings.py | 19 ++++++--- modulus/distributed/utils.py | 76 +++++++++------------------------ 2 files changed, 33 insertions(+), 62 deletions(-) diff --git a/modulus/distributed/mappings.py b/modulus/distributed/mappings.py index 9e0b073d42..245103c54f 100644 --- a/modulus/distributed/mappings.py +++ b/modulus/distributed/mappings.py @@ -15,7 +15,12 @@ import torch from modulus.distributed.manager import DistributedManager -from modulus.distributed.utils import _gather, _reduce, _split, compute_split_shapes +from modulus.distributed.utils import ( + _reduce, + _split, + all_gather_v_wrapper, + compute_split_shapes, +) class _CopyToParallelRegion(torch.autograd.Function): @@ -70,10 +75,10 @@ def forward(ctx, input_, dim_, group_): # pragma: no cover @staticmethod def backward(ctx, grad_output): # pragma: no cover return ( - _gather( + all_gather_v_wrapper( grad_output, + ctx.split_shapes, ctx.dim, - shapes_=ctx.split_shapes, group=DistributedManager().group(ctx.group), ), None, @@ -86,14 +91,16 @@ class _GatherFromParallelRegion(torch.autograd.Function): @staticmethod def symbolic(graph, input_, dim_, group_, shapes_): # pragma: no cover - return _gather(input_, dim_, shapes_, group=DistributedManager().group(group_)) + return all_gather_v_wrapper( + input_, shapes_, dim_, group=DistributedManager().group(group_) + ) @staticmethod def forward(ctx, input_, dim_, shapes_, group_): # pragma: no cover ctx.dim = dim_ ctx.group = group_ - return _gather( - input_, dim_, shapes_=shapes_, group=DistributedManager().group(group_) + return all_gather_v_wrapper( + input_, shapes_, dim_, group=DistributedManager().group(group_) ) @staticmethod diff --git a/modulus/distributed/utils.py b/modulus/distributed/utils.py index 8c6bbfcd7f..9c74351b3a 100644 --- a/modulus/distributed/utils.py +++ b/modulus/distributed/utils.py @@ -218,52 +218,9 @@ def _split(input_, dim_, group=None): # pragma: no cover return output -def _gather(input_, dim_, shapes_=None, group=None): # pragma: no cover - """Gather tensors and concatenate along the specified dimension.""" - # get input format - input_format = get_memory_format(input_) - - comm_size = dist.get_world_size(group=group) - # Bypass the function if we are using only 1 GPU. - if comm_size == 1: - return input_ - - # sanity checks - if not (dim_ < input_.dim()): - raise AssertionError( - f"Error, cannot gather along {dim_} for tensor with {input_.dim()} " - "dimensions." - ) - - # Size and dimension. - comm_rank = dist.get_rank(group=group) - - # make input contiguous - input_ = input_.contiguous(memory_format=input_format) - - if shapes_ is not None: - shape = list(input_.shape) - tensor_list = [None for _ in range(comm_size)] - for i in range(comm_size): - shape[dim_] = shapes_[i] - tensor_list[i] = torch.empty( - shape, device=input_.device, dtype=input_.dtype - ) - else: - tensor_list = [torch.empty_like(input_) for _ in range(comm_size)] - - tensor_list[comm_rank] = input_ - dist.all_gather(tensor_list, input_, group=group) - - # Note: torch.cat already creates a contiguous tensor. - output = torch.cat(tensor_list, dim=dim_).contiguous(memory_format=input_format) - - return output - - def all_gather_v_wrapper( tensor: torch.Tensor, - sizes: List[int], + sizes: Optional[List[int]] = None, dim: int = 0, group: Optional[dist.ProcessGroup] = None, ) -> torch.Tensor: # pragma: no cover @@ -278,9 +235,9 @@ def all_gather_v_wrapper( ---------- tensor : "torch.Tensor" local tensor on each rank - sizes : List[int] + sizes : List[int], optional list of the sizes of each chunk on each rank along distributed dimension, - valid and set on each rank + valid and set on each rank, by default None dim : int, optional dimension along which global tensor is distributed, by default 0 group : Optional[dist.ProcessGroup], optional @@ -293,7 +250,8 @@ def all_gather_v_wrapper( """ comm_size = dist.get_world_size(group=group) - if len(sizes) != comm_size: + + if (sizes is not None) and (len(sizes) != comm_size): raise ValueError() if dim >= tensor.dim(): raise ValueError() @@ -302,19 +260,25 @@ def all_gather_v_wrapper( return tensor tensor_shape = list(tensor.shape) - tensor_list = [None] * comm_size + tensor_format = get_memory_format(tensor) - for src in range(comm_size): - tensor_shape[dim] = sizes[src] - tensor_list[src] = torch.empty( - tensor_shape, - dtype=tensor.dtype, - device=tensor.device, - ) + if sizes is not None: + tensor_list = [None] * comm_size + + for src in range(comm_size): + tensor_shape[dim] = sizes[src] + tensor_list[src] = torch.empty( + tensor_shape, + dtype=tensor.dtype, + device=tensor.device, + ) + else: + # assume equal shape on all ranks + tensor_list = [torch.empty_like(tensor) for _ in range(comm_size)] dist.all_gather(tensor_list, tensor, group=group) - output = torch.cat(tensor_list, dim=dim) + output = torch.cat(tensor_list, dim=dim).contiguous(memory_format=tensor_format) return output