-
Notifications
You must be signed in to change notification settings - Fork 282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable Domain Parallelism with ShardTensor #784
base: main
Are you sure you want to change the base?
Conversation
…simple DDP sharding
…ieces are WIP but this has basic functionality supported for creation and forward usage.
…t of the ops have been validated, all that remains is to wrap the na2d function call to ensure it will dispatch properly.
…ng in unbind op rules.
….ops.aten.convolution.default.
…s also a minor bug in the backward pass that got more pronounced with smaller data: grad inputs were failing to properly collect haloed gradients and add them on the edges. Now fixed.
…gnificant overhead. I'm implementing here an option to switch to peer to peer message passing, since it might benefit from stream utilization in layers like natten.na2d. It's a developer choice currently, not a user choice.
…gnificant functionality changes in this commit.
Add `scatter_tensor` function to enable more easy transition to shard tensor. This function allows users to maintain data pipelines (on one rank) and easily scatter that data to a domain mesh.
But also, this adjusts the shard tensor mechanism for tracking shard info to use a dict instead of a list of tuples.
No real code changes applied here.
\blossom-ci |
Overall, I think this is looking good, nice work! I started with the documentation and then focused on unit tests to see overall functionality, as well as changes to the Aside from the minor comments added, my main flag is to make the unit testing more complete, but I don’t think that should necessarily block merging. In particular I think for ops that we support (conv or nat, currently), we should add unit tests for correctness compared to a non-sharded baseline for forward and backward passes (subject to within some numerical tolerance, esp. in context of the neighborhood attention numerics we discovered). |
Thanks for the review @pzharrington! I agree with you on the testing. Here's my thoughts:
|
/multi-gpu-ci |
|
||
In scientific AI applications, the parallelization techniques to enable state of the art | ||
models are different from those used in training large language models. Modulus | ||
introduces a new parallelization primitive called a ShardTensor that is designed for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Codify ShardTensor
to make it stand out better?
dist = DistributedManager() | ||
|
||
# Create a 1D device mesh | ||
mesh = DeviceMesh(device_type="cuda", mesh=[0, 1, 2, 3]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is going straight to pytorch right? Would be better to route that through DistributedManager
I think so that mesh objects don't need to be explicitly passed around and modules can query the DistributedManager
for them like how they currently query it for the torch device, rank, process group handles, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I can update this script. I've updated DistributedManager
in this script to enable DeviceMesh creation too.
|
||
def create_sample_data(batch_size=32, height=32, width=64): | ||
# Create random data | ||
data = torch.randn(batch_size, 3, height, width, device=f"cuda:{dm.local_rank}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change device to dm.device
here
|
||
def setup_model(): | ||
# Create base model | ||
model = SimpleCNN().to(f"cuda:{dm.local_rank}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also best to use dm.device
here (thinking of future device to rank mapping optimizations that we might add)
|
||
return generic_conv_nd_wrapper(wrapped, instance, args, kwargs) | ||
|
||
@wrapt.patch_function_wrapper('torch.nn.functional', 'conv2d') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a big fan of this monkey patching for a few reasons:
- It took me a long time to figure out how the example in the tutorial worked and to get to this line here.
- This doesn't work if this module is not imported.
- It breaks the
modulus.Module.save
andmodulus.Module.from_checkpoint
functionality (in a general case).
My preference would be to explicitly expose this as a DistributedConv2D
module and maybe use the partition_fn
argument in distribute_module
to swap in DistributedConv2D
instead of nn.Conv2D
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's an example for ColwiseParallel
in torch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the feedback I have been expecting! :)
I hear your point. I'll outline why I did it this way and not the otherway, maybe we'll find a path that meets all needs. Sorry it took awhile to figure out, I can (if we keep it) do a write up of how exactly this is working.
I wanted to meet these user-facing requirements in the implementation:
- Easy to use interface. Ideally, the "model code" should remain unchanged or as close to unchanged as possible to enable domain parallelism.
- We could, maybe, have a big registry of "nn.Conv2d -> modulus.distributed.nn.Conv2d" that gets used when
distribute_module
is applied as a partition_fn. Then, users would have to apply our partition function + any other partition functions needed. - The spirit of the partition function, though, appears to be to enable model parallelism (not domain parallelism as I am doing here).
distribute_module
is operating on the model parameters and not how those parameters are interacting with the input data.
- We could, maybe, have a big registry of "nn.Conv2d -> modulus.distributed.nn.Conv2d" that gets used when
- Don't break un-sharded codes
- I wanted the operation to get called correctly regardless of whether the input is sharded or not. Using a module replacement technique, it likely comes down to the same steps, but every module we built to do this would essentially be a custom layer.
- I know you could do this with a module mapping/replacement too.
- Support
torch.nn.functional
interface- What happens if you have a custom layer that stores a convolutional weight, bias and then calls an operation with
torch.nn.functional.conv2d
? This would need to be specifically enabled.
- What happens if you have a custom layer that stores a convolutional weight, bias and then calls an operation with
In the end, I went with this approach since it was most flexible to the user. The monkey patching will pick the right operation based on the input types, and in fact I built the sharded ops on even lower level operations (aten.convolution.default
) for example so it's a one-op-fits all to support convolutions.
The best of both worlds, after writing this up, is likely to use this functional-level mapping and a registry for dispatch within ShardTensor
itself. I think that's possible; it definitely won't be ready for this release.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, can you elaborate on how this breaks the save / from_checkpoint interface? I think it might actually be distribute_module
that might break these, but in general this dispatch/dynamic interception of layers shouldn't break the way models are stored and loaded.
The one exception is if users need to shard their models parameters for domain sharding. This can occur if you're adding position embeddings or have other trainable parameters with shapes dependent on data shape. Then, we'd need to update save / from_checkpoint to ensure:
- Models coalesce full parameters before saving.
- Sharding / placements are (optionally) restored on loading.
# This will leave x as a Partial placement, meaning | ||
# it isn't really sharded anymore but the results on the domain | ||
# pieces haven't been computed yet. | ||
x = self.pool(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this work with DTensor? I don't see any pooling ops in https://github.com/pytorch/pytorch/tree/937b41e3b55e9dc682ac2027b5af17b1d7d75489/torch/distributed/tensor/_ops
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Under the hood in the pooling layer, it's doing an aten.mean
operation. That gets intercepted and redispatched by DTensor
, which is the fallback dispatch for ShardTensor
's dispatch. The mean operation is handled in upstream pytorch here:
https://github.com/pytorch/pytorch/blob/937b41e3b55e9dc682ac2027b5af17b1d7d75489/torch/distributed/tensor/_ops/_math_ops.py#L306-L355
In general, most operations can fall back to DTensor successfully. There is a small hiccup here in this operation (doesn't show up in this example, of course) actually: If this tensor was sharded asymmetrically, the mean would do local averaging followed by a global average when the output is needed. The global average needs to be weighted by the shape of the un-reduced tensor, but currently isn't. It's actually an expected-fail in the tests in this branch, for exactly this reason, but I don't think it should block progress since it's a relatively rare path to go down. It's on my agenda to fix but likely requires sub-classing the Partial
placement in DTensor to catch and handle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, traced that down to this piece of code that tells you that it will be a Partial
placement: https://github.com/pytorch/pytorch/blob/937b41e3b55e9dc682ac2027b5af17b1d7d75489/torch/distributed/tensor/_ops/_math_ops.py#L248-L251.
When does this Partial
placement get resolved? In the fc
I assume? Is that basically calling a redistribute
to Replicated
and that triggers an allreduce? If that fc
was rowwise (or colwise depending on convention) sharded, then it would redistribute
to Shard(0)
and that would trigger a reduce-scatter instead of an allreduce?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not 100% sure when it resolves, actually. But what I think is happening is:
- Ops have a set of rules defined, for input placements, to determine output placements.
- If the placements aren't going to match after computation, a redistribute is triggered. https://github.com/pytorch/pytorch/blob/937b41e3b55e9dc682ac2027b5af17b1d7d75489/torch/distributed/tensor/_dispatch.py#L177-L185
I want to look into this more: the details of when and why are important when parallelism is being composed.
# pieces haven't been computed yet. | ||
x = self.pool(x) | ||
x = torch.flatten(x, 1) | ||
x = self.fc(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is automatically changed to all local ops by distribute_model
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it changes all the underlying parameters in the layers from torch.Tensor
to torch.distributed.tensor.DTensor
. DTensor
comes with it's own dispatch, and when you call the thing returned by distribute_module
with a DTensor or ShardTensor it will make a decision of how to process based on a) the operation being dispatched and b) the input tensor(s).
When you do the pool
operation in this example, the pooling actually takes place over the sharded dimension. So, the output x
on line 78 is a ShardTensor/DTensor with Replicate()
placement instead of Shard(2)
placement (which was the height dimension).
Subsequent operations will see that the weights in fc
are Replicate
sharding, and the inputs to fc
are Replicate
sharding, so it just has to use an all-local operation and wrap the output in a Replicate
'd DTensor. ShardTensor then intercepts that and keeps the domain data as ShardTensor
.
In the backwards pass, DTensor (and hence ShardTensor) should automatically un-replicate the gradients and shard them appropriately. So the input gradients to the upstream layers here (conv, relu) will receive shard tensors of the correct size.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that makes sense that fc
weights are replicated since that's the default with distribute_model
.
This actually brings up a good point. Ideally, we'd want to parallelize the fc
in this example as well since it's a lot of redundant compute right? How would one do that? Through a custom partition_fn
in distribute_model
? That would work in this case. How would it work in a larger model with many different Linear
layers which all might need different handling in partition_fn
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This model is so small that the fc
is probably irrelevant in terms of compute, but yes, I get your point. I think weigh partitioning, in almost every case, should be handled with existing FSDP tools. That ought to allow most parallelizations to work out of the box, though it needs to be fully checked and confirmed.
This is on my to-do list for the next release, 3D parallelism over batching, domain sharding, and model parallelism. It's not ready in this release though.
inputs, targets = create_sample_data() | ||
|
||
# Forward pass | ||
outputs = model(inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Presumably the outputs are replicated across the spatial mesh right? It's not clear to me where/how that happens.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, yes, they are replicated. In this example it happens specifically in the pool
layer.
In general, it need not be replicated - depending on your model and loss function you may have different needs. For this classification example, where the labels are small and replicated, it makes sense to replicate them. If we didn't have the pool
layer the loss computation would call outputs = outputs.full_tensor()
which explicitly (and differentiably) summons the outputs.
On the other hand, for something like a segmentation loss where the output size == input size == label size, the computation of the loss may actually make more sense to perform as a distributed operation. In that case, a per-pixel softmax cross entropy between the output and the label would produce itself a sharded tensor. When you sum (or average) across the image shape, you'll be left with scalars on each rank in the Partial
placement - meaning they are reduced on the local device but not the global device.
Most of this is transparently handled by the DTensor dispatch + operation rules. Combined with the generic way to redistribute tensors (from dtensor) + the reimplementation for asymmetrically sharded tensors here, most single-tensor operations and straightforward operations (like a + b
for two DTensors) can work out of the box and under the hood.
/blossom-ci |
4 similar comments
/blossom-ci |
/blossom-ci |
/blossom-ci |
/blossom-ci |
/blossom-ci |
/blossom-ci |
1 similar comment
/blossom-ci |
Reminder to change the target to |
Modulus Pull Request
Description
This PR adds new capabilities to Modulus:
ShardTensor
is an extension to pytorchDTensor
that enables uneven sharding of tensors across DeviceMesh objects. While some logical sharding constraints remain, this allows more dynamic and flexible operation on distributed input data, especially in cases where the input data shape and output data shape differ.ShardTensor
also enables an ecosystem of operation extensions. Two major ones are included in this PR: convolutions (1D/2D/3D) and neighborhood attention. When the right components of modulus are imported, these operations (when performed on sharded tensors) will automatically compute halo regions and perform data transfers to enable results consistent with single device outputs.ShardTensor
, as well as an example of integrating multiple levels of parallelism by combining shard tensor and pytorchFSDP
.Checklist
Dependencies
Adds a dependency on
wrapt
for monkey-patching operations on sharded inputs..