Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SparseML Compression Pt 1: saving w/compression configs #2177

Merged
merged 22 commits into from
Mar 20, 2024
Merged

Conversation

Satrat
Copy link

@Satrat Satrat commented Mar 12, 2024

This initial PR for safetensors compression sets up the CompressionConfig and ModelCompressor registries and implements bitmask compression on save. See the corresponding internal docs PR for design details

To be implemented in follow-up PRs

  • bitmask decompression
  • saving/loading/inferring sparsity config from model
  • SparseAutoModel load/save interface

Example

from sparseml.transformers import SparseAutoModelForCausalLM
from sparseml.transformers.compression import BitmaskConfig, BitmaskCompressor
from safetensors import safe_open
import os

MODEL_PATH = "zoo:llama2-7b-gsm8k_llama2_pretrain-pruned50.oneshot"
OUTPUT_PATH = "./test_compress_output"

model = SparseAutoModelForCausalLM.from_pretrained(MODEL_PATH)

sparsity_config = BitmaskConfig()
compressor = BitmaskCompressor(config=sparsity_config)

model_state_dict = model.state_dict()
sparse_state_dict = compressor.compress(model_state_dict)


model.save_pretrained(OUTPUT_PATH, safe_serialization=True, state_dict=sparse_state_dict)

safetensors_path = os.path.join(OUTPUT_PATH, "model-00001-of-00002.safetensors")
with safe_open(safetensors_path, framework="pt", device=0) as f:
    test_name = "model.layers.4.self_attn.k_proj.weight"
    bitmask = f.get_tensor(test_name + ".bitmask")
    shape = f.get_tensor(test_name + ".shape")
    values = f.get_tensor(test_name + ".compressed")
    row_offsets = f.get_tensor(test_name + ".row_offsets")
    print(f"bitmask: {bitmask}")
    print(f"shape: {shape}")
    print(f"values: {values}")
    print(f"row offsets: {row_offsets}")

dbogunowicz
dbogunowicz previously approved these changes Mar 13, 2024
Copy link
Contributor

@dbogunowicz dbogunowicz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple of nitpicks. All in all great job, definitely well done on a bite-sized scope of this PR

setup.py Show resolved Hide resolved
bfineran
bfineran previously approved these changes Mar 18, 2024
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job Sara, this looks good to me for saving. Just some comments on state dict conflicts and tensor devices

@Satrat Satrat requested review from bfineran and mgoin March 19, 2024 14:50
dbogunowicz
dbogunowicz previously approved these changes Mar 20, 2024
Copy link
Contributor

@dbogunowicz dbogunowicz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎅

Sara Adkins and others added 2 commits March 20, 2024 11:31
@Satrat Satrat requested a review from dbogunowicz March 20, 2024 15:32
@Satrat Satrat merged commit dead8b5 into main Mar 20, 2024
13 of 14 checks passed
@Satrat Satrat deleted the tensor_compression branch March 20, 2024 17:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants