Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tkurth/extended distributed primitives #273

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 25 additions & 45 deletions modulus/distributed/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
import torch

from modulus.distributed.manager import DistributedManager
from modulus.distributed.utils import _gather, _reduce, _split
from modulus.distributed.utils import (
_reduce,
_split,
all_gather_v_wrapper,
compute_split_shapes,
)


class _CopyToParallelRegion(torch.autograd.Function):
Expand Down Expand Up @@ -62,12 +67,20 @@ def symbolic(graph, input_, dim_, group_): # pragma: no cover
def forward(ctx, input_, dim_, group_): # pragma: no cover
ctx.dim = dim_
ctx.group = group_
ctx.split_shapes = compute_split_shapes(
input_.shape[dim_], DistributedManager().group_size(group_)
)
return _split(input_, dim_, group=DistributedManager().group(group_))

@staticmethod
def backward(ctx, grad_output): # pragma: no cover
return (
_gather(grad_output, ctx.dim, group=DistributedManager().group(ctx.group_)),
all_gather_v_wrapper(
grad_output,
ctx.split_shapes,
ctx.dim,
group=DistributedManager().group(ctx.group),
),
None,
None,
)
Expand All @@ -77,48 +90,25 @@ class _GatherFromParallelRegion(torch.autograd.Function):
"""Gather the input from parallel region and concatenate."""

@staticmethod
def symbolic(graph, input_, dim_, group_): # pragma: no cover
return _gather(input_, dim_, group=DistributedManager().group(group_))
def symbolic(graph, input_, dim_, group_, shapes_): # pragma: no cover
return all_gather_v_wrapper(
input_, shapes_, dim_, group=DistributedManager().group(group_)
)

@staticmethod
def forward(ctx, input_, dim_, group_): # pragma: no cover
def forward(ctx, input_, dim_, shapes_, group_): # pragma: no cover
ctx.dim = dim_
ctx.group = group_
return _gather(input_, dim_, group=DistributedManager().group(group_))
return all_gather_v_wrapper(
input_, shapes_, dim_, group=DistributedManager().group(group_)
)

@staticmethod
def backward(ctx, grad_output): # pragma: no cover
return (
_split(grad_output, ctx.dim, group=DistributedManager().group(ctx.group)),
None,
None,
)


class _GatherWithinParallelRegion(torch.autograd.Function):
"""
Gather the input within parallel region and concatenate.
The same forward method as _GatherFromParallelRegion, the difference is only in the
backward pass. This method performs a reduction of the gradients before the split in
the backward pass while the other version only performs a split
"""

@staticmethod
def symbolic(graph, input_, dim_, group_): # pragma: no cover
return _gather(input_, dim_, group=DistributedManager().group(group_))

@staticmethod
def forward(ctx, input_, dim_, group_): # pragma: no cover
ctx.dim = dim_
ctx.group = group_
return _gather(input_, dim_, group=DistributedManager().group(group_))

@staticmethod
def backward(ctx, grad_output): # pragma: no cover
red = _reduce(grad_output, group=DistributedManager().group(ctx.group_))
return (
_split(red, ctx.dim, group=DistributedManager().group(ctx.group)),
None,
None,
)

Expand All @@ -141,16 +131,6 @@ def scatter_to_parallel_region(input, dim, group): # pragma: no cover
return _ScatterToParallelRegion.apply(input, dim, group)


def gather_from_parallel_region(input, dim, group): # pragma: no cover
def gather_from_parallel_region(input, dim, shapes, group): # pragma: no cover
"""Gather the input from matmul parallel region and concatenate."""
return _GatherFromParallelRegion.apply(input, dim, group)


def gather_within_parallel_region(input, dim, group): # pragma: no cover
"""
Gather the input within parallel region and concatenate.
The same forward method as gather_from_parallel_region, the difference is only in
the backward pass. This method performs a reduction of the gradients before the
split in the backward pass while the other version only performs a split
"""
return _GatherWithinParallelRegion.apply(input, dim, group)
return _GatherFromParallelRegion.apply(input, dim, shapes, group)
107 changes: 52 additions & 55 deletions modulus/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,26 @@
from .manager import DistributedManager


def compute_split_shapes(size: int, num_chunks: int) -> List[int]:

# treat trivial case first
if num_chunks == 1:
return [size]

# first, check if we can split using div-up to balance the load:
chunk_size = (size + num_chunks - 1) // num_chunks
last_chunk_size = max(0, size - chunk_size * (num_chunks - 1))
if last_chunk_size == 0:
# in this case, the last shard would be empty, split with floor instead:
chunk_size = size // num_chunks
last_chunk_size = size - chunk_size * (num_chunks - 1)

# generate sections list
sections = [chunk_size for _ in range(num_chunks - 1)] + [last_chunk_size]

return sections


def get_memory_format(tensor):
"""Gets format for tensor"""
if tensor.is_contiguous(memory_format=torch.channels_last):
Expand Down Expand Up @@ -71,19 +91,19 @@ def truncate_helper(tensor, dim, new_size):


def split_tensor_along_dim(tensor, dim, num_chunks):
"""splits tensor along specific dim"""
if not (dim < tensor.dim()):
raise AssertionError(
f"Error, tensor dimension is {tensor.dim()} which cannot be"
f"split along {dim}"
if dim >= tensor.dim():
raise ValueError(
f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}"
)
if not (tensor.shape[dim] % num_chunks == 0):
raise AssertionError(
f"Error, cannot split dim {dim} evenly. Dim size is \
{tensor.shape[dim]} and requested numnber of splits is {num_chunks}"
if tensor.shape[dim] < num_chunks:
raise ValueError(
"Error, cannot split dim {dim} of size {tensor.shape[dim]} into \
{num_chunks} chunks. Empty slices are currently not supported."
)
chunk_size = tensor.shape[dim] // num_chunks
tensor_list = torch.split(tensor, chunk_size, dim=dim)

# get split
sections = compute_split_shapes(tensor.shape[dim], num_chunks)
tensor_list = torch.split(tensor, sections, dim=dim)

return tensor_list

Expand Down Expand Up @@ -198,39 +218,9 @@ def _split(input_, dim_, group=None): # pragma: no cover
return output


def _gather(input_, dim_, group=None): # pragma: no cover
"""Gather tensors and concatenate along the specified dimension."""
# get input format
input_format = get_memory_format(input_)

comm_size = dist.get_world_size(group=group)
# Bypass the function if we are using only 1 GPU.
if comm_size == 1:
return input_

# sanity checks
if not (dim_ < input_.dim()):
raise AssertionError(
f"Error, cannot gather along {dim_} for tensor with {input_.dim()} "
"dimensions."
)

# Size and dimension.
comm_rank = dist.get_rank(group=group)

tensor_list = [torch.empty_like(input_) for _ in range(comm_size)]
tensor_list[comm_rank] = input_
dist.all_gather(tensor_list, input_, group=group)

# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=dim_).contiguous(memory_format=input_format)

return output


def all_gather_v_wrapper(
tensor: torch.Tensor,
sizes: List[int],
sizes: Optional[List[int]] = None,
dim: int = 0,
group: Optional[dist.ProcessGroup] = None,
) -> torch.Tensor: # pragma: no cover
Expand All @@ -245,9 +235,9 @@ def all_gather_v_wrapper(
----------
tensor : "torch.Tensor"
local tensor on each rank
sizes : List[int]
sizes : List[int], optional
list of the sizes of each chunk on each rank along distributed dimension,
valid and set on each rank
valid and set on each rank, by default None
dim : int, optional
dimension along which global tensor is distributed, by default 0
group : Optional[dist.ProcessGroup], optional
Expand All @@ -260,7 +250,8 @@ def all_gather_v_wrapper(
"""

comm_size = dist.get_world_size(group=group)
if len(sizes) != comm_size:

if (sizes is not None) and (len(sizes) != comm_size):
raise ValueError()
if dim >= tensor.dim():
raise ValueError()
Expand All @@ -269,19 +260,25 @@ def all_gather_v_wrapper(
return tensor

tensor_shape = list(tensor.shape)
tensor_list = [None] * comm_size

for src in range(comm_size):
tensor_shape[dim] = sizes[src]
tensor_list[src] = torch.empty(
tensor_shape,
dtype=tensor.dtype,
device=tensor.device,
)
tensor_format = get_memory_format(tensor)

if sizes is not None:
tensor_list = [None] * comm_size

for src in range(comm_size):
tensor_shape[dim] = sizes[src]
tensor_list[src] = torch.empty(
tensor_shape,
dtype=tensor.dtype,
device=tensor.device,
)
else:
# assume equal shape on all ranks
tensor_list = [torch.empty_like(tensor) for _ in range(comm_size)]

dist.all_gather(tensor_list, tensor, group=group)

output = torch.cat(tensor_list, dim=dim)
output = torch.cat(tensor_list, dim=dim).contiguous(memory_format=tensor_format)

return output

Expand Down
8 changes: 7 additions & 1 deletion modulus/models/afno/distributed/afno.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
gather_from_parallel_region,
scatter_to_parallel_region,
)
from modulus.distributed.utils import compute_split_shapes
from modulus.models.afno.distributed.layers import (
DistributedAFNO2D,
DistributedMLP,
Expand Down Expand Up @@ -97,6 +98,9 @@ def __init__(

def forward(self, x):
if not self.input_is_matmul_parallel:
scatter_shapes = compute_split_shapes(
x.shape[1], DistributedManager().group_size("model_parallel")
)
x = scatter_to_parallel_region(x, dim=1, group="model_parallel")

residual = x
Expand All @@ -113,7 +117,9 @@ def forward(self, x):
x = x + residual

if not self.output_is_matmul_parallel:
x = gather_from_parallel_region(x, dim=1, group="model_parallel")
x = gather_from_parallel_region(
x, dim=1, shapes=scatter_shapes, group="model_parallel"
)

return x

Expand Down
25 changes: 22 additions & 3 deletions modulus/models/afno/distributed/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
reduce_from_parallel_region,
scatter_to_parallel_region,
)
from modulus.distributed.utils import compute_split_shapes


def _no_grad_trunc_normal_(tensor, mean, std, a, b):
Expand Down Expand Up @@ -163,6 +164,11 @@ def __init__(
self.act = act_layer()
self.drop = nn.Dropout(drop) if drop > 0.0 else nn.Identity()

if self.input_is_matmul_parallel:
self.gather_shapes = compute_split_shapes(
in_features, DistributedManager().group_size("model_parallel")
)

# init weights
self._init_weights()

Expand All @@ -175,7 +181,9 @@ def _init_weights(self):
def forward(self, x):
# gather if input is MP
if self.input_is_matmul_parallel:
x = gather_from_parallel_region(x, dim=1, group="model_parallel")
x = gather_from_parallel_region(
x, dim=1, shapes=self.gather_shapes, group="model_parallel"
)

x = copy_to_parallel_region(x, group="model_parallel")
x = F.conv2d(x, self.w1, bias=self.b1)
Expand Down Expand Up @@ -223,6 +231,9 @@ def __init__(
raise ValueError(
"Error, the in_chans needs to be divisible by matmul_parallel_size"
)
self.in_shapes = compute_split_shapes(
in_chans, DistributedManager().group_size("model_parallel")
)

# get effective embedding size:
if self.output_parallel:
Expand All @@ -245,7 +256,9 @@ def __init__(

def forward(self, x):
if self.input_parallel:
x = gather_from_parallel_region(x, dim=1, group="model_parallel")
x = gather_from_parallel_region(
x, dim=1, shapes=self.in_shapes, group="model_parallel"
)

if self.output_parallel:
x = copy_to_parallel_region(x, group="model_parallel")
Expand Down Expand Up @@ -373,6 +386,7 @@ def __init__(
def forward(self, x):
if not self.input_is_matmul_parallel:
# distribute data
num_chans = x.shape[1]
x = scatter_to_parallel_region(x, dim=1, group="model_parallel")

# bias
Expand Down Expand Up @@ -418,6 +432,11 @@ def forward(self, x):

# gather
if not self.output_is_matmul_parallel:
x = gather_from_parallel_region(x, dim=1, group="model_parallel")
gather_shapes = compute_split_shapes(
num_chans, DistributedManager().group_size("model_parallel")
)
x = gather_from_parallel_region(
x, dim=1, shapes=gather_shapes, group="model_parallel"
)

return x