Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SparseML Compression Pt 2: Load compressed weights #2184

Merged
merged 41 commits into from
Mar 20, 2024
Merged

Conversation

Satrat
Copy link

@Satrat Satrat commented Mar 15, 2024

This PR implements ModelCompressor.decompress(), which will decompress the weights in the safetensors file one by one. Also includes a bunch of helper functions for reading safetensors files and dealing with the compressed format. See the corresponding internal docs PR for design details

Note: #2177 needs to be merged first

To be implemented in follow-up PR

  • inferring sparsity config from model
  • SparseAutoModel save/load interface

Example

Sample code for compressing a model with 50% sparsity(See PR #2177), then reloading the compressed weights as a dense model

from sparseml.transformers import SparseAutoModelForCausalLM
from sparseml.transformers.compression import BitmaskConfig, BitmaskCompressor
from sparseml.utils.pytorch.utils import measure_cuda_memory
from tqdm import tqdm
import torch

MODEL_PATH = "zoo:llama2-7b-gsm8k_llama2_pretrain-pruned50.oneshot"
OUTPUT_PATH = "./test_compress_output"

torch.cuda.set_device(0)
with measure_cuda_memory() as m:
    model = SparseAutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="cuda:0")
print(f"Load dense model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB")

sparsity_config = BitmaskConfig()
compressor = BitmaskCompressor(config=sparsity_config)

# compresses the model using Bitmask compression
with measure_cuda_memory() as m:
    model_state_dict = model.state_dict()
    sparse_state_dict = compressor.compress(model_state_dict)

    # save the compressed model
    model.save_pretrained(
        OUTPUT_PATH, 
        safe_serialization=True, 
        state_dict=sparse_state_dict
    )

print(f"Save compressed model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB")

# use the dense state dict to reload the model
torch.cuda.set_device(1)
with measure_cuda_memory() as m:
    model_again = SparseAutoModelForCausalLM.from_pretrained(
        OUTPUT_PATH, 
        device_map="cuda:1"
    )

    #returns iterator
    dense_state_dict = compressor.decompress(OUTPUT_PATH)
    for name, data in tqdm(dense_state_dict, desc="Decompressing model"):
        BitmaskCompressor.replace_layer(name, data, model_again)

print(f"Load compressed model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB")

Load dense model peak GPU 25.2276 GB
Compressing model: 100%|████████████████████████████████████████████████████████████████████████████████████████| 291/291 [01:28<00:00, 3.29it/s]
Save compressed model peak GPU 25.2276 GB
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 6.27it/s]
Decompressing model: 291it [01:11, 4.08it/s]
Load compressed model peak GPU 25.7159 GB

@Satrat Satrat changed the title [Draft] SparseML Compression Pt 1: Load compressed weights [Draft] SparseML Compression Pt 2: Load compressed weights Mar 15, 2024
@Satrat Satrat changed the title [Draft] SparseML Compression Pt 2: Load compressed weights SparseML Compression Pt 2: Load compressed weights Mar 15, 2024
@Satrat Satrat marked this pull request as ready for review March 15, 2024 15:59
Base automatically changed from tensor_compression to main March 20, 2024 17:13
@Satrat Satrat dismissed stale reviews from bfineran and dbogunowicz March 20, 2024 17:13

The base branch was changed.

@Satrat Satrat requested review from dbogunowicz and bfineran March 20, 2024 17:13
bfineran
bfineran previously approved these changes Mar 20, 2024
dbogunowicz
dbogunowicz previously approved these changes Mar 20, 2024
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks, just one line I think was missed

@Satrat Satrat dismissed stale reviews from dbogunowicz and bfineran via 0335095 March 20, 2024 20:35
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@mgoin mgoin merged commit 121d7fe into main Mar 20, 2024
13 of 14 checks passed
@mgoin mgoin deleted the tensor_decompression branch March 20, 2024 21:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants