-
Notifications
You must be signed in to change notification settings - Fork 149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SparseML Compression Pt 2: Load compressed weights #2184
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
…into tensor_compression
Satrat
changed the title
[Draft] SparseML Compression Pt 1: Load compressed weights
[Draft] SparseML Compression Pt 2: Load compressed weights
Mar 15, 2024
Satrat
changed the title
[Draft] SparseML Compression Pt 2: Load compressed weights
SparseML Compression Pt 2: Load compressed weights
Mar 15, 2024
Satrat
requested review from
mgoin,
bfineran,
dsikka,
horheynm,
dbogunowicz and
rahul-tuli
March 15, 2024 15:59
Satrat
dismissed stale reviews from bfineran and dbogunowicz
March 20, 2024 17:13
The base branch was changed.
bfineran
previously approved these changes
Mar 20, 2024
dbogunowicz
previously approved these changes
Mar 20, 2024
mgoin
reviewed
Mar 20, 2024
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks, just one line I think was missed
src/sparseml/transformers/compression/compressors/sparse_bitmask.py
Outdated
Show resolved
Hide resolved
mgoin
approved these changes
Mar 20, 2024
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR implements ModelCompressor.decompress(), which will decompress the weights in the safetensors file one by one. Also includes a bunch of helper functions for reading safetensors files and dealing with the compressed format. See the corresponding internal docs PR for design details
Note: #2177 needs to be merged first
To be implemented in follow-up PR
Example
Sample code for compressing a model with 50% sparsity(See PR #2177), then reloading the compressed weights as a dense model
Load dense model peak GPU 25.2276 GB
Compressing model: 100%|████████████████████████████████████████████████████████████████████████████████████████| 291/291 [01:28<00:00, 3.29it/s]
Save compressed model peak GPU 25.2276 GB
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 6.27it/s]
Decompressing model: 291it [01:11, 4.08it/s]
Load compressed model peak GPU 25.7159 GB