Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support FSDP in PyTorch #796

Open
priyakasimbeg opened this issue Oct 17, 2024 · 6 comments
Open

Support FSDP in PyTorch #796

priyakasimbeg opened this issue Oct 17, 2024 · 6 comments
Labels
👷 In Progress Issue is being worked on

Comments

@priyakasimbeg
Copy link
Contributor

priyakasimbeg commented Oct 17, 2024

It is useful to shard optimizer state across devices (to save significant memory). This reflects current practice. We want to support it.

  • We want to switch from no sharding to naive model parameter sharding in both framworks.
  • We will forbid (in the rules) any hacks that change the model parallelization strategy and have workload-default sharding.
  • Allow submitters to opt-out of it on a per-workload basis.
@priyakasimbeg
Copy link
Contributor Author

priyakasimbeg commented Oct 17, 2024

From meeting minutes from Michael Shi: Challenge is ensuring that JAX and PyTorch are equivalent. PyTorch should be doable by changing the DDP wrapper to the FSDP wrapper.

@IFFranciscoME
Copy link

...For the sake of self-referencing notes and approaches.

For the pytorch case, there are two ways for doing this:

For the Jax case, which is the one I am less familiar with:

  • Manual-parallelism : A single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data.
  • Scalax package : An external package with to write a model and training code for a single GPU/TPU, and rely on scalax to automatically scale it up to hundreds of GPUs/TPUs.

@davidtweedle
Copy link

davidtweedle commented Nov 5, 2024

Hi,
I have been working on this.

  • The code I have is running on cifar (on kaggle), and it seems to be fine.
  • The wrapping I used is by size, but we would need to make it equivalent to the JAX code.
  • I am getting some errors when I try to save the model checkpoints. This may have to do with the torch version, I am not sure.
    You can see the branch I am working from here:
    https://github.com/davidtweedle/algorithmic-efficiency/tree/fsdp_cifar

Edited to add: Also, I turned off torch.compile for this workload. I think that is also due to the pytorch version.

@priyakasimbeg
Copy link
Contributor Author

priyakasimbeg commented Nov 7, 2024

Thanks for the update!
The model checkpoints are expected to break, because they make specific assumptions about the model if I remember correctly. If I recall correctly, some submitters ran into issues with checkpointing because the checkpointing code also makes assumptions about the optimizer state. We probably want to fix the model checkpointing though as part of this FSDP migration. But I would focus on that at a later stage and just disable it for now if it is blocking.

Regarding the torch.compile, that seems a little more problematic. When you have time could you paste a traceback of the issue w torch compile (maybe with https://gist.github.com/) of in the GH issue thread. If the fix requires updating PyTorch, we should probably bump the priority on that.

@davidtweedle
Copy link

Hi,
OK for now I will disable the model checkpoints.
Here is a gist of the logs for this run.
https://gist.github.com/davidtweedle/a870a7dd0d409e920604565a2e08b638

I am not sure what to make of this error, yet.

Also, there is this related blog post: https://dev-discuss.pytorch.org/t/torchdynamo-update-11-making-fsdp-and-dynamo-work-together/1037

@davidtweedle
Copy link

Hi,
I hope it is appropriate to give a quick update on what could be going on here.
When the batch norm is updated during the training step, "module.apply" is called to update the batch norm.
This is called from the FSDP wrapper of the module which asserts that the training state must be "IDLE".
But calling apply from the FSDP wrapper means that the FSDP wrapper wants to all gather the different parameters, which is not necessary because all we want to do is tell the batch norm layers to keep track of the running stats.
So hopefully it is possible to apply "update_batch_norm_fn" without calling module.apply from the FSDP wrapper.

@priyakasimbeg priyakasimbeg added the 👷 In Progress Issue is being worked on label Jan 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
👷 In Progress Issue is being worked on
Projects
None yet
Development

No branches or pull requests

3 participants