Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Low-bit optim] Support for dcp.save() and dcp.load() #1217

Merged
merged 15 commits into from
Nov 9, 2024

Conversation

gau-nernst
Copy link
Collaborator

@gau-nernst gau-nernst commented Nov 3, 2024

Fixes #1189

  • To support dcp.save(), aten.detach and aten.is_pinned are required
  • To support dcp.load()
    • When world size does not change, no addition ops are needed
    • When world size changes, aten.slice is required

Thus this PR adds implementations for the above 3 ops for all low-bit optim state subclasses in torchao, as well as appropriate tests. Also did some minor housekeeping (e.g. format code, remove torch>=2.3 guard since we only test against torch>=2.3 now...).

Note: Low-bit optims are still not compatible with dcp.state_dict.get_optimizer_state_dict() due to pytorch/pytorch#139575

Copy link

pytorch-bot bot commented Nov 3, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1217

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 782a6b1 with merge base 0e854ec (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 3, 2024
@gau-nernst gau-nernst marked this pull request as ready for review November 4, 2024 05:35
@gau-nernst gau-nernst requested review from msaroufim and awgu November 4, 2024 05:50
@msaroufim msaroufim requested a review from vkuzo November 5, 2024 03:58
Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! for now stamping to unblock (assuming we fix CI). Lmk if you'd like a proper review - happy to take more time, would just need to catch up on low bit optimizers first.

@gau-nernst
Copy link
Collaborator Author

@vkuzo Sure, we can wait until CI is fixed (How do I know when CI is fixed? Seems like no tracking issue atm). No urgent, unless @nighting0le01 needs this patch merged to main soon?

@msaroufim
Copy link
Member

You can ignore the cpu nightly failure. That hasn't been root caused yet but is likely a runner specific issue

@gau-nernst
Copy link
Collaborator Author

New CI errors (seems like runner issue too?) https://github.com/pytorch/ao/actions/runs/11734156452/job/32689712737?pr=1217 - Will try rerun this later to see if it still persists.

Also a lot of build wheels CI are failing e.g.

@msaroufim
Copy link
Member

msaroufim commented Nov 8, 2024

glibc error was fixed
the conda arm error is new but should be fixed, @malfet is this related? pytorch/test-infra@709824e

@malfet
Copy link

malfet commented Nov 8, 2024

glibc error was fixed the conda arm error is new but should be fixed, @malfet is this related? pytorch/test-infra@709824e

@msaroufim yes, aarch64 build failures should have been fixed by that commit

@gau-nernst gau-nernst added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Nov 9, 2024
@gau-nernst gau-nernst merged commit 75f52ae into pytorch:main Nov 9, 2024
17 of 18 checks passed
@gau-nernst gau-nernst deleted the optim_fsdp_save_load branch November 9, 2024 07:11
jainapurva pushed a commit that referenced this pull request Nov 11, 2024
* support dcp.save

* add test for dcp.load()

* fix test

* typo

* implement aten.slice

* skip test

* fix checks

* run ruff

* fix formatting

* remove add safe globals in test

* sort some imports

---------

Co-authored-by: Mark Saroufim <[email protected]>
jainapurva pushed a commit that referenced this pull request Nov 12, 2024
* support dcp.save

* add test for dcp.load()

* fix test

* typo

* implement aten.slice

* skip test

* fix checks

* run ruff

* fix formatting

* remove add safe globals in test

* sort some imports

---------

Co-authored-by: Mark Saroufim <[email protected]>
sunjiweiswift pushed a commit to sunjiweiswift/ao that referenced this pull request Nov 25, 2024
* support dcp.save

* add test for dcp.load()

* fix test

* typo

* implement aten.slice

* skip test

* fix checks

* run ruff

* fix formatting

* remove add safe globals in test

* sort some imports

---------

Co-authored-by: Mark Saroufim <[email protected]>
nighting0le01 pushed a commit to nighting0le01/ao that referenced this pull request Dec 5, 2024
* support dcp.save

* add test for dcp.load()

* fix test

* typo

* implement aten.slice

* skip test

* fix checks

* run ruff

* fix formatting

* remove add safe globals in test

* sort some imports

---------

Co-authored-by: Mark Saroufim <[email protected]>

# input validation
if dim != 0:
raise ValueError("Only support aten.slice along the first dim")
Copy link

@nighting0le01 nighting0le01 Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gau-nernst this raises Valueerror when swithching from TP=1,DP=8 to DP=1,TP=8. why is it required.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of block-wise quantization, slicing in any dim > 0 is messy. It's doable, but just messy, and it does not always work (i.e. if you are slicing in the middle of a quantization block -> not possible).
I don't use TP before so I don't know how it does the sharding. Can you try print out the x.shape, dim, start, end? And possibly open a new issue so we can discuss there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Cannot run FSDP2 with low bit optim from AO
7 participants