Skip to content

Commit

Permalink
Merge branch 'main' into clear-ml
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm authored Mar 27, 2024
2 parents cb5fe0b + 85b0e72 commit 5f7c4c2
Show file tree
Hide file tree
Showing 22 changed files with 629 additions and 53 deletions.
6 changes: 5 additions & 1 deletion src/sparseml/pytorch/model_load/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def save_model_and_recipe(
save_path: str,
tokenizer: Optional[Any] = None,
save_safetensors: bool = False,
save_compressed: bool = False,
):
"""
Save a model, tokenizer and the currently loaded recipe to file
Expand All @@ -241,9 +242,12 @@ def save_model_and_recipe(
:param save_path: path to save output to
:param tokenizer: model tokenizer to save
:param save_safetensors: whether to save as safetensors or pickle (bin)
:param save_compressed: whether to compress sparse weights on disk
"""

model.save_pretrained(save_path, safe_serialization=save_safetensors)
model.save_pretrained(
save_path, save_compressed=save_compressed, safe_serialization=save_safetensors
)

if tokenizer is not None:
tokenizer.save_pretrained(save_path)
Expand Down
2 changes: 1 addition & 1 deletion src/sparseml/pytorch/utils/sparsification.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def params_sparse(self) -> int:
"""
return sum(
round(tensor_sparsity(param).item() * torch.numel(param))
for param in self.trainable_params
for param in tqdm(self.trainable_params, desc="Calculating model sparsity")
)

@property
Expand Down
133 changes: 107 additions & 26 deletions src/sparseml/transformers/compression/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,52 +30,133 @@ needed for decompression in the compressed state_dict:
}
```

Config information gets stored in the HF config file
```json
// config.json
{
"sparsity_config": {
"format": "sparse_bitmask", // "dense_sparsity" for original tensor format

// informational
"sparsity_structure": "unstructured", // or 2:4, 8:16 etc...
"global_sparsity": "0.5"
}
}
```

## Saving/Loading Interface

Loading in a compressed model requires no interface changes

```python
from sparseml.transformers.utils import SparseAutoModelForCausalLM

# should contain model.safetensors or model.safetensors.index.json
model_path = "/PATH/TO/COMPRESSED_MODEL"

model = SparseAutoModelForCausalLM.from_pretrained(
model_name_or_path=model_path,
**model_kwargs,
)
```

Saving a compressed model with an explicitly provided compression config. The config
is saved to the model's `config.json` file. **Note:** the model must have been
initialized with SparseAutoModelForCausalLM.from_pretrained()

```python
from sparseml.transformers.compression import BitmaskConfig

output_dir = "/PATH/TO/SAVE/COMPRESSED_MODEL"
sparsity_config = BitmaskConfig()

model.save_pretrained(
save_directory=output_dir,
sparsity_config=sparsity_config,
)
```

Saving a compressed model, inferring the config from the model attributes

```python
model.save_pretrained(
save_directory=output_dir,
save_compressed=True
)
```

Saving a model in the dense format. If the model has at least 5% global sparsity a
sparsity config will still be included in `config.json` with format `dense_sparsity`

```python
model.save_pretrained(
save_directory=output_dir
)
```

Saving a model in the dense format, bypassing the sparsity config calculation. When the
`skip_compression_stats` flag is set, no sparsity config will be written to
`config.json`

```python
model.save_pretrained(
save_directory=output_dir
skip_compression_stats=True
)
```

## Enable Compression During One-Shot and Sparse Finetunining
Models that are saved in a supported compressed format on disk will automatically be
decompressed when loaded as input to `sparseml.transformers.oneshot` or
`sparseml.transformers.train`

To enable compression on save after oneshot or finetuning simply add the
`save_compressed=True` argument to `sparseml.transformers.oneshot` or
`sparseml.transformers.train`

```python
from sparseml.transformers import train

train(
save_compressed=True,
model="neuralmagic/TinyLlama-1.1B-Chat-v1.0-pruned2.4",
recipe=RECIPE,
dataset=DATASET
)
```


## Example Code

Loads a 60% sparse model, compresses it using the inferred bitmask compression, then
reloads the compressed model.

```python
from sparseml.transformers import SparseAutoModelForCausalLM
from sparseml.transformers.compression import BitmaskConfig, BitmaskCompressor
from sparseml.utils.pytorch.utils import measure_cuda_memory
from tqdm import tqdm
import torch

MODEL_PATH = "zoo:llama2-7b-gsm8k_llama2_pretrain-pruned50.oneshot"
MODEL_PATH = "zoo:llama2-7b-open_platypus_orca_llama2_pretrain-pruned60"
OUTPUT_PATH = "./test_compress_output"
RECIPE = "zoo:llama2-7b-open_platypus_orca_llama2_pretrain-pruned60"

torch.cuda.set_device(0)
with measure_cuda_memory() as m:
model = SparseAutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="cuda:0")
print(f"Load dense model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB")

sparsity_config = BitmaskConfig()
compressor = BitmaskCompressor(config=sparsity_config)

# compresses the model using Bitmask compression
sparsity_config = getattr(model,"sparsity_config", None)
print(f"Sparsity config before compression: {sparsity_config}")
with measure_cuda_memory() as m:
model_state_dict = model.state_dict()
sparse_state_dict = compressor.compress(model_state_dict)

# save the compressed model
model.save_pretrained(
OUTPUT_PATH,
safe_serialization=True,
state_dict=sparse_state_dict
)

model.save_pretrained(OUTPUT_PATH, save_compressed=True)
print(f"Save compressed model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB")

# use the dense state dict to reload the model
torch.cuda.set_device(1)
with measure_cuda_memory() as m:
model_again = SparseAutoModelForCausalLM.from_pretrained(
OUTPUT_PATH,
device_map="cuda:1"
OUTPUT_PATH, device_map="cuda:1"
)

#returns iterator
dense_state_dict = compressor.decompress(OUTPUT_PATH)
for name, data in tqdm(dense_state_dict, desc="Decompressing model"):
BitmaskCompressor.replace_layer(name, data, model_again)

print(f"Load compressed model peak GPU {m.overall_peak_memory / float(2**30):.4f} GB")
sparsity_config = getattr(model_again,"sparsity_config", None)
print(f"Sparsity config after compression: {sparsity_config}")
```
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
# flake8: noqa

from .base import ModelCompressor
from .dense import DenseCompressor
from .sparse_bitmask import BitmaskCompressor
19 changes: 17 additions & 2 deletions src/sparseml/transformers/compression/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# limitations under the License.

import operator
from typing import Dict, Generator
from typing import Dict, Generator, Tuple

from torch import Tensor
from torch.nn import Module, Parameter
from tqdm import tqdm

from sparseml.transformers.compression.config import CompressionConfig
from sparseml.transformers.utils.helpers import SPARSITY_CONFIG_NAME
from sparseml.utils.pytorch.module import set_layer
from sparsezoo.utils.registry import RegistryMixin

Expand All @@ -45,7 +47,7 @@ def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
raise NotImplementedError()

def decompress(self, model_path: str) -> Generator:
def decompress(self, model_path: str) -> Generator[Tuple[str, Tensor], None, None]:
"""
Reads a compressed state dict located at model_path and returns a
generator for sequentially decompressing back to a dense state dict
Expand All @@ -67,3 +69,16 @@ def replace_layer(param_name: str, data: Tensor, model: Module):
"""
model_device = operator.attrgetter(param_name)(model).device
set_layer(param_name, Parameter(data.to(model_device)), model)

def overwrite_weights(self, pretrained_model_name_or_path: str, model: Module):
"""
Overwrites the weights in model with weights decompressed from
pretrained_model_name_or_path
:param pretrained_model_name_or_path: path to compressed weights
:param model: pytorch model to load decompressed weights into
"""
dense_gen = self.decompress(pretrained_model_name_or_path)
for name, data in tqdm(dense_gen, desc="Decompressing model"):
ModelCompressor.replace_layer(name, data, model)
setattr(model, SPARSITY_CONFIG_NAME, self.config)
32 changes: 32 additions & 0 deletions src/sparseml/transformers/compression/compressors/dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Generator, Tuple

from torch import Tensor

from sparseml.transformers.compression.compressors import ModelCompressor


@ModelCompressor.register(name="dense_sparsity")
class DenseCompressor(ModelCompressor):
"""
Identity compressor for dense models, returns the original state_dict
"""

def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
return model_state

def decompress(self, model_path: str) -> Generator[Tuple[str, Tensor], None, None]:
return iter([])
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:

return compressed_dict

def decompress(self, model_path: str) -> Generator:
def decompress(self, model_path: str) -> Generator[Tuple[str, Tensor], None, None]:
"""
Reads a bitmask compressed state dict located at model_path and returns a
generator for sequentially decompressing back to a dense state dict
Expand Down
82 changes: 82 additions & 0 deletions src/sparseml/transformers/compression/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from pydantic import BaseModel
from torch.nn import Module

import sparseml.core.session as session_manager
from sparseml.pytorch.utils import ModuleSparsificationInfo
from sparsezoo.utils.registry import RegistryMixin


Expand All @@ -25,6 +30,83 @@ class CompressionConfig(RegistryMixin, BaseModel):
Base data class for storing compression parameters
:param format: name of compression format
:param global_sparsity: average sparsity of the entire model
:param sparsity_structure: structure of the sparsity, such as
"unstructured", "2:4", "8:16" etc
"""

format: str
global_sparsity: Optional[float] = 0.0
sparsity_structure: Optional[str] = "unstructured"

@staticmethod
def infer_global_sparsity(model: Module) -> float:
"""
Calculates the global percentage of sparse zero weights in the model
:param model: pytorch model to infer sparsity of
:return: global sparsity of model
"""
info = ModuleSparsificationInfo(model)
global_sparsity = info.params_sparse_percent
return global_sparsity

@staticmethod
def infer_sparsity_structure() -> str:
"""
Determines what sparsity structure, if any, was applied in the currently active
sparse session
:return: sparsity structure as a string
"""
current_session = session_manager.active_session()
stage_modifiers = current_session.lifecycle.modifiers
sparsity_structure = "unstructured"

# check for applied pruning modifiers
for stage in stage_modifiers:
if stage.applied:
for modifier in stage.modifiers:
if hasattr(modifier, "mask_structure"):
sparsity_structure = modifier.mask_structure
break

return sparsity_structure

@staticmethod
def infer_config_from_model(
model: Module, compress: bool = False
) -> Optional["CompressionConfig"]:
"""
Determines compression type and informational parameters for a given model
:param model: pytorch model to calculate sparsity config for
:param compress: whether or not to compress the model on disk
:return: compression config inferred from the model
"""

global_sparsity = CompressionConfig.infer_global_sparsity(model)

if global_sparsity < 0.05:
return None

sparsity_structure = CompressionConfig.infer_sparsity_structure()
if compress:
format = "sparse_bitmask"
else:
format = "dense_sparsity"

return CompressionConfig.load_from_registry(
format,
global_sparsity=global_sparsity,
sparsity_structure=sparsity_structure,
)

def fill_config_details(self, model: Module):
"""
Fills in informational sparsity parameters from a given model
:param model: pytorch model to infer config parameters from
"""
self.global_sparsity = CompressionConfig.infer_global_sparsity(model)
self.sparsity_structure = CompressionConfig.infer_sparsity_structure()
4 changes: 2 additions & 2 deletions src/sparseml/transformers/compression/config/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ class DenseSparsityConfig(CompressionConfig):
"""

format: str = "dense_sparsity"
global_sparsity: Optional[float] = None
sparsity_structure: Optional[str] = None
global_sparsity: Optional[float] = 0.0
sparsity_structure: Optional[str] = "unstructured"
2 changes: 2 additions & 0 deletions src/sparseml/transformers/compression/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@

# flake8: noqa

from .compress_save import *
from .helpers import *
from .safetensors_load import *
Loading

0 comments on commit 5f7c4c2

Please sign in to comment.