diff --git a/modulus/distributed/mappings.py b/modulus/distributed/mappings.py index ba99887cf7..245103c54f 100644 --- a/modulus/distributed/mappings.py +++ b/modulus/distributed/mappings.py @@ -15,7 +15,12 @@ import torch from modulus.distributed.manager import DistributedManager -from modulus.distributed.utils import _gather, _reduce, _split +from modulus.distributed.utils import ( + _reduce, + _split, + all_gather_v_wrapper, + compute_split_shapes, +) class _CopyToParallelRegion(torch.autograd.Function): @@ -62,12 +67,20 @@ def symbolic(graph, input_, dim_, group_): # pragma: no cover def forward(ctx, input_, dim_, group_): # pragma: no cover ctx.dim = dim_ ctx.group = group_ + ctx.split_shapes = compute_split_shapes( + input_.shape[dim_], DistributedManager().group_size(group_) + ) return _split(input_, dim_, group=DistributedManager().group(group_)) @staticmethod def backward(ctx, grad_output): # pragma: no cover return ( - _gather(grad_output, ctx.dim, group=DistributedManager().group(ctx.group_)), + all_gather_v_wrapper( + grad_output, + ctx.split_shapes, + ctx.dim, + group=DistributedManager().group(ctx.group), + ), None, None, ) @@ -77,14 +90,18 @@ class _GatherFromParallelRegion(torch.autograd.Function): """Gather the input from parallel region and concatenate.""" @staticmethod - def symbolic(graph, input_, dim_, group_): # pragma: no cover - return _gather(input_, dim_, group=DistributedManager().group(group_)) + def symbolic(graph, input_, dim_, group_, shapes_): # pragma: no cover + return all_gather_v_wrapper( + input_, shapes_, dim_, group=DistributedManager().group(group_) + ) @staticmethod - def forward(ctx, input_, dim_, group_): # pragma: no cover + def forward(ctx, input_, dim_, shapes_, group_): # pragma: no cover ctx.dim = dim_ ctx.group = group_ - return _gather(input_, dim_, group=DistributedManager().group(group_)) + return all_gather_v_wrapper( + input_, shapes_, dim_, group=DistributedManager().group(group_) + ) @staticmethod def backward(ctx, grad_output): # pragma: no cover @@ -92,33 +109,6 @@ def backward(ctx, grad_output): # pragma: no cover _split(grad_output, ctx.dim, group=DistributedManager().group(ctx.group)), None, None, - ) - - -class _GatherWithinParallelRegion(torch.autograd.Function): - """ - Gather the input within parallel region and concatenate. - The same forward method as _GatherFromParallelRegion, the difference is only in the - backward pass. This method performs a reduction of the gradients before the split in - the backward pass while the other version only performs a split - """ - - @staticmethod - def symbolic(graph, input_, dim_, group_): # pragma: no cover - return _gather(input_, dim_, group=DistributedManager().group(group_)) - - @staticmethod - def forward(ctx, input_, dim_, group_): # pragma: no cover - ctx.dim = dim_ - ctx.group = group_ - return _gather(input_, dim_, group=DistributedManager().group(group_)) - - @staticmethod - def backward(ctx, grad_output): # pragma: no cover - red = _reduce(grad_output, group=DistributedManager().group(ctx.group_)) - return ( - _split(red, ctx.dim, group=DistributedManager().group(ctx.group)), - None, None, ) @@ -141,16 +131,6 @@ def scatter_to_parallel_region(input, dim, group): # pragma: no cover return _ScatterToParallelRegion.apply(input, dim, group) -def gather_from_parallel_region(input, dim, group): # pragma: no cover +def gather_from_parallel_region(input, dim, shapes, group): # pragma: no cover """Gather the input from matmul parallel region and concatenate.""" - return _GatherFromParallelRegion.apply(input, dim, group) - - -def gather_within_parallel_region(input, dim, group): # pragma: no cover - """ - Gather the input within parallel region and concatenate. - The same forward method as gather_from_parallel_region, the difference is only in - the backward pass. This method performs a reduction of the gradients before the - split in the backward pass while the other version only performs a split - """ - return _GatherWithinParallelRegion.apply(input, dim, group) + return _GatherFromParallelRegion.apply(input, dim, shapes, group) diff --git a/modulus/distributed/utils.py b/modulus/distributed/utils.py index dc6e6389fc..9c74351b3a 100644 --- a/modulus/distributed/utils.py +++ b/modulus/distributed/utils.py @@ -22,6 +22,26 @@ from .manager import DistributedManager +def compute_split_shapes(size: int, num_chunks: int) -> List[int]: + + # treat trivial case first + if num_chunks == 1: + return [size] + + # first, check if we can split using div-up to balance the load: + chunk_size = (size + num_chunks - 1) // num_chunks + last_chunk_size = max(0, size - chunk_size * (num_chunks - 1)) + if last_chunk_size == 0: + # in this case, the last shard would be empty, split with floor instead: + chunk_size = size // num_chunks + last_chunk_size = size - chunk_size * (num_chunks - 1) + + # generate sections list + sections = [chunk_size for _ in range(num_chunks - 1)] + [last_chunk_size] + + return sections + + def get_memory_format(tensor): """Gets format for tensor""" if tensor.is_contiguous(memory_format=torch.channels_last): @@ -71,19 +91,19 @@ def truncate_helper(tensor, dim, new_size): def split_tensor_along_dim(tensor, dim, num_chunks): - """splits tensor along specific dim""" - if not (dim < tensor.dim()): - raise AssertionError( - f"Error, tensor dimension is {tensor.dim()} which cannot be" - f"split along {dim}" + if dim >= tensor.dim(): + raise ValueError( + f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}" ) - if not (tensor.shape[dim] % num_chunks == 0): - raise AssertionError( - f"Error, cannot split dim {dim} evenly. Dim size is \ - {tensor.shape[dim]} and requested numnber of splits is {num_chunks}" + if tensor.shape[dim] < num_chunks: + raise ValueError( + "Error, cannot split dim {dim} of size {tensor.shape[dim]} into \ + {num_chunks} chunks. Empty slices are currently not supported." ) - chunk_size = tensor.shape[dim] // num_chunks - tensor_list = torch.split(tensor, chunk_size, dim=dim) + + # get split + sections = compute_split_shapes(tensor.shape[dim], num_chunks) + tensor_list = torch.split(tensor, sections, dim=dim) return tensor_list @@ -198,39 +218,9 @@ def _split(input_, dim_, group=None): # pragma: no cover return output -def _gather(input_, dim_, group=None): # pragma: no cover - """Gather tensors and concatenate along the specified dimension.""" - # get input format - input_format = get_memory_format(input_) - - comm_size = dist.get_world_size(group=group) - # Bypass the function if we are using only 1 GPU. - if comm_size == 1: - return input_ - - # sanity checks - if not (dim_ < input_.dim()): - raise AssertionError( - f"Error, cannot gather along {dim_} for tensor with {input_.dim()} " - "dimensions." - ) - - # Size and dimension. - comm_rank = dist.get_rank(group=group) - - tensor_list = [torch.empty_like(input_) for _ in range(comm_size)] - tensor_list[comm_rank] = input_ - dist.all_gather(tensor_list, input_, group=group) - - # Note: torch.cat already creates a contiguous tensor. - output = torch.cat(tensor_list, dim=dim_).contiguous(memory_format=input_format) - - return output - - def all_gather_v_wrapper( tensor: torch.Tensor, - sizes: List[int], + sizes: Optional[List[int]] = None, dim: int = 0, group: Optional[dist.ProcessGroup] = None, ) -> torch.Tensor: # pragma: no cover @@ -245,9 +235,9 @@ def all_gather_v_wrapper( ---------- tensor : "torch.Tensor" local tensor on each rank - sizes : List[int] + sizes : List[int], optional list of the sizes of each chunk on each rank along distributed dimension, - valid and set on each rank + valid and set on each rank, by default None dim : int, optional dimension along which global tensor is distributed, by default 0 group : Optional[dist.ProcessGroup], optional @@ -260,7 +250,8 @@ def all_gather_v_wrapper( """ comm_size = dist.get_world_size(group=group) - if len(sizes) != comm_size: + + if (sizes is not None) and (len(sizes) != comm_size): raise ValueError() if dim >= tensor.dim(): raise ValueError() @@ -269,19 +260,25 @@ def all_gather_v_wrapper( return tensor tensor_shape = list(tensor.shape) - tensor_list = [None] * comm_size - - for src in range(comm_size): - tensor_shape[dim] = sizes[src] - tensor_list[src] = torch.empty( - tensor_shape, - dtype=tensor.dtype, - device=tensor.device, - ) + tensor_format = get_memory_format(tensor) + + if sizes is not None: + tensor_list = [None] * comm_size + + for src in range(comm_size): + tensor_shape[dim] = sizes[src] + tensor_list[src] = torch.empty( + tensor_shape, + dtype=tensor.dtype, + device=tensor.device, + ) + else: + # assume equal shape on all ranks + tensor_list = [torch.empty_like(tensor) for _ in range(comm_size)] dist.all_gather(tensor_list, tensor, group=group) - output = torch.cat(tensor_list, dim=dim) + output = torch.cat(tensor_list, dim=dim).contiguous(memory_format=tensor_format) return output diff --git a/modulus/models/afno/distributed/afno.py b/modulus/models/afno/distributed/afno.py index 1391c59dc6..59daa4098b 100644 --- a/modulus/models/afno/distributed/afno.py +++ b/modulus/models/afno/distributed/afno.py @@ -31,6 +31,7 @@ gather_from_parallel_region, scatter_to_parallel_region, ) +from modulus.distributed.utils import compute_split_shapes from modulus.models.afno.distributed.layers import ( DistributedAFNO2D, DistributedMLP, @@ -97,6 +98,9 @@ def __init__( def forward(self, x): if not self.input_is_matmul_parallel: + scatter_shapes = compute_split_shapes( + x.shape[1], DistributedManager().group_size("model_parallel") + ) x = scatter_to_parallel_region(x, dim=1, group="model_parallel") residual = x @@ -113,7 +117,9 @@ def forward(self, x): x = x + residual if not self.output_is_matmul_parallel: - x = gather_from_parallel_region(x, dim=1, group="model_parallel") + x = gather_from_parallel_region( + x, dim=1, shapes=scatter_shapes, group="model_parallel" + ) return x diff --git a/modulus/models/afno/distributed/layers.py b/modulus/models/afno/distributed/layers.py index d48a12a18e..e0ea587b7e 100644 --- a/modulus/models/afno/distributed/layers.py +++ b/modulus/models/afno/distributed/layers.py @@ -26,6 +26,7 @@ reduce_from_parallel_region, scatter_to_parallel_region, ) +from modulus.distributed.utils import compute_split_shapes def _no_grad_trunc_normal_(tensor, mean, std, a, b): @@ -163,6 +164,11 @@ def __init__( self.act = act_layer() self.drop = nn.Dropout(drop) if drop > 0.0 else nn.Identity() + if self.input_is_matmul_parallel: + self.gather_shapes = compute_split_shapes( + in_features, DistributedManager().group_size("model_parallel") + ) + # init weights self._init_weights() @@ -175,7 +181,9 @@ def _init_weights(self): def forward(self, x): # gather if input is MP if self.input_is_matmul_parallel: - x = gather_from_parallel_region(x, dim=1, group="model_parallel") + x = gather_from_parallel_region( + x, dim=1, shapes=self.gather_shapes, group="model_parallel" + ) x = copy_to_parallel_region(x, group="model_parallel") x = F.conv2d(x, self.w1, bias=self.b1) @@ -223,6 +231,9 @@ def __init__( raise ValueError( "Error, the in_chans needs to be divisible by matmul_parallel_size" ) + self.in_shapes = compute_split_shapes( + in_chans, DistributedManager().group_size("model_parallel") + ) # get effective embedding size: if self.output_parallel: @@ -245,7 +256,9 @@ def __init__( def forward(self, x): if self.input_parallel: - x = gather_from_parallel_region(x, dim=1, group="model_parallel") + x = gather_from_parallel_region( + x, dim=1, shapes=self.in_shapes, group="model_parallel" + ) if self.output_parallel: x = copy_to_parallel_region(x, group="model_parallel") @@ -373,6 +386,7 @@ def __init__( def forward(self, x): if not self.input_is_matmul_parallel: # distribute data + num_chans = x.shape[1] x = scatter_to_parallel_region(x, dim=1, group="model_parallel") # bias @@ -418,6 +432,11 @@ def forward(self, x): # gather if not self.output_is_matmul_parallel: - x = gather_from_parallel_region(x, dim=1, group="model_parallel") + gather_shapes = compute_split_shapes( + num_chans, DistributedManager().group_size("model_parallel") + ) + x = gather_from_parallel_region( + x, dim=1, shapes=gather_shapes, group="model_parallel" + ) return x