Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SparseML Compression Pt 2: Load compressed weights #2184

Merged
merged 41 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
45a16ed
initial classes
Mar 12, 2024
a7cee23
WIP
Mar 12, 2024
e1549e8
compression working
Mar 12, 2024
92ba386
unit tests and README
Mar 12, 2024
f061d78
docstrings
Mar 12, 2024
40a75a9
README and fix test
Mar 12, 2024
522813c
add bitmask source
Mar 12, 2024
c6d0b4d
Merge branch 'main' into tensor_compression
Mar 12, 2024
118c223
initial commit
Mar 14, 2024
be2223f
compression working
Mar 14, 2024
c07d36a
formatting
Mar 14, 2024
d2a8a78
cleanup
Mar 15, 2024
1096700
dtype tests
Mar 15, 2024
1749b28
Merge branch 'main' into tensor_compression
Mar 15, 2024
013d17b
oops fix test
Mar 15, 2024
813c8e7
Merge branch 'tensor_compression' of github.com:neuralmagic/sparseml …
Mar 15, 2024
41223bb
tests
Mar 15, 2024
2c6eeba
add bfloat16
Mar 15, 2024
2d515fa
Merge branch 'tensor_compression' into tensor_decompression
Mar 15, 2024
35e1dba
unit tests
Mar 15, 2024
dd9d82f
docstrings
Mar 15, 2024
e473ef3
update README
Mar 15, 2024
08e039e
move statements to debug
Mar 18, 2024
e369710
warn on conflicts, store device
Mar 19, 2024
e7fb048
Merge branch 'main' into tensor_compression
Mar 19, 2024
52a916a
Merge branch 'tensor_compression' into tensor_decompression
Mar 19, 2024
1ecf58b
merge conflict
Mar 19, 2024
59ff306
fix typing
Mar 19, 2024
bb348d2
helper fn for setting layers
Mar 19, 2024
d81dd5c
expand cuda memory helper
Mar 19, 2024
bb973c2
update example and tests
Mar 19, 2024
a6b83da
update docstrings
Mar 19, 2024
aee4575
remove unneeded file
Mar 20, 2024
e0fe017
Merge branch 'tensor_compression' into tensor_decompression
Mar 20, 2024
1ba94ed
Update src/sparseml/pytorch/model_load/helpers.py
Mar 20, 2024
e5e1215
Update README.md
Mar 20, 2024
3958525
Merge branch 'main' into tensor_compression
Mar 20, 2024
38e3c6d
Merge branch 'tensor_compression' into tensor_decompression
Mar 20, 2024
830a9f5
Merge branch 'main' into tensor_decompression
Mar 20, 2024
7fcd5c3
Merge branch 'main' into tensor_decompression
Mar 20, 2024
0335095
fix merge conflict error
Mar 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 35 additions & 20 deletions src/sparseml/transformers/compression/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,32 +35,47 @@ needed for decompression in the compressed state_dict:
```python
from sparseml.transformers import SparseAutoModelForCausalLM
from sparseml.transformers.compression import BitmaskConfig, BitmaskCompressor
from safetensors import safe_open
import os
from sparseml.utils.pytorch.utils import measure_cuda_memory
from tqdm import tqdm
import torch

MODEL_PATH = "zoo:llama2-7b-gsm8k_llama2_pretrain-pruned50.oneshot"
OUTPUT_PATH = "./test_compress_output"

model = SparseAutoModelForCausalLM.from_pretrained(MODEL_PATH)
torch.cuda.set_device(0)
with measure_cuda_memory() as m:
model = SparseAutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="cuda:0")
print(f"Load dense model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB")

sparsity_config = BitmaskConfig()
compressor = BitmaskCompressor(config=sparsity_config)

model_state_dict = model.state_dict()
sparse_state_dict = compressor.compress(model_state_dict)


model.save_pretrained(OUTPUT_PATH, safe_serialization=True, state_dict=sparse_state_dict)

safetensors_path = os.path.join(OUTPUT_PATH, "model-00001-of-00002.safetensors")
with safe_open(safetensors_path, framework="pt", device=0) as f:
test_name = "model.layers.4.self_attn.k_proj.weight"
bitmask = f.get_tensor(test_name + ".bitmask")
shape = f.get_tensor(test_name + ".shape")
values = f.get_tensor(test_name + ".compressed")
row_offsets = f.get_tensor(test_name + ".row_offsets")
print(f"bitmask: {bitmask}")
print(f"shape: {shape}")
print(f"values: {values}")
print(f"row offsets: {row_offsets}")
# compresses the model using Bitmask compression
with measure_cuda_memory() as m:
model_state_dict = model.state_dict()
sparse_state_dict = compressor.compress(model_state_dict)

# save the compressed model
model.save_pretrained(
OUTPUT_PATH,
safe_serialization=True,
state_dict=sparse_state_dict
)

print(f"Save compressed model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB")

# use the dense state dict to reload the model
torch.cuda.set_device(1)
with measure_cuda_memory() as m:
model_again = SparseAutoModelForCausalLM.from_pretrained(
OUTPUT_PATH,
device_map="cuda:1"
)

#returns iterator
dense_state_dict = compressor.decompress(OUTPUT_PATH)
for name, data in tqdm(dense_state_dict, desc="Decompressing model"):
BitmaskCompressor.replace_layer(name, data, model_again)

print(f"Load compressed model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB")
```
25 changes: 21 additions & 4 deletions src/sparseml/transformers/compression/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict
import operator
from typing import Dict, Generator

from torch import Tensor
from torch.nn import Module, Parameter

from sparseml.transformers.compression.config import CompressionConfig
from sparseml.utils.pytorch.module import set_layer
from sparsezoo.utils.registry import RegistryMixin


Expand All @@ -42,11 +45,25 @@ def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
raise NotImplementedError()

def decompress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
def decompress(self, model_path: str) -> Generator:
"""
Uncompresses a compressed state dict back to dense
Reads a compressed state dict located at model_path and returns a
generator for sequentially decompressing back to a dense state dict

:param model_state: state dict of uncompressed model
:param model_path: path to compressed safetensors model
:return: compressed state dict
"""
raise NotImplementedError()

@staticmethod
def replace_layer(param_name: str, data: Tensor, model: Module):
"""
Overwrites a parameterized layer with a new tensor, maintaining the device of
the original parameter

:param param_name: name of parameterized layer to replace
:param data: tensor to insert into model
:param model: pytorch model to insert data into
"""
model_device = operator.attrgetter(param_name)(model).device
set_layer(param_name, Parameter(data.to(model_device)), model)
91 changes: 62 additions & 29 deletions src/sparseml/transformers/compression/compressors/sparse_bitmask.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,19 @@
# limitations under the License.

import logging
from typing import Dict, Tuple
from typing import Dict, Generator, List, Tuple, Union

import numpy
import torch
from torch import Tensor
from tqdm import tqdm

from safetensors import safe_open
from sparseml.transformers.compression.compressors import ModelCompressor
from sparseml.transformers.compression.utils import (
get_nested_weight_mappings,
merge_names,
)


__all__ = [
Expand All @@ -42,6 +47,8 @@ class BitmaskCompressor(ModelCompressor):
values tensor, with their locations stored in a 2d bitmask
"""

COMPRESSION_PARAM_NAMES = ["shape", "compressed", "bitmask", "row_offsets"]

def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
Compresses a dense state dict using bitmask compression
Expand All @@ -50,12 +57,12 @@ def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
:return: compressed state dict
"""
compressed_dict = {}
_LOGGER.info(
_LOGGER.debug(
f"Compressing model with {len(model_state)} parameterized layers..."
)
for name, value in tqdm(model_state.items()):
bitmask_tensor = BitmaskTensor(value)
bitmask_dict = bitmask_tensor.dict(name_prefix=name)
for name, value in tqdm(model_state.items(), desc="Compressing model"):
bitmask_tensor = BitmaskTensor.from_dense(value)
bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu")
for key in bitmask_dict.keys():
if key in compressed_dict:
_LOGGER.warn(
Expand All @@ -67,42 +74,68 @@ def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:

return compressed_dict

def decompress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
def decompress(self, model_path: str) -> Generator:
"""
Uncompresses a bitmask compressed state dict back to dense
Reads a bitmask compressed state dict located at model_path and returns a
generator for sequentially decompressing back to a dense state dict

:param model_state: state dict of uncompressed model
:return: compressed state dict
:param model_path: path to compressed safetensors model
:return: iterator for generating decompressed weights
"""
raise NotImplementedError()
weight_mappings = get_nested_weight_mappings(
Satrat marked this conversation as resolved.
Show resolved Hide resolved
model_path, self.COMPRESSION_PARAM_NAMES
)
for weight_name in weight_mappings.keys():
weight_data = {}
for param_name, safe_path in weight_mappings[weight_name].items():
full_name = merge_names(weight_name, param_name)
with safe_open(safe_path, framework="pt", device="cpu") as f:
weight_data[param_name] = f.get_tensor(full_name)
data = BitmaskTensor(**weight_data)
decompressed = data.decompress()
yield weight_name, decompressed


class BitmaskTensor:
"""
Owns compressions and decompression for a single bitmask compressed tensor.
Adapted from: https://github.com/mgoin/torch_bitmask/tree/main

:param tensor: Dense tensor to compress
:param shape: shape of dense tensor
:compressed: flat tensor of non-zero values
:bitmask: 2d bitmask of non-zero values
:row_offsets: flat tensor indicating what index in values each dense row starts at
"""

def __init__(self, tensor: Tensor):
self.dense_device = tensor.device
self.shape = tensor.shape
self.values, self.bitmasks, self.row_offsets = bitmask_compress(tensor.cpu())

def decompress(self) -> Tensor:
"""
:return: reconstructed dense tensor
"""
return bitmask_decompress(self.values, self.bitmasks, self.shape)
def __init__(
self,
shape: Union[torch.Size, List],
compressed: Tensor,
bitmask: Tensor,
row_offsets: Tensor,
):
self.shape = list(shape)
self.compressed = compressed
self.bitmask = bitmask
self.row_offsets = row_offsets

@staticmethod
def from_dense(tensor: Tensor) -> "BitmaskTensor":
"""
:param tensor: dense tensor to compress
:return: instantiated compressed tensor
"""
return BitmaskTensor(tensor)
shape = tensor.shape
compressed, bitmask, row_offsets = bitmask_compress(tensor.cpu())
return BitmaskTensor(
shape=shape, compressed=compressed, bitmask=bitmask, row_offsets=row_offsets
)

def decompress(self) -> Tensor:
"""
:return: reconstructed dense tensor
"""
return bitmask_decompress(self.compressed, self.bitmask, self.shape)

def curr_memory_size_bytes(self):
"""
Expand All @@ -113,21 +146,21 @@ def sizeof_tensor(a):
return a.element_size() * a.nelement()

return (
sizeof_tensor(self.values)
+ sizeof_tensor(self.bitmasks)
sizeof_tensor(self.compressed)
+ sizeof_tensor(self.bitmask)
+ sizeof_tensor(self.row_offsets)
)

def dict(self, name_prefix: str) -> Dict[str, Tensor]:
def dict(self, name_prefix: str, device: str = "cpu") -> Dict[str, Tensor]:
"""
:name_prefix: name of original tensor to store compressed weight as
:return: dict of compressed data for the stored weight
"""
return {
name_prefix + ".compressed": self.values,
name_prefix + ".bitmask": self.bitmasks,
name_prefix + ".shape": torch.tensor(self.shape, device="cpu"),
name_prefix + ".row_offsets": self.row_offsets,
merge_names(name_prefix, "shape"): torch.tensor(self.shape, device=device),
merge_names(name_prefix, "compressed"): self.compressed.to(device),
merge_names(name_prefix, "bitmask"): self.bitmask.to(device),
merge_names(name_prefix, "row_offsets"): self.row_offsets.to(device),
}

def __repr__(self):
Expand Down
17 changes: 17 additions & 0 deletions src/sparseml/transformers/compression/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa

from .safetensors_load import *
Loading
Loading