Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Domain Parallelism with ShardTensor #784

Open
wants to merge 48 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
6c6a15e
Enable mesh-based parallelism as the configuration backend, even for …
coreyjadams Dec 17, 2024
91f2398
Fix small typo in docstring
coreyjadams Dec 17, 2024
e56b0b8
Remove unnecessary functions with new interface
coreyjadams Dec 18, 2024
897a464
Adding first implementation of ShardTensor prototype. Still several …
coreyjadams Jan 7, 2025
f28f537
Working implementation of ShardTensor, though still somewhate incompl…
coreyjadams Jan 10, 2025
ce99df6
Adding work-in-progress examples. Be careful of sharp edges!
coreyjadams Jan 13, 2025
025d8d2
A few more example pieces before natten will work out of the box. Mo…
coreyjadams Jan 14, 2025
b57050a
Merge branch 'NVIDIA:main' into distributed
coreyjadams Jan 16, 2025
b0e335c
Fix naming scheme
coreyjadams Jan 16, 2025
70b8ce5
Minor name change
coreyjadams Jan 16, 2025
9f19e36
Add monkey patching for na2d operation with shard tensors
coreyjadams Jan 16, 2025
2f06c07
Fix bug in shard tensor inference of globla size. CHeck agains shard…
coreyjadams Jan 16, 2025
71168a2
Enable backwards gradients for halo sharding and natten patch
coreyjadams Jan 20, 2025
72af11e
Convolution 2d backwards works, though would be better to catch torc…
coreyjadams Jan 27, 2025
305bad3
Fix missing import and ensure tensors are contiguous before allgather_v
coreyjadams Jan 27, 2025
d79c975
Clean up and remove unnecessary noise and printouts for debugging
coreyjadams Jan 28, 2025
1a4d886
Merge branch 'NVIDIA:main' into distributed
coreyjadams Jan 29, 2025
f7d063a
Unify (and correct!) the sharded convolution implementation. There w…
coreyjadams Jan 30, 2025
03615c5
Merge branch 'NVIDIA:main' into distributed
coreyjadams Jan 30, 2025
350ec41
Remove noise from sharding utils.
coreyjadams Jan 31, 2025
8342905
Merge branch 'distributed' of github.com:coreyjadams/modulus into dis…
coreyjadams Jan 31, 2025
a9f5484
For smaller tensors, the alltoall step of halo reductions might be si…
coreyjadams Feb 3, 2025
726ba92
Remove shard_utils file, it is a subfolder.
coreyjadams Feb 5, 2025
05ce224
Add modulus ShardTensor api documentation
coreyjadams Feb 5, 2025
f84cfd2
Clean up doc strings, type annotations and mesh implementation. No s…
coreyjadams Feb 5, 2025
73464e0
Add significant docstring / type annotation cleanup to ShardTensor.
coreyjadams Feb 5, 2025
e4ac9eb
Remove neighborhood attention prototypes
coreyjadams Feb 6, 2025
8a96e7e
Remove the rest of these examples since they are outdated and unneces…
coreyjadams Feb 6, 2025
584857d
Mostly, this commit is adding type annotations and doc strings.
coreyjadams Feb 6, 2025
a8b0592
Clean up and document conv patches.
coreyjadams Feb 6, 2025
5f848a5
clean up and improve documentation and type hints for shard utils wor…
coreyjadams Feb 6, 2025
a0b8f6a
Adding basic tests for shard tensor initialization and redistribution.
coreyjadams Feb 6, 2025
631bc9f
Add full working example of multilevel parallelism with pytorch
coreyjadams Feb 6, 2025
3f4a943
Merge branch 'NVIDIA:main' into distributed
coreyjadams Feb 6, 2025
0512728
Add missing type annotations
coreyjadams Feb 6, 2025
04cc895
Merge branch 'distributed' of github.com:coreyjadams/modulus into dis…
coreyjadams Feb 6, 2025
6384f4c
Merge branch 'main' into distributed
coreyjadams Feb 6, 2025
dc2b729
Ensure scatter_tensor is available to import from modulus.distributed
coreyjadams Feb 6, 2025
abb18f0
Merge branch 'distributed' of github.com:coreyjadams/modulus into dis…
coreyjadams Feb 6, 2025
7841c2b
Update changelog and ensure wrapt is a optional dependency
coreyjadams Feb 6, 2025
fad06e3
Update fsdp_and_shard_tensor.rst
coreyjadams Feb 7, 2025
b767159
Update __init__.py
coreyjadams Feb 7, 2025
f8aa96e
Update shard_tensor.py
coreyjadams Feb 7, 2025
25115a8
This is an essential bug fix for a missing import
coreyjadams Feb 11, 2025
1ec40a5
Merge branch 'main' into distributed
coreyjadams Feb 12, 2025
a624d17
Merge branch 'main' into distributed
coreyjadams Feb 13, 2025
67666f0
Update branch to pass CI tests.
coreyjadams Feb 13, 2025
0021891
Merge branch 'distributed' of github.com:coreyjadams/modulus into dis…
coreyjadams Feb 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added matrix decomposition scheme to improve graph partitioning
- DrivAerML dataset support in FIGConvNet example.
- Retraining recipe for DoMINO from a pretrained model checkpoint
- Prototype support for domain parallelism of using ShardTensor (new).
- Enable DeviceMesh initialization via DistributedManager.

### Changed

Expand All @@ -26,6 +28,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Deprecated

- ProcessGroupConfig is tagged for future deprecation in favor of DeviceMesh.

### Removed

### Fixed
Expand All @@ -41,6 +45,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Moved pytz and nvtx to optional
- Update the base image for the Dockerfile
- Introduce Multi-Storage Client (MSC) as an optional dependency.
- Introduce `wrapt` as an optional dependency, needed when using ShardTensor's automatic domain parallelism

## [0.9.0] - 2024-12-04

Expand Down
117 changes: 117 additions & 0 deletions docs/api/modulus.distributed.shardtensor.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@

Modulus Shard Tensor
===========

In scientific AI applications, the parallelization techniques to enable state of the art
models are different from those used in training large language models. Modulus
introduces a new parallelization primitive called a ShardTensor that is designed for
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Codify ShardTensor to make it stand out better?

large-input AI applications to enable domain parallelization.

ShardTensor provides a distributed tensor implementation that supports uneven sharding across devices.
It builds on PyTorch's DTensor while adding flexibility for cases where different ranks may have
different local tensor sizes.

The example below shows how to create and work with ShardTensor:

.. code:: python

import torch
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import Shard
from modulus.distributed import DistributedManager
from modulus.distributed.shard_tensor import ShardTensor, scatter_tensor

def main():
# Initialize distributed environment
DistributedManager.initialize()
dist = DistributedManager()

# Create a 1D device mesh
mesh = DeviceMesh(device_type="cuda", mesh=[0, 1, 2, 3])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going straight to pytorch right? Would be better to route that through DistributedManager I think so that mesh objects don't need to be explicitly passed around and modules can query the DistributedManager for them like how they currently query it for the torch device, rank, process group handles, etc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I can update this script. I've updated DistributedManager in this script to enable DeviceMesh creation too.


# Create a tensor on rank 0
if dist.rank == 0:
tensor = torch.randn(100, 64)
else:
tensor = None

# Scatter the tensor across devices with uneven sharding
# This will automatically determine appropriate local sizes
sharded = scatter_tensor(
tensor,
global_src=0,
mesh=mesh,
placements=(Shard(0),) # Shard along first dimension
)

# Work with local portions
local_tensor = sharded.to_local()

# Redistribute to different sharding scheme
new_sharded = sharded.redistribute(
placements=(Shard(1),) # Change to shard along second dimension
)

How does this work?
""""""""""""""""""

ShardTensor extends PyTorch's DTensor to support uneven sharding where different ranks can have different
local tensor sizes. It tracks shard size information and handles redistribution between different
sharding schemes while maintaining gradient flow.

Key differences from DTensor include:
- Support for uneven sharding where ranks have different local sizes
- Tracking and propagation of shard size information
- Custom collective operations optimized for uneven sharding
- Flexible redistribution between different sharding schemes

Operations work by:
1. Converting inputs to local tensors
2. Performing operations locally
3. Constructing new ShardTensor with appropriate sharding
4. Handling any needed communication between ranks

.. autosummary::
:toctree: generated

ShardTensor
-----------

.. autoclass:: modulus.distributed.shard_tensor.ShardTensor
:members:
:show-inheritance:

Utility Functions
----------------

.. autofunction:: modulus.distributed.shard_tensor.scatter_tensor


Why do we need this?
""""""""""""""""""""

During deep learning training, memory usage can grow significantly when working with large input data, even if the model itself is relatively small. This is because many operations create intermediate tensors that temporarily consume memory.

For example, consider a 2D convolution operation on a high-resolution image. If we have a batch of 1024x1024 images, even a simple 3x3 convolution needs to save the entire input image in memory for computing the gradients in the backward pass.

For high resolution images, this can easily lead to out of memory errors as model depth grows, even if the number of parameters is small - this is a significant contrast from LLM model training, where the memory usage is dominated by the number of parameters and the corresponding optimizer states. In software solutions like DeepSpeed and ZeRO, this is handled by partitioning the model across GPUs, but this is not a solution for large-input applications.

ShardTensor helps address this by:
- Distributing the input data across multiple devices
- Performing operations on smaller local portions
- Coordinating the necessary communication between devices in the forward and backward passes

ShardTensor is built as an extension of PyTorch's DTensor, and gains substantial functionality by leveraging the utilities already implemented in the PyTorch distributed package. However, some operations on sharded input data are not trivial to implement correctly, nor relevant to the model sharding problem. In Modulus, we have implemented parallelized versions of several key operations, including (so far):

- Convolution (1D, 2D, 3D)
- Neighborhood Attention (2D)
pzharrington marked this conversation as resolved.
Show resolved Hide resolved

These operations are implemented in the ``modulus.distributed.shard_utils`` module, and are enabled by dynamically intercepting calls to (for example) ``torch.nn.functional.conv2d``. When the function is called with ShardTensor inputs, the operation is automatically parallelized across the mesh associated with the input. When the function is called with non-ShardTensor inputs, the operation is executed in a non-parallelized manner, exactly as expected.

To enable these operations, you must import `patch_operations` from `modulus.distributed.shard_utils`. This will patch the relevant functions in the distributed package to support ShardTensor inputs.

We are continuing to add more operations, and contributions are welcome!




230 changes: 230 additions & 0 deletions docs/tutorials/fsdp_and_shard_tensor.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
ShardTensor and FSDP Tutorial
coreyjadams marked this conversation as resolved.
Show resolved Hide resolved
=============================

This tutorial demonstrates how to use Modulus's ShardTensor functionality alongside PyTorch's FSDP (Fully Sharded Data Parallel) to train a simple convolutional neural network. We'll show how to:

1. Create a simple CNN model
2. Set up input data sharding across multiple GPUs
3. Combine FSDP with domain decomposition
4. Train the model

Simple CNN Model
---------------

The preamble to the training script has an important patch to make sure that the conv2d operation works with ShardTensor:

.. code-block:: python

import torch

# This is necessary to patch Conv2d to work with ShardTensor
from modulus.distributed.shard_utils import patch_operations

import torch.nn as nn

from modulus.distributed import DistributedManager
from modulus.distributed.shard_tensor import ShardTensor
from torch.distributed.tensor import distribute_module, distribute_tensor
from torch.distributed.tensor.placement_types import Shard, Replicate
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

Next, setup the distributed environment including the device mesh. Here we do it globally,
but you can do it locally as well and pass device_mesh objects around.

Setting Up the Environment
------------------------

.. code-block:: python
# Initialize distributed environment
DistributedManager.initialize()
dm = DistributedManager()

# Create a 2D mesh for hybrid parallelism
# First dimension for data parallel, second for spatial decomposition
mesh = dm.initialize_mesh((-1, 2), mesh_dim_names=["data", "spatial"])

# Get submeshes for different parallel strategies
data_mesh = mesh["data"] # For FSDP
spatial_mesh = mesh["spatial"] # For spatial decomposition

First, let's create a simple one-layer CNN model:

.. code-block:: python

import torch
import torch.nn as nn
from modulus.distributed import DistributedManager
from modulus.distributed.shard_tensor import ShardTensor
from torch.distributed.tensor.placement_types import Shard
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.relu = nn.ReLU()
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(16, 10)

def forward(self, x):
# This is automatically parallel:
x = self.conv(x)
x = self.relu(x)
# This operation reduces on the parallel dimension.
# This will leave x as a Partial placement, meaning
# it isn't really sharded anymore but the results on the domain
# pieces haven't been computed yet.
x = self.pool(x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under the hood in the pooling layer, it's doing an aten.mean operation. That gets intercepted and redispatched by DTensor, which is the fallback dispatch for ShardTensor's dispatch. The mean operation is handled in upstream pytorch here:
https://github.com/pytorch/pytorch/blob/937b41e3b55e9dc682ac2027b5af17b1d7d75489/torch/distributed/tensor/_ops/_math_ops.py#L306-L355

In general, most operations can fall back to DTensor successfully. There is a small hiccup here in this operation (doesn't show up in this example, of course) actually: If this tensor was sharded asymmetrically, the mean would do local averaging followed by a global average when the output is needed. The global average needs to be weighted by the shape of the un-reduced tensor, but currently isn't. It's actually an expected-fail in the tests in this branch, for exactly this reason, but I don't think it should block progress since it's a relatively rare path to go down. It's on my agenda to fix but likely requires sub-classing the Partial placement in DTensor to catch and handle.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, traced that down to this piece of code that tells you that it will be a Partial placement: https://github.com/pytorch/pytorch/blob/937b41e3b55e9dc682ac2027b5af17b1d7d75489/torch/distributed/tensor/_ops/_math_ops.py#L248-L251.

When does this Partial placement get resolved? In the fc I assume? Is that basically calling a redistribute to Replicated and that triggers an allreduce? If that fc was rowwise (or colwise depending on convention) sharded, then it would redistribute to Shard(0) and that would trigger a reduce-scatter instead of an allreduce?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure when it resolves, actually. But what I think is happening is:

I want to look into this more: the details of when and why are important when parallelism is being composed.

x = torch.flatten(x, 1)
x = self.fc(x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is automatically changed to all local ops by distribute_model?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it changes all the underlying parameters in the layers from torch.Tensor to torch.distributed.tensor.DTensor. DTensor comes with it's own dispatch, and when you call the thing returned by distribute_module with a DTensor or ShardTensor it will make a decision of how to process based on a) the operation being dispatched and b) the input tensor(s).

When you do the pool operation in this example, the pooling actually takes place over the sharded dimension. So, the output x on line 78 is a ShardTensor/DTensor with Replicate() placement instead of Shard(2) placement (which was the height dimension).

Subsequent operations will see that the weights in fc are Replicate sharding, and the inputs to fc are Replicate sharding, so it just has to use an all-local operation and wrap the output in a Replicate'd DTensor. ShardTensor then intercepts that and keeps the domain data as ShardTensor.

In the backwards pass, DTensor (and hence ShardTensor) should automatically un-replicate the gradients and shard them appropriately. So the input gradients to the upstream layers here (conv, relu) will receive shard tensors of the correct size.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that makes sense that fc weights are replicated since that's the default with distribute_model.

This actually brings up a good point. Ideally, we'd want to parallelize the fc in this example as well since it's a lot of redundant compute right? How would one do that? Through a custom partition_fn in distribute_model? That would work in this case. How would it work in a larger model with many different Linear layers which all might need different handling in partition_fn?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This model is so small that the fc is probably irrelevant in terms of compute, but yes, I get your point. I think weigh partitioning, in almost every case, should be handled with existing FSDP tools. That ought to allow most parallelizations to work out of the box, though it needs to be fully checked and confirmed.

This is on my to-do list for the next release, 3D parallelism over batching, domain sharding, and model parallelism. It's not ready in this release though.

return x


Preparing Data with ShardTensor
-----------------------------

Create a simple dataset and shard it across devices:

.. code-block:: python

def create_sample_data(batch_size=32, height=32, width=64):
# Create random data
data = torch.randn(batch_size, 3, height, width, device=f"cuda:{dm.local_rank}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change device to dm.device here

labels = torch.randint(0, 10, (batch_size,), device=f"cuda:{dm.local_rank}")

# Convert to ShardTensor for spatial decomposition
placements = (Shard(2),) # Shard H dimensions
data = ShardTensor.from_local(
data,
device_mesh=spatial_mesh,
placements=placements
)

# For the labels, we can leverage DTensor to distribute them:
labels = ShardTensor.from_dtensor(
distribute_tensor(labels,
device_mesh=spatial_mesh,
placements=(Replicate(),)
)
)

return data, labels

Combining FSDP with Domain Decomposition
-------------------------------------

Set up the model with both FSDP and spatial decomposition:

.. code-block:: python

def setup_model():
# Create base model
model = SimpleCNN().to(f"cuda:{dm.local_rank}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also best to use dm.device here (thinking of future device to rank mapping optimizations that we might add)


# Take the module and distributed it over the spatial mesh
# This will replicate the model over the spatial mesh
# You can, if you want FSDP, get more fancy than this.
model = distribute_module(
model,
device_mesh=spatial_mesh,
)

# Wrap with FSDP
# Since the model is replicated, this will mimic DDP behavior.
model = FSDP(
model,
device_mesh=data_mesh,
use_orig_params=True
)


return model

Note that, above, we manually distribute the model over the spatial mesh, then setup FSDP over the data parallel mesh.


Training Loop
------------

Implement a basic training loop:

.. code-block:: python

def train_epoch(model, optimizer, criterion):
model.train()

for i in range(10): # 10 training steps
# Get sharded data
inputs, targets = create_sample_data()

# Forward pass
outputs = model(inputs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presumably the outputs are replicated across the spatial mesh right? It's not clear to me where/how that happens.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, yes, they are replicated. In this example it happens specifically in the pool layer.

In general, it need not be replicated - depending on your model and loss function you may have different needs. For this classification example, where the labels are small and replicated, it makes sense to replicate them. If we didn't have the pool layer the loss computation would call outputs = outputs.full_tensor() which explicitly (and differentiably) summons the outputs.

On the other hand, for something like a segmentation loss where the output size == input size == label size, the computation of the loss may actually make more sense to perform as a distributed operation. In that case, a per-pixel softmax cross entropy between the output and the label would produce itself a sharded tensor. When you sum (or average) across the image shape, you'll be left with scalars on each rank in the Partial placement - meaning they are reduced on the local device but not the global device.

Most of this is transparently handled by the DTensor dispatch + operation rules. Combined with the generic way to redistribute tensors (from dtensor) + the reimplementation for asymmetrically sharded tensors here, most single-tensor operations and straightforward operations (like a + b for two DTensors) can work out of the box and under the hood.


loss = criterion(outputs, targets)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()

if dm.rank == 0 and i % 2 == 0:
print(f"Step {i}, Loss: {loss.item():.4f}")

Main Training Script
------------------

Put it all together:

.. code-block:: python


def main():



# Create model and optimizer
model = setup_model()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Train for 5 epochs
for epoch in range(5):
if dm.rank == 0:
print(f"Epoch {epoch+1}")
train_epoch(model, optimizer, criterion)

# Cleanup
DistributedManager.cleanup()

if __name__ == "__main__":
main()


Running the Code
--------------

To run this example with 4 GPUs (2x2 mesh):

.. code-block:: bash

torchrun --nproc_per_node=4 train_cnn.py

This will train the model using both data parallelism (FSDP) and spatial decomposition (ShardTensor) across 4 GPUs in a 2x2 configuration.

Key Points
---------

1. The device mesh is split into two dimensions: one for data parallelism (FSDP) and one for spatial decomposition (ShardTensor)
coreyjadams marked this conversation as resolved.
Show resolved Hide resolved
2. Input data is sharded across the spatial dimension using ShardTensor
3. FSDP handles parameter sharding and optimization across the data parallel dimension
4. The model can process larger spatial dimensions efficiently by distributing the computation

This example demonstrates basic usage - for production use cases, you'll want to add:

- Proper data loading and preprocessing
- Model checkpointing
- Validation loop
- Learning rate scheduling
- Error handling
- Logging and metrics

For more advanced usage and configuration options, refer to the Modulus documentation on ShardTensor and the PyTorch FSDP documentation.
5 changes: 5 additions & 0 deletions modulus/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,8 @@
reduce_loss,
unmark_module_as_shared,
)

from .shard_tensor import ShardTensor, scatter_tensor

# Load and register custom ops:
from .custom_ops import *
coreyjadams marked this conversation as resolved.
Show resolved Hide resolved
Loading