-
Notifications
You must be signed in to change notification settings - Fork 216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Low-bit optim] Support for dcp.save()
and dcp.load()
#1217
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1217
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 782a6b1 with merge base 0e854ec ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice! for now stamping to unblock (assuming we fix CI). Lmk if you'd like a proper review - happy to take more time, would just need to catch up on low bit optimizers first.
@vkuzo Sure, we can wait until CI is fixed (How do I know when CI is fixed? Seems like no tracking issue atm). No urgent, unless @nighting0le01 needs this patch merged to main soon? |
You can ignore the cpu nightly failure. That hasn't been root caused yet but is likely a runner specific issue |
New CI errors (seems like runner issue too?) https://github.com/pytorch/ao/actions/runs/11734156452/job/32689712737?pr=1217 - Will try rerun this later to see if it still persists. Also a lot of build wheels CI are failing e.g. |
glibc error was fixed |
@msaroufim yes, aarch64 build failures should have been fixed by that commit |
* support dcp.save * add test for dcp.load() * fix test * typo * implement aten.slice * skip test * fix checks * run ruff * fix formatting * remove add safe globals in test * sort some imports --------- Co-authored-by: Mark Saroufim <[email protected]>
* support dcp.save * add test for dcp.load() * fix test * typo * implement aten.slice * skip test * fix checks * run ruff * fix formatting * remove add safe globals in test * sort some imports --------- Co-authored-by: Mark Saroufim <[email protected]>
* support dcp.save * add test for dcp.load() * fix test * typo * implement aten.slice * skip test * fix checks * run ruff * fix formatting * remove add safe globals in test * sort some imports --------- Co-authored-by: Mark Saroufim <[email protected]>
* support dcp.save * add test for dcp.load() * fix test * typo * implement aten.slice * skip test * fix checks * run ruff * fix formatting * remove add safe globals in test * sort some imports --------- Co-authored-by: Mark Saroufim <[email protected]>
|
||
# input validation | ||
if dim != 0: | ||
raise ValueError("Only support aten.slice along the first dim") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gau-nernst this raises Valueerror when swithching from TP=1,DP=8 to DP=1,TP=8. why is it required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because of block-wise quantization, slicing in any dim > 0 is messy. It's doable, but just messy, and it does not always work (i.e. if you are slicing in the middle of a quantization block -> not possible).
I don't use TP before so I don't know how it does the sharding. Can you try print out the x.shape, dim, start, end
? And possibly open a new issue so we can discuss there.
Fixes #1189
dcp.save()
,aten.detach
andaten.is_pinned
are requireddcp.load()
aten.slice
is requiredThus this PR adds implementations for the above 3 ops for all low-bit optim state subclasses in torchao, as well as appropriate tests. Also did some minor housekeeping (e.g. format code, remove torch>=2.3 guard since we only test against torch>=2.3 now...).
Note: Low-bit optims are still not compatible with
dcp.state_dict.get_optimizer_state_dict()
due to pytorch/pytorch#139575