Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SparseML Compression Pt 1: saving w/compression configs #2177

Merged
merged 22 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
_deps = [
"setuptools<=59.5.0",
"pyyaml>=5.0.0",
"numpy>=1.0.0",
"numpy>=1.17.0",
Satrat marked this conversation as resolved.
Show resolved Hide resolved
"matplotlib>=3.0.0",
"merge-args>=0.1.0",
"onnx>=1.5.0,<1.15.0",
Expand Down
8 changes: 6 additions & 2 deletions src/sparseml/pytorch/model_load/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,17 +229,21 @@ def reload_model_from_checkpoint(model: Module, checkpoint: Optional[str] = None


def save_model_and_recipe(
model: Module, save_path: str, tokenizer: Optional[Any] = None
model: Module,
save_path: str,
tokenizer: Optional[Any] = None,
save_safetensors: bool = False,
):
"""
Save a model, tokenizer and the currently loaded recipe to file

:param model: pytorch model to save
:param save_path: path to save output to
:param tokenizer: model tokenizer to save
:param save_safetensors: whether to save as safetensors or pickle(bin)
Satrat marked this conversation as resolved.
Show resolved Hide resolved
"""

model.save_pretrained(save_path)
model.save_pretrained(save_path, safe_serialization=save_safetensors)

if tokenizer is not None:
tokenizer.save_pretrained(save_path)
Expand Down
1 change: 1 addition & 0 deletions src/sparseml/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,4 @@ def _check_transformers_install():
from .utils import *
from .export import *
from .finetune import *
from .compression import *
Satrat marked this conversation as resolved.
Show resolved Hide resolved
67 changes: 67 additions & 0 deletions src/sparseml/transformers/compression/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Save/Load Compressed SafeTensors

## Motivation

* Reduce disk space by saving in a compressed format for sparse models. Models in this compressed format will be loaded by vLLM for more efficient inference
* Set up the save/load architecture such that we can easily expand to additional compression formats in the future. The config should be human readable so users can understand the compression format at a quick glance

## SafeTensors File Format

For each parameter in the uncompressed state_dict, we store the following attributes
needed for decompression in the compressed state_dict:

* compressed tensor
* bitmask
* uncompressed shape
* row offsets

```python
# dense
{
PARAM_NAME: uncompressed_tensor
}

# compressed
{
PARAM_NAME.compressed: compressed_tensor # 1d tensor
PARAM_NAME.bitmask: value # 2d bitmask tensor (nrows x (ncols / 8))
Satrat marked this conversation as resolved.
Satrat marked this conversation as resolved.
Show resolved Hide resolved
PARAM_NAME.shape: value # uncompressed shape tensor
PARAM_NAME.row_offsets: value # 1d offsets tensor
}
```

## Example Code

```python
from sparseml.transformers import SparseAutoModelForCausalLM
from sparseml.transformers.compression import BitmaskConfig, BitmaskCompressor
from safetensors import safe_open
import os

MODEL_PATH = "zoo:llama2-7b-gsm8k_llama2_pretrain-pruned50.oneshot"
OUTPUT_PATH = "./test_compress_output"

model = SparseAutoModelForCausalLM.from_pretrained(MODEL_PATH)

sparsity_config = BitmaskConfig()
compressor = BitmaskCompressor(config=sparsity_config)

model_state_dict = model.state_dict()
sparse_state_dict = compressor.compress(model_state_dict)


model.save_pretrained(OUTPUT_PATH, safe_serialization=True, state_dict=sparse_state_dict)

safetensors_path = os.path.join(OUTPUT_PATH, "model-00001-of-00002.safetensors")
with safe_open(safetensors_path, framework="pt", device=0) as f:
test_name = "model.layers.4.self_attn.k_proj.weight"
bitmask = f.get_tensor(test_name + ".bitmask")
shape = f.get_tensor(test_name + ".shape")
values = f.get_tensor(test_name + ".compressed")
row_offsets = f.get_tensor(test_name + ".row_offsets")
print(f"bitmask: {bitmask}")
print(f"shape: {shape}")
print(f"values: {values}")
print(f"row offsets: {row_offsets}")
```
18 changes: 18 additions & 0 deletions src/sparseml/transformers/compression/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa

from .compressors import *
from .config import *
18 changes: 18 additions & 0 deletions src/sparseml/transformers/compression/compressors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa

from .base import ModelCompressor
from .sparse_bitmask import BitmaskCompressor
52 changes: 52 additions & 0 deletions src/sparseml/transformers/compression/compressors/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict

from torch import Tensor

from sparseml.transformers.compression.config import CompressionConfig
from sparsezoo.utils.registry import RegistryMixin


__all__ = ["ModelCompressor"]


class ModelCompressor(RegistryMixin):
"""
Base class representing a model compression algorithm.

:param config: config specifying compression parameters
"""

def __init__(self, config: CompressionConfig):
self.config = config

def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
Compresses a dense state dict

:param model_state: state dict of uncompressed model
:return: compressed state dict
"""
raise NotImplementedError()

def decompress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
Uncompresses a compressed state dict back to dense

:param model_state: state dict of uncompressed model
:return: compressed state dict
"""
raise NotImplementedError()
Loading
Loading