Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Domain Parallelism with ShardTensor #784

Open
wants to merge 48 commits into
base: main
Choose a base branch
from

Conversation

coreyjadams
Copy link
Collaborator

@coreyjadams coreyjadams commented Feb 6, 2025

Modulus Pull Request

Description

This PR adds new capabilities to Modulus:

  • ShardTensor is an extension to pytorch DTensor that enables uneven sharding of tensors across DeviceMesh objects. While some logical sharding constraints remain, this allows more dynamic and flexible operation on distributed input data, especially in cases where the input data shape and output data shape differ.
  • ShardTensor also enables an ecosystem of operation extensions. Two major ones are included in this PR: convolutions (1D/2D/3D) and neighborhood attention. When the right components of modulus are imported, these operations (when performed on sharded tensors) will automatically compute halo regions and perform data transfers to enable results consistent with single device outputs.
    • For small data, this is not useful, but for extremely large data this is a powerful way to scale training on large inputs.
  • The documentation for Modulus now includes an API reference for ShardTensor, as well as an example of integrating multiple levels of parallelism by combining shard tensor and pytorch FSDP.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

Dependencies

Adds a dependency on wrapt for monkey-patching operations on sharded inputs..

coreyjadams and others added 30 commits December 17, 2024 08:52
…ieces are WIP but this has basic functionality supported for creation and forward usage.
…t of the ops have been validated, all that remains is to wrap the na2d function call to ensure it will dispatch properly.
…s also a minor bug in the backward

pass that got more pronounced with smaller data: grad inputs were failing to properly collect
haloed gradients and add them on the edges.  Now fixed.
…gnificant overhead.

I'm implementing here an option to switch to peer to peer message passing, since it might
benefit from stream utilization in layers like natten.na2d.

It's a developer choice currently, not a user choice.
…gnificant functionality changes in this commit.
Add `scatter_tensor` function to enable more easy transition to shard tensor.
This function allows users to maintain data pipelines (on one rank) and easily
scatter that data to a domain mesh.
But also, this adjusts the shard tensor mechanism for tracking shard info to use
a dict instead of a list of tuples.
No real code changes applied here.
@pzharrington
Copy link
Collaborator

\blossom-ci

@pzharrington
Copy link
Collaborator

Overall, I think this is looking good, nice work! I started with the documentation and then focused on unit tests to see overall functionality, as well as changes to the DistributedManager to see device mesh functionality and the main changes to what existed previously. Also looked at the ShardTensor definition, halo collectives and conv/natten patches, but didn’t spend much time on the other backend stuff.

Aside from the minor comments added, my main flag is to make the unit testing more complete, but I don’t think that should necessarily block merging. In particular I think for ops that we support (conv or nat, currently), we should add unit tests for correctness compared to a non-sharded baseline for forward and backward passes (subject to within some numerical tolerance, esp. in context of the neighborhood attention numerics we discovered).

Update tutorial based on feedback from @pzharrington
Remove wildcard import.
@coreyjadams
Copy link
Collaborator Author

Thanks for the review @pzharrington! I agree with you on the testing. Here's my thoughts:

  • Modulus should support unit tests of the basic ShardTensor functionality, "baked in". Most of those are there, but I have locally some tests in development regarding gradient propagation through sharded tensors. I would like to get them in but didn't want to hold the review.
  • Tests on numerical accuracy are probably too much for CI/CD and unit testing. I am working on exactly the tools you highlighted, but it's currently manually run and analyzed. It aims to support numerical checking (and performance benchmarking!) of all the operations we patch like this in modulus, as well as extending to more comprehensive layers and even full models. I'll get the repo up on gitlab to iron out the kinks and keep the numerical checking untied from the modulus release. FYI, apart from the issues we found in na2d for long sequence lengths, all patched operations are passing numerical checks.

@ktangsali
Copy link
Collaborator

/multi-gpu-ci

@coreyjadams coreyjadams added the ! - Release PRs or Issues releating to a release label Feb 11, 2025
@NickGeneva NickGeneva added the 4 - In Review Currently Under Review label Feb 11, 2025

In scientific AI applications, the parallelization techniques to enable state of the art
models are different from those used in training large language models. Modulus
introduces a new parallelization primitive called a ShardTensor that is designed for
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Codify ShardTensor to make it stand out better?

dist = DistributedManager()

# Create a 1D device mesh
mesh = DeviceMesh(device_type="cuda", mesh=[0, 1, 2, 3])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going straight to pytorch right? Would be better to route that through DistributedManager I think so that mesh objects don't need to be explicitly passed around and modules can query the DistributedManager for them like how they currently query it for the torch device, rank, process group handles, etc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I can update this script. I've updated DistributedManager in this script to enable DeviceMesh creation too.


def create_sample_data(batch_size=32, height=32, width=64):
# Create random data
data = torch.randn(batch_size, 3, height, width, device=f"cuda:{dm.local_rank}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change device to dm.device here


def setup_model():
# Create base model
model = SimpleCNN().to(f"cuda:{dm.local_rank}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also best to use dm.device here (thinking of future device to rank mapping optimizations that we might add)


return generic_conv_nd_wrapper(wrapped, instance, args, kwargs)

@wrapt.patch_function_wrapper('torch.nn.functional', 'conv2d')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a big fan of this monkey patching for a few reasons:

  1. It took me a long time to figure out how the example in the tutorial worked and to get to this line here.
  2. This doesn't work if this module is not imported.
  3. It breaks the modulus.Module.save and modulus.Module.from_checkpoint functionality (in a general case).

My preference would be to explicitly expose this as a DistributedConv2D module and maybe use the partition_fn argument in distribute_module to swap in DistributedConv2D instead of nn.Conv2D.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's an example for ColwiseParallel in torch

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the feedback I have been expecting! :)

I hear your point. I'll outline why I did it this way and not the otherway, maybe we'll find a path that meets all needs. Sorry it took awhile to figure out, I can (if we keep it) do a write up of how exactly this is working.

I wanted to meet these user-facing requirements in the implementation:

  • Easy to use interface. Ideally, the "model code" should remain unchanged or as close to unchanged as possible to enable domain parallelism.
    • We could, maybe, have a big registry of "nn.Conv2d -> modulus.distributed.nn.Conv2d" that gets used when distribute_module is applied as a partition_fn. Then, users would have to apply our partition function + any other partition functions needed.
    • The spirit of the partition function, though, appears to be to enable model parallelism (not domain parallelism as I am doing here). distribute_module is operating on the model parameters and not how those parameters are interacting with the input data.
  • Don't break un-sharded codes
    • I wanted the operation to get called correctly regardless of whether the input is sharded or not. Using a module replacement technique, it likely comes down to the same steps, but every module we built to do this would essentially be a custom layer.
    • I know you could do this with a module mapping/replacement too.
  • Support torch.nn.functional interface
    • What happens if you have a custom layer that stores a convolutional weight, bias and then calls an operation with torch.nn.functional.conv2d? This would need to be specifically enabled.

In the end, I went with this approach since it was most flexible to the user. The monkey patching will pick the right operation based on the input types, and in fact I built the sharded ops on even lower level operations (aten.convolution.default) for example so it's a one-op-fits all to support convolutions.

The best of both worlds, after writing this up, is likely to use this functional-level mapping and a registry for dispatch within ShardTensor itself. I think that's possible; it definitely won't be ready for this release.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, can you elaborate on how this breaks the save / from_checkpoint interface? I think it might actually be distribute_module that might break these, but in general this dispatch/dynamic interception of layers shouldn't break the way models are stored and loaded.

The one exception is if users need to shard their models parameters for domain sharding. This can occur if you're adding position embeddings or have other trainable parameters with shapes dependent on data shape. Then, we'd need to update save / from_checkpoint to ensure:

  • Models coalesce full parameters before saving.
  • Sharding / placements are (optionally) restored on loading.

# This will leave x as a Partial placement, meaning
# it isn't really sharded anymore but the results on the domain
# pieces haven't been computed yet.
x = self.pool(x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under the hood in the pooling layer, it's doing an aten.mean operation. That gets intercepted and redispatched by DTensor, which is the fallback dispatch for ShardTensor's dispatch. The mean operation is handled in upstream pytorch here:
https://github.com/pytorch/pytorch/blob/937b41e3b55e9dc682ac2027b5af17b1d7d75489/torch/distributed/tensor/_ops/_math_ops.py#L306-L355

In general, most operations can fall back to DTensor successfully. There is a small hiccup here in this operation (doesn't show up in this example, of course) actually: If this tensor was sharded asymmetrically, the mean would do local averaging followed by a global average when the output is needed. The global average needs to be weighted by the shape of the un-reduced tensor, but currently isn't. It's actually an expected-fail in the tests in this branch, for exactly this reason, but I don't think it should block progress since it's a relatively rare path to go down. It's on my agenda to fix but likely requires sub-classing the Partial placement in DTensor to catch and handle.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, traced that down to this piece of code that tells you that it will be a Partial placement: https://github.com/pytorch/pytorch/blob/937b41e3b55e9dc682ac2027b5af17b1d7d75489/torch/distributed/tensor/_ops/_math_ops.py#L248-L251.

When does this Partial placement get resolved? In the fc I assume? Is that basically calling a redistribute to Replicated and that triggers an allreduce? If that fc was rowwise (or colwise depending on convention) sharded, then it would redistribute to Shard(0) and that would trigger a reduce-scatter instead of an allreduce?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure when it resolves, actually. But what I think is happening is:

I want to look into this more: the details of when and why are important when parallelism is being composed.

# pieces haven't been computed yet.
x = self.pool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is automatically changed to all local ops by distribute_model?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it changes all the underlying parameters in the layers from torch.Tensor to torch.distributed.tensor.DTensor. DTensor comes with it's own dispatch, and when you call the thing returned by distribute_module with a DTensor or ShardTensor it will make a decision of how to process based on a) the operation being dispatched and b) the input tensor(s).

When you do the pool operation in this example, the pooling actually takes place over the sharded dimension. So, the output x on line 78 is a ShardTensor/DTensor with Replicate() placement instead of Shard(2) placement (which was the height dimension).

Subsequent operations will see that the weights in fc are Replicate sharding, and the inputs to fc are Replicate sharding, so it just has to use an all-local operation and wrap the output in a Replicate'd DTensor. ShardTensor then intercepts that and keeps the domain data as ShardTensor.

In the backwards pass, DTensor (and hence ShardTensor) should automatically un-replicate the gradients and shard them appropriately. So the input gradients to the upstream layers here (conv, relu) will receive shard tensors of the correct size.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that makes sense that fc weights are replicated since that's the default with distribute_model.

This actually brings up a good point. Ideally, we'd want to parallelize the fc in this example as well since it's a lot of redundant compute right? How would one do that? Through a custom partition_fn in distribute_model? That would work in this case. How would it work in a larger model with many different Linear layers which all might need different handling in partition_fn?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This model is so small that the fc is probably irrelevant in terms of compute, but yes, I get your point. I think weigh partitioning, in almost every case, should be handled with existing FSDP tools. That ought to allow most parallelizations to work out of the box, though it needs to be fully checked and confirmed.

This is on my to-do list for the next release, 3D parallelism over batching, domain sharding, and model parallelism. It's not ready in this release though.

inputs, targets = create_sample_data()

# Forward pass
outputs = model(inputs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presumably the outputs are replicated across the spatial mesh right? It's not clear to me where/how that happens.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, yes, they are replicated. In this example it happens specifically in the pool layer.

In general, it need not be replicated - depending on your model and loss function you may have different needs. For this classification example, where the labels are small and replicated, it makes sense to replicate them. If we didn't have the pool layer the loss computation would call outputs = outputs.full_tensor() which explicitly (and differentiably) summons the outputs.

On the other hand, for something like a segmentation loss where the output size == input size == label size, the computation of the loss may actually make more sense to perform as a distributed operation. In that case, a per-pixel softmax cross entropy between the output and the label would produce itself a sharded tensor. When you sum (or average) across the image shape, you'll be left with scalars on each rank in the Partial placement - meaning they are reduced on the local device but not the global device.

Most of this is transparently handled by the DTensor dispatch + operation rules. Combined with the generic way to redistribute tensors (from dtensor) + the reimplementation for asymmetrically sharded tensors here, most single-tensor operations and straightforward operations (like a + b for two DTensors) can work out of the box and under the hood.

@ktangsali
Copy link
Collaborator

/blossom-ci

4 similar comments
@ktangsali
Copy link
Collaborator

/blossom-ci

@ktangsali
Copy link
Collaborator

/blossom-ci

@ktangsali
Copy link
Collaborator

/blossom-ci

@ktangsali
Copy link
Collaborator

/blossom-ci

@ktangsali
Copy link
Collaborator

/blossom-ci

@coreyjadams
Copy link
Collaborator Author

/blossom-ci

1 similar comment
@ktangsali
Copy link
Collaborator

/blossom-ci

@ktangsali
Copy link
Collaborator

Reminder to change the target to 0.10.0-rc branch before merging.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
4 - In Review Currently Under Review ! - Release PRs or Issues releating to a release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants