Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Composability with sparse and quantization compressors #948

Merged
merged 21 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion examples/sparse_2of4_quantization_fp8/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ oneshot(
)
```

3. **Save the Compressed Model**
### Saving the Compressed Model

The compressed model and tokenizer are saved to the output directory:

Expand All @@ -106,6 +106,17 @@ Output Directories:
- Without FP8: `Meta-Llama-3-8B-Instruct-2of4-sparse`
- With FP8: `Meta-Llama-3-8B-Instruct-2of4-W8A8-FP8-Dynamic-Per-Token`

#### Saving Without Sparse Compression

To save the model on disk without sparse compression:

```python
model.save_pretrained(save_dir, save_compressed=True, disable_sparse_compression=True)
tokenizer.save_pretrained(save_dir)
```

> **Note:** Saving a model with both the `save_compressed` and `disable_sparse_compression` options will compress the model using the quantization compressor; however, instead of using the more disk-efficient sparsity compressor(s), the dense sparsity compressor will be used. The `dense` sparsity compressor saves model params as is, and does not leverage sparsity for disk-efficient storage. These options only affect how the model(s) are saved on disk and do not impact the actual pruning or quantization processes.

### Validation

After compression, the script validates the model by generating a sample output:
Expand Down
28 changes: 24 additions & 4 deletions src/llmcompressor/transformers/compression/quantization_format.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional

from compressed_tensors import CompressionFormat
from compressed_tensors.config import SparsityCompressionConfig
from compressed_tensors.config import SparsityStructure
from compressed_tensors.quantization import QuantizationStrategy, QuantizationType
from compressed_tensors.quantization.utils import (
is_model_quantized,
Expand All @@ -16,10 +16,30 @@ def infer_quantization_format(
model,
quantization_format: Optional[str] = None,
save_compressed: bool = False,
sparsity_config: Optional[SparsityCompressionConfig] = None,
sparsity_structure: Optional[str] = None,
dsikka marked this conversation as resolved.
Show resolved Hide resolved
) -> str:
"""
Infers a quantization format based on model state and compression args
Infers the quantization format for a model based on its state and provided
compression arguments.

The following table outlines the possible quantization and sparsity formats
along with their corresponding compressor formats:

+---------------+----------+----------------------+---------------------+
| Quantization | Sparsity | Quant Compressor | Sparsity Compressor |
| | | Format | Format |
+---------------+----------+----------------------+---------------------+
| W8A8 - int | None | int_quantized | Dense |
| W8A8 - float | None | float_quantized | Dense |
| W4A16 - int | None | pack_quantized | Dense |
| W8A16 - int | None | pack_quantized | Dense |
| W8A16 - float | None | naive_quantized | Dense |
| W8A8 - int | 2:4 | int_quantized | Sparse24 |
| W8A8 - float | 2:4 | float_quantized | Sparse24 |
| W4A16 - int | 2:4 | marlin_24 | Dense |
| W8A16 - int | 2:4 | marlin_24 | Dense |
| W8A16 - float | 2:4 | naive_quantized | Dense |
+---------------+----------+----------------------+---------------------+

:param model: model to check for quantization, if the model is not quantized no
quantization format is returned
Expand All @@ -37,7 +57,7 @@ def infer_quantization_format(
if save_compressed:
weight_args, input_args = _get_unique_quant_args(model)
is_24_structure = (
sparsity_config and sparsity_config.sparsity_structure == "2:4"
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
)
is_weight_only = len(input_args) == 0 and len(weight_args) > 0

Expand Down
102 changes: 94 additions & 8 deletions src/llmcompressor/transformers/compression/sparsity_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from typing import Dict, Optional
from typing import Dict, List, Optional

from compressed_tensors import CompressionFormat, SparsityCompressionConfig
from compressed_tensors.quantization.utils import is_model_quantized
from compressed_tensors.config import SparsityStructure
from compressed_tensors.quantization import QuantizationType
from compressed_tensors.quantization.utils import (
is_model_quantized,
is_module_quantized,
iter_named_leaf_modules,
)
from loguru import logger
from torch import Tensor
from torch.nn import Module

Expand All @@ -20,7 +27,7 @@ class SparsityConfigMetadata:
metadata from the model
"""

SPARSITY_THRESHOLD: float = 0.4
SPARSITY_THRESHOLD: float = 0.5

mgoin marked this conversation as resolved.
Show resolved Hide resolved
@staticmethod
def infer_global_sparsity(
Expand Down Expand Up @@ -67,13 +74,15 @@ def infer_sparsity_structure(model: Optional[Module] = None) -> str:
if model and sparsity_structure is None:
sparsity_structure = infer_sparsity_structure_from_model(model)

return sparsity_structure or "unstructured"
return SparsityStructure(sparsity_structure).value

@staticmethod
def from_pretrained(
model: Module,
state_dict: Optional[Dict[str, Tensor]] = None,
compress: bool = False,
quantization_format: Optional[CompressionFormat] = None,
disable_sparse_compression: bool = False,
) -> Optional["SparsityCompressionConfig"]:
"""
Determines compression type and informational parameters for a given model
Expand All @@ -82,6 +91,11 @@ def from_pretrained(
:param state_dict: optional state_dict to replace that in model, used for
gathering global FSDP model info
:param compress: whether or not to compress the model on disk
:param quantization_format: the quantization compression format being used
for the model
rahul-tuli marked this conversation as resolved.
Show resolved Hide resolved
:param disable_sparse_compression: whether or not to compress the model with
sparse compressors, If True, the sparse compression format will
be dense, default is False.
:return: compression config inferred from the model
"""

Expand All @@ -95,11 +109,18 @@ def from_pretrained(
sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure(
model=model
)
if is_model_quantized(model):
# compressing a sparse quantized model is not supported yet
if (
disable_sparse_compression
or quantization_format == CompressionFormat.marlin_24
):
# sparse compressor should be dense
# when no_sparse_compression is True
# or when marlin_24 is used
format = CompressionFormat.dense.value
elif compress:
format = CompressionFormat.sparse_bitmask.value
elif compress and SparsityConfigMetadata.is_sparse24_bitmask_supported(
model, sparsity_structure
):
format = CompressionFormat.sparse_24_bitmask.value
else:
rahul-tuli marked this conversation as resolved.
Show resolved Hide resolved
format = CompressionFormat.dense.value

Expand Down Expand Up @@ -135,3 +156,68 @@ def fill_config_details(
model, state_dict=state_dict
)
config.sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure()

@staticmethod
def is_sparse24_bitmask_supported(
model: Module,
sparsity_structure: Optional[str] = None,
) -> bool:
"""
Determines if sparse 24 bitmask sparse compressor is supported for a given model
and its sparsity structure in vLLM

:param model: pytorch model to check for sparse 24 bit sparsity support
:param sparsity_structure: sparsity structure of the model, if
not supplied it will be inferred
:return: whether or not sparse 24 bitmask compression is supported
in vLLM for the given model
"""

if sparsity_structure is None:
sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure(model)

if sparsity_structure != SparsityStructure.TWO_FOUR.value:
# only supported for 2:4 sparsity
return False

if not is_model_quantized(model):
# non-quantized 2:4 sparse models are supported
return True

# when model is quantized, and has 2:4 sparsity

supported_scheme_types: List[str] = [
QuantizationType.INT.value,
QuantizationType.FLOAT.value,
]

for _, submodule in iter_named_leaf_modules(model):
if is_module_quantized(submodule):
weight_scheme = submodule.quantization_scheme.weights
input_scheme = submodule.quantization_scheme.input_activations

if weight_scheme and input_scheme:
# weight and activation quantization
# check schemes are supported
for scheme in [weight_scheme, input_scheme]:
scheme_supported = (
scheme.num_bits == 8
and scheme.type in supported_scheme_types
)
if not scheme_supported:
logger.info(
"Quantization scheme not supported,"
" turning off sparse 24 compression."
f" Invalid Scheme: {scheme}"
)
return False

elif weight_scheme or input_scheme:
# weight only quantization
logger.info(
"Weight only quantization detected, "
"turning off sparse 24 compression."
)
return False

return True
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import transformers
from accelerate.accelerator import get_state_dict_offloaded_model
from compressed_tensors import (
CompressionFormat,
ModelCompressor,
SparsityCompressionConfig,
is_module_offloaded,
Expand Down Expand Up @@ -124,6 +125,7 @@ def save_pretrained_wrapper(
quantization_format: Optional[str] = None,
save_compressed: bool = True,
skip_compression_stats: bool = False,
disable_sparse_compression: bool = False,
**kwargs,
):
"""
Expand All @@ -133,13 +135,15 @@ def save_pretrained_wrapper(

:param save_directory: output directory to save model to
:param sparsity_config: optional sparsity config to compress model with,
if no config is provided it will be inferred from the model
if no config is provided it will be inferred from the model
:param quantization_format: optional compression format for quantized
models. If none is provided it will be inferred from the model
models. If none is provided it will be inferred from the model
:param save_compressed: whether or not to compress the model on disk
:param skip_compression_stats: whether to skip the calculation of
compression statistics (such as global sparsity and sparsity structure) when
saving a model in dense format
compression statistics (such as global sparsity and sparsity structure)
when saving a model in dense format
:param disable_sparse_compression: whether to skip sparse compression
during save, default is False
:param kwargs: additional kwargs to pass on to model.save_pretrained
"""

Expand Down Expand Up @@ -169,6 +173,7 @@ def skip(*args, **kwargs):
save_compressed=save_compressed,
skip_compression_stats=skip_compression_stats,
state_dict=state_dict,
disable_sparse_compression=disable_sparse_compression,
)

if compressor is None:
Expand Down Expand Up @@ -260,6 +265,7 @@ def get_model_compressor(
save_compressed: bool = True,
skip_compression_stats: bool = False,
state_dict: Optional[Dict] = None,
disable_sparse_compression: bool = False,
):
"""
Obtain the compressor based on the config and the
Expand All @@ -273,19 +279,26 @@ def get_model_compressor(
format
:param skip_compression_stats: bool allowing compression stats on std out
:param state_dict: state_dict of the model
:param disable_sparse_compression: bool to skip sparse compression
"""

# find offloaded state dict if none is provided
if state_dict is None:
state_dict = get_state_dict_offloaded_model(model)

sparsity_stucture = SparsityConfigMetadata.infer_sparsity_structure(model)
quantization_format: Optional[CompressionFormat] = infer_quantization_format(
model=model,
quantization_format=quantization_format,
save_compressed=save_compressed,
sparsity_structure=sparsity_stucture,
)

if sparsity_config is not None:
sparsity_config.global_sparsity = SparsityConfigMetadata.infer_global_sparsity(
model, state_dict=state_dict
)
sparsity_config.sparsity_structure = (
SparsityConfigMetadata.infer_sparsity_structure()
)
sparsity_config.sparsity_structure = sparsity_stucture
elif not skip_compression_stats:
# try to infer a sparsity config from the model if none is provided
logger.info(
Expand All @@ -295,15 +308,13 @@ def get_model_compressor(
"skip_compression_stats=True"
)
sparsity_config = SparsityConfigMetadata.from_pretrained(
model, state_dict=state_dict, compress=save_compressed
model,
state_dict=state_dict,
compress=save_compressed,
quantization_format=quantization_format,
disable_sparse_compression=disable_sparse_compression,
)

quantization_format = infer_quantization_format(
model=model,
quantization_format=quantization_format,
save_compressed=save_compressed,
sparsity_config=sparsity_config,
)
return ModelCompressor.from_pretrained_model(
model,
sparsity_config=sparsity_config,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pruning_stage:
obcq_modifiers:
SparseGPTModifier:
sparsity: 0.5
sequential_update: true
mask_structure: "2:4"
targets: ['re:model.layers.\d*$']
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
pruning_stage:
obcq_modifiers:
SparseGPTModifier:
sparsity: 0.5
sequential_update: true
mask_structure: "2:4"
targets: ['re:model.layers.\d*$']
quant_stage:
quant_modifiers:
QuantizationModifier:
ignore: ["lm_head"]
config_groups:
group_0:
weights:
num_bits: 8
type: float
strategy: channel
dynamic: false
symmetric: true
input_activations:
num_bits: 8
type: float
strategy: token
dynamic: true
symmetric: true
targets: ["Linear"]
pruning_modifiers:
ConstantPruningModifier:
targets: [
're:.*q_proj.weight',
're:.*k_proj.weight',
're:.*v_proj.weight',
're:.*o_proj.weight',
're:.*gate_proj.weight',
're:.*up_proj.weight',
're:.*down_proj.weight',
]
start: 0
Loading
Loading