-
Notifications
You must be signed in to change notification settings - Fork 282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable Domain Parallelism with ShardTensor #784
base: main
Are you sure you want to change the base?
Changes from 40 commits
6c6a15e
91f2398
e56b0b8
897a464
f28f537
ce99df6
025d8d2
b57050a
b0e335c
70b8ce5
9f19e36
2f06c07
71168a2
72af11e
305bad3
d79c975
1a4d886
f7d063a
03615c5
350ec41
8342905
a9f5484
726ba92
05ce224
f84cfd2
73464e0
e4ac9eb
8a96e7e
584857d
a8b0592
5f848a5
a0b8f6a
631bc9f
3f4a943
0512728
04cc895
6384f4c
dc2b729
abb18f0
7841c2b
fad06e3
b767159
f8aa96e
25115a8
1ec40a5
a624d17
67666f0
0021891
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
|
||
Modulus Shard Tensor | ||
=========== | ||
|
||
In scientific AI applications, the parallelization techniques to enable state of the art | ||
models are different from those used in training large language models. Modulus | ||
introduces a new parallelization primitive called a ShardTensor that is designed for | ||
large-input AI applications to enable domain parallelization. | ||
|
||
ShardTensor provides a distributed tensor implementation that supports uneven sharding across devices. | ||
It builds on PyTorch's DTensor while adding flexibility for cases where different ranks may have | ||
different local tensor sizes. | ||
|
||
The example below shows how to create and work with ShardTensor: | ||
|
||
.. code:: python | ||
|
||
import torch | ||
from torch.distributed.device_mesh import DeviceMesh | ||
from torch.distributed.tensor.placement_types import Shard | ||
from modulus.distributed import DistributedManager | ||
from modulus.distributed.shard_tensor import ShardTensor, scatter_tensor | ||
|
||
def main(): | ||
# Initialize distributed environment | ||
DistributedManager.initialize() | ||
dist = DistributedManager() | ||
|
||
# Create a 1D device mesh | ||
mesh = DeviceMesh(device_type="cuda", mesh=[0, 1, 2, 3]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is going straight to pytorch right? Would be better to route that through There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I can update this script. I've updated |
||
|
||
# Create a tensor on rank 0 | ||
if dist.rank == 0: | ||
tensor = torch.randn(100, 64) | ||
else: | ||
tensor = None | ||
|
||
# Scatter the tensor across devices with uneven sharding | ||
# This will automatically determine appropriate local sizes | ||
sharded = scatter_tensor( | ||
tensor, | ||
global_src=0, | ||
mesh=mesh, | ||
placements=(Shard(0),) # Shard along first dimension | ||
) | ||
|
||
# Work with local portions | ||
local_tensor = sharded.to_local() | ||
|
||
# Redistribute to different sharding scheme | ||
new_sharded = sharded.redistribute( | ||
placements=(Shard(1),) # Change to shard along second dimension | ||
) | ||
|
||
How does this work? | ||
"""""""""""""""""" | ||
|
||
ShardTensor extends PyTorch's DTensor to support uneven sharding where different ranks can have different | ||
local tensor sizes. It tracks shard size information and handles redistribution between different | ||
sharding schemes while maintaining gradient flow. | ||
|
||
Key differences from DTensor include: | ||
- Support for uneven sharding where ranks have different local sizes | ||
- Tracking and propagation of shard size information | ||
- Custom collective operations optimized for uneven sharding | ||
- Flexible redistribution between different sharding schemes | ||
|
||
Operations work by: | ||
1. Converting inputs to local tensors | ||
2. Performing operations locally | ||
3. Constructing new ShardTensor with appropriate sharding | ||
4. Handling any needed communication between ranks | ||
|
||
.. autosummary:: | ||
:toctree: generated | ||
|
||
ShardTensor | ||
----------- | ||
|
||
.. autoclass:: modulus.distributed.shard_tensor.ShardTensor | ||
:members: | ||
:show-inheritance: | ||
|
||
Utility Functions | ||
---------------- | ||
|
||
.. autofunction:: modulus.distributed.shard_tensor.scatter_tensor | ||
|
||
|
||
Why do we need this? | ||
"""""""""""""""""""" | ||
|
||
During deep learning training, memory usage can grow significantly when working with large input data, even if the model itself is relatively small. This is because many operations create intermediate tensors that temporarily consume memory. | ||
|
||
For example, consider a 2D convolution operation on a high-resolution image. If we have a batch of 1024x1024 images, even a simple 3x3 convolution needs to save the entire input image in memory for computing the gradients in the backward pass. | ||
|
||
For high resolution images, this can easily lead to out of memory errors as model depth grows, even if the number of parameters is small - this is a significant contrast from LLM model training, where the memory usage is dominated by the number of parameters and the corresponding optimizer states. In software solutions like DeepSpeed and ZeRO, this is handled by partitioning the model across GPUs, but this is not a solution for large-input applications. | ||
|
||
ShardTensor helps address this by: | ||
- Distributing the input data across multiple devices | ||
- Performing operations on smaller local portions | ||
- Coordinating the necessary communication between devices in the forward and backward passes | ||
|
||
ShardTensor is built as an extension of PyTorch's DTensor, and gains substantial functionality by leveraging the utilities already implemented in the PyTorch distributed package. However, some operations on sharded input data are not trivial to implement correctly, nor relevant to the model sharding problem. In Modulus, we have implemented parallelized versions of several key operations, including (so far): | ||
|
||
- Convolution (1D, 2D, 3D) | ||
- Neighborhood Attention (2D) | ||
pzharrington marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
These operations are implemented in the ``modulus.distributed.shard_utils`` module, and are enabled by dynamically intercepting calls to (for example) ``torch.nn.functional.conv2d``. When the function is called with ShardTensor inputs, the operation is automatically parallelized across the mesh associated with the input. When the function is called with non-ShardTensor inputs, the operation is executed in a non-parallelized manner, exactly as expected. | ||
|
||
To enable these operations, you must import `patch_operations` from `modulus.distributed.shard_utils`. This will patch the relevant functions in the distributed package to support ShardTensor inputs. | ||
|
||
We are continuing to add more operations, and contributions are welcome! | ||
|
||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
ShardTensor and FSDP Tutorial | ||
coreyjadams marked this conversation as resolved.
Show resolved
Hide resolved
|
||
============================= | ||
|
||
This tutorial demonstrates how to use Modulus's ShardTensor functionality alongside PyTorch's FSDP (Fully Sharded Data Parallel) to train a simple convolutional neural network. We'll show how to: | ||
|
||
1. Create a simple CNN model | ||
2. Set up input data sharding across multiple GPUs | ||
3. Combine FSDP with domain decomposition | ||
4. Train the model | ||
|
||
Simple CNN Model | ||
--------------- | ||
|
||
The preamble to the training script has an important patch to make sure that the conv2d operation works with ShardTensor: | ||
|
||
.. code-block:: python | ||
|
||
import torch | ||
|
||
# This is necessary to patch Conv2d to work with ShardTensor | ||
from modulus.distributed.shard_utils import patch_operations | ||
|
||
import torch.nn as nn | ||
|
||
from modulus.distributed import DistributedManager | ||
from modulus.distributed.shard_tensor import ShardTensor | ||
from torch.distributed.tensor import distribute_module, distribute_tensor | ||
from torch.distributed.tensor.placement_types import Shard, Replicate | ||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
|
||
Next, setup the distributed environment including the device mesh. Here we do it globally, | ||
but you can do it locally as well and pass device_mesh objects around. | ||
|
||
Setting Up the Environment | ||
------------------------ | ||
|
||
.. code-block:: python | ||
# Initialize distributed environment | ||
DistributedManager.initialize() | ||
dm = DistributedManager() | ||
|
||
# Create a 2D mesh for hybrid parallelism | ||
# First dimension for data parallel, second for spatial decomposition | ||
mesh = dm.initialize_mesh((-1, 2), mesh_dim_names=["data", "spatial"]) | ||
|
||
# Get submeshes for different parallel strategies | ||
data_mesh = mesh["data"] # For FSDP | ||
spatial_mesh = mesh["spatial"] # For spatial decomposition | ||
|
||
First, let's create a simple one-layer CNN model: | ||
|
||
.. code-block:: python | ||
|
||
import torch | ||
import torch.nn as nn | ||
from modulus.distributed import DistributedManager | ||
from modulus.distributed.shard_tensor import ShardTensor | ||
from torch.distributed.tensor.placement_types import Shard | ||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
|
||
class SimpleCNN(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1) | ||
self.relu = nn.ReLU() | ||
self.pool = nn.AdaptiveAvgPool2d((1, 1)) | ||
self.fc = nn.Linear(16, 10) | ||
|
||
def forward(self, x): | ||
# This is automatically parallel: | ||
x = self.conv(x) | ||
x = self.relu(x) | ||
# This operation reduces on the parallel dimension. | ||
# This will leave x as a Partial placement, meaning | ||
# it isn't really sharded anymore but the results on the domain | ||
# pieces haven't been computed yet. | ||
x = self.pool(x) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How does this work with DTensor? I don't see any pooling ops in https://github.com/pytorch/pytorch/tree/937b41e3b55e9dc682ac2027b5af17b1d7d75489/torch/distributed/tensor/_ops There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Under the hood in the pooling layer, it's doing an In general, most operations can fall back to DTensor successfully. There is a small hiccup here in this operation (doesn't show up in this example, of course) actually: If this tensor was sharded asymmetrically, the mean would do local averaging followed by a global average when the output is needed. The global average needs to be weighted by the shape of the un-reduced tensor, but currently isn't. It's actually an expected-fail in the tests in this branch, for exactly this reason, but I don't think it should block progress since it's a relatively rare path to go down. It's on my agenda to fix but likely requires sub-classing the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it, traced that down to this piece of code that tells you that it will be a When does this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not 100% sure when it resolves, actually. But what I think is happening is:
I want to look into this more: the details of when and why are important when parallelism is being composed. |
||
x = torch.flatten(x, 1) | ||
x = self.fc(x) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is automatically changed to all local ops by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, it changes all the underlying parameters in the layers from When you do the Subsequent operations will see that the weights in In the backwards pass, DTensor (and hence ShardTensor) should automatically un-replicate the gradients and shard them appropriately. So the input gradients to the upstream layers here (conv, relu) will receive shard tensors of the correct size. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, that makes sense that This actually brings up a good point. Ideally, we'd want to parallelize the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This model is so small that the This is on my to-do list for the next release, 3D parallelism over batching, domain sharding, and model parallelism. It's not ready in this release though. |
||
return x | ||
|
||
|
||
Preparing Data with ShardTensor | ||
----------------------------- | ||
|
||
Create a simple dataset and shard it across devices: | ||
|
||
.. code-block:: python | ||
|
||
def create_sample_data(batch_size=32, height=32, width=64): | ||
# Create random data | ||
data = torch.randn(batch_size, 3, height, width, device=f"cuda:{dm.local_rank}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Change device to |
||
labels = torch.randint(0, 10, (batch_size,), device=f"cuda:{dm.local_rank}") | ||
|
||
# Convert to ShardTensor for spatial decomposition | ||
placements = (Shard(2),) # Shard H dimensions | ||
data = ShardTensor.from_local( | ||
data, | ||
device_mesh=spatial_mesh, | ||
placements=placements | ||
) | ||
|
||
# For the labels, we can leverage DTensor to distribute them: | ||
labels = ShardTensor.from_dtensor( | ||
distribute_tensor(labels, | ||
device_mesh=spatial_mesh, | ||
placements=(Replicate(),) | ||
) | ||
) | ||
|
||
return data, labels | ||
|
||
Combining FSDP with Domain Decomposition | ||
------------------------------------- | ||
|
||
Set up the model with both FSDP and spatial decomposition: | ||
|
||
.. code-block:: python | ||
|
||
def setup_model(): | ||
# Create base model | ||
model = SimpleCNN().to(f"cuda:{dm.local_rank}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also best to use |
||
|
||
# Take the module and distributed it over the spatial mesh | ||
# This will replicate the model over the spatial mesh | ||
# You can, if you want FSDP, get more fancy than this. | ||
model = distribute_module( | ||
model, | ||
device_mesh=spatial_mesh, | ||
) | ||
|
||
# Wrap with FSDP | ||
# Since the model is replicated, this will mimic DDP behavior. | ||
model = FSDP( | ||
model, | ||
device_mesh=data_mesh, | ||
use_orig_params=True | ||
) | ||
|
||
|
||
return model | ||
|
||
Note that, above, we manually distribute the model over the spatial mesh, then setup FSDP over the data parallel mesh. | ||
|
||
|
||
Training Loop | ||
------------ | ||
|
||
Implement a basic training loop: | ||
|
||
.. code-block:: python | ||
|
||
def train_epoch(model, optimizer, criterion): | ||
model.train() | ||
|
||
for i in range(10): # 10 training steps | ||
# Get sharded data | ||
inputs, targets = create_sample_data() | ||
|
||
# Forward pass | ||
outputs = model(inputs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Presumably the outputs are replicated across the spatial mesh right? It's not clear to me where/how that happens. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, yes, they are replicated. In this example it happens specifically in the In general, it need not be replicated - depending on your model and loss function you may have different needs. For this classification example, where the labels are small and replicated, it makes sense to replicate them. If we didn't have the On the other hand, for something like a segmentation loss where the output size == input size == label size, the computation of the loss may actually make more sense to perform as a distributed operation. In that case, a per-pixel softmax cross entropy between the output and the label would produce itself a sharded tensor. When you sum (or average) across the image shape, you'll be left with scalars on each rank in the Most of this is transparently handled by the DTensor dispatch + operation rules. Combined with the generic way to redistribute tensors (from dtensor) + the reimplementation for asymmetrically sharded tensors here, most single-tensor operations and straightforward operations (like |
||
|
||
loss = criterion(outputs, targets) | ||
# Backward and optimize | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
if dm.rank == 0 and i % 2 == 0: | ||
print(f"Step {i}, Loss: {loss.item():.4f}") | ||
|
||
Main Training Script | ||
------------------ | ||
|
||
Put it all together: | ||
|
||
.. code-block:: python | ||
|
||
|
||
def main(): | ||
|
||
|
||
|
||
# Create model and optimizer | ||
model = setup_model() | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) | ||
criterion = nn.CrossEntropyLoss() | ||
|
||
# Train for 5 epochs | ||
for epoch in range(5): | ||
if dm.rank == 0: | ||
print(f"Epoch {epoch+1}") | ||
train_epoch(model, optimizer, criterion) | ||
|
||
# Cleanup | ||
DistributedManager.cleanup() | ||
|
||
if __name__ == "__main__": | ||
main() | ||
|
||
|
||
Running the Code | ||
-------------- | ||
|
||
To run this example with 4 GPUs (2x2 mesh): | ||
|
||
.. code-block:: bash | ||
|
||
torchrun --nproc_per_node=4 train_cnn.py | ||
|
||
This will train the model using both data parallelism (FSDP) and spatial decomposition (ShardTensor) across 4 GPUs in a 2x2 configuration. | ||
|
||
Key Points | ||
--------- | ||
|
||
1. The device mesh is split into two dimensions: one for data parallelism (FSDP) and one for spatial decomposition (ShardTensor) | ||
coreyjadams marked this conversation as resolved.
Show resolved
Hide resolved
|
||
2. Input data is sharded across the spatial dimension using ShardTensor | ||
3. FSDP handles parameter sharding and optimization across the data parallel dimension | ||
4. The model can process larger spatial dimensions efficiently by distributing the computation | ||
|
||
This example demonstrates basic usage - for production use cases, you'll want to add: | ||
|
||
- Proper data loading and preprocessing | ||
- Model checkpointing | ||
- Validation loop | ||
- Learning rate scheduling | ||
- Error handling | ||
- Logging and metrics | ||
|
||
For more advanced usage and configuration options, refer to the Modulus documentation on ShardTensor and the PyTorch FSDP documentation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Codify
ShardTensor
to make it stand out better?