diff --git a/examples/sparse_2of4_quantization_fp8/README.md b/examples/sparse_2of4_quantization_fp8/README.md index 97b8e590e..99fc3c545 100644 --- a/examples/sparse_2of4_quantization_fp8/README.md +++ b/examples/sparse_2of4_quantization_fp8/README.md @@ -93,7 +93,7 @@ oneshot( ) ``` -3. **Save the Compressed Model** +### Saving the Compressed Model The compressed model and tokenizer are saved to the output directory: @@ -106,6 +106,17 @@ Output Directories: - Without FP8: `Meta-Llama-3-8B-Instruct-2of4-sparse` - With FP8: `Meta-Llama-3-8B-Instruct-2of4-W8A8-FP8-Dynamic-Per-Token` +#### Saving Without Sparse Compression + +To save the model on disk without sparse compression: + +```python +model.save_pretrained(save_dir, save_compressed=True, disable_sparse_compression=True) +tokenizer.save_pretrained(save_dir) +``` + +> **Note:** Saving a model with both the `save_compressed` and `disable_sparse_compression` options will compress the model using the quantization compressor; however, instead of using the more disk-efficient sparsity compressor(s), the dense sparsity compressor will be used. The `dense` sparsity compressor saves model params as is, and does not leverage sparsity for disk-efficient storage. These options only affect how the model(s) are saved on disk and do not impact the actual pruning or quantization processes. + ### Validation After compression, the script validates the model by generating a sample output: diff --git a/src/llmcompressor/transformers/compression/quantization_format.py b/src/llmcompressor/transformers/compression/quantization_format.py index e6f14a6c7..86e1baf2c 100644 --- a/src/llmcompressor/transformers/compression/quantization_format.py +++ b/src/llmcompressor/transformers/compression/quantization_format.py @@ -1,7 +1,7 @@ from typing import Optional from compressed_tensors import CompressionFormat -from compressed_tensors.config import SparsityCompressionConfig +from compressed_tensors.config import SparsityStructure from compressed_tensors.quantization import QuantizationStrategy, QuantizationType from compressed_tensors.quantization.utils import ( is_model_quantized, @@ -16,10 +16,30 @@ def infer_quantization_format( model, quantization_format: Optional[str] = None, save_compressed: bool = False, - sparsity_config: Optional[SparsityCompressionConfig] = None, + sparsity_structure: Optional[str] = None, ) -> str: """ - Infers a quantization format based on model state and compression args + Infers the quantization format for a model based on its state and provided + compression arguments. + + The following table outlines the possible quantization and sparsity formats + along with their corresponding compressor formats: + + +---------------+----------+----------------------+---------------------+ + | Quantization | Sparsity | Quant Compressor | Sparsity Compressor | + | | | Format | Format | + +---------------+----------+----------------------+---------------------+ + | W8A8 - int | None | int_quantized | Dense | + | W8A8 - float | None | float_quantized | Dense | + | W4A16 - int | None | pack_quantized | Dense | + | W8A16 - int | None | pack_quantized | Dense | + | W8A16 - float | None | naive_quantized | Dense | + | W8A8 - int | 2:4 | int_quantized | Sparse24 | + | W8A8 - float | 2:4 | float_quantized | Sparse24 | + | W4A16 - int | 2:4 | marlin_24 | Dense | + | W8A16 - int | 2:4 | marlin_24 | Dense | + | W8A16 - float | 2:4 | naive_quantized | Dense | + +---------------+----------+----------------------+---------------------+ :param model: model to check for quantization, if the model is not quantized no quantization format is returned @@ -37,7 +57,7 @@ def infer_quantization_format( if save_compressed: weight_args, input_args = _get_unique_quant_args(model) is_24_structure = ( - sparsity_config and sparsity_config.sparsity_structure == "2:4" + SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR ) is_weight_only = len(input_args) == 0 and len(weight_args) > 0 diff --git a/src/llmcompressor/transformers/compression/sparsity_config.py b/src/llmcompressor/transformers/compression/sparsity_config.py index d6ed9f7e7..1183023b3 100644 --- a/src/llmcompressor/transformers/compression/sparsity_config.py +++ b/src/llmcompressor/transformers/compression/sparsity_config.py @@ -1,7 +1,14 @@ -from typing import Dict, Optional +from typing import Dict, List, Optional from compressed_tensors import CompressionFormat, SparsityCompressionConfig -from compressed_tensors.quantization.utils import is_model_quantized +from compressed_tensors.config import SparsityStructure +from compressed_tensors.quantization import QuantizationType +from compressed_tensors.quantization.utils import ( + is_model_quantized, + is_module_quantized, + iter_named_leaf_modules, +) +from loguru import logger from torch import Tensor from torch.nn import Module @@ -20,7 +27,7 @@ class SparsityConfigMetadata: metadata from the model """ - SPARSITY_THRESHOLD: float = 0.4 + SPARSITY_THRESHOLD: float = 0.5 @staticmethod def infer_global_sparsity( @@ -67,13 +74,15 @@ def infer_sparsity_structure(model: Optional[Module] = None) -> str: if model and sparsity_structure is None: sparsity_structure = infer_sparsity_structure_from_model(model) - return sparsity_structure or "unstructured" + return SparsityStructure(sparsity_structure).value @staticmethod def from_pretrained( model: Module, state_dict: Optional[Dict[str, Tensor]] = None, compress: bool = False, + quantization_format: Optional[CompressionFormat] = None, + disable_sparse_compression: bool = False, ) -> Optional["SparsityCompressionConfig"]: """ Determines compression type and informational parameters for a given model @@ -82,6 +91,11 @@ def from_pretrained( :param state_dict: optional state_dict to replace that in model, used for gathering global FSDP model info :param compress: whether or not to compress the model on disk + :param quantization_format: the quantization compression format being used + for the model + :param disable_sparse_compression: whether or not to compress the model with + sparse compressors, If True, the sparse compression format will + be dense, default is False. :return: compression config inferred from the model """ @@ -95,11 +109,18 @@ def from_pretrained( sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure( model=model ) - if is_model_quantized(model): - # compressing a sparse quantized model is not supported yet + if ( + disable_sparse_compression + or quantization_format == CompressionFormat.marlin_24 + ): + # sparse compressor should be dense + # when no_sparse_compression is True + # or when marlin_24 is used format = CompressionFormat.dense.value - elif compress: - format = CompressionFormat.sparse_bitmask.value + elif compress and SparsityConfigMetadata.is_sparse24_bitmask_supported( + model, sparsity_structure + ): + format = CompressionFormat.sparse_24_bitmask.value else: format = CompressionFormat.dense.value @@ -135,3 +156,68 @@ def fill_config_details( model, state_dict=state_dict ) config.sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure() + + @staticmethod + def is_sparse24_bitmask_supported( + model: Module, + sparsity_structure: Optional[str] = None, + ) -> bool: + """ + Determines if sparse 24 bitmask sparse compressor is supported for a given model + and its sparsity structure in vLLM + + :param model: pytorch model to check for sparse 24 bit sparsity support + :param sparsity_structure: sparsity structure of the model, if + not supplied it will be inferred + :return: whether or not sparse 24 bitmask compression is supported + in vLLM for the given model + """ + + if sparsity_structure is None: + sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure(model) + + if sparsity_structure != SparsityStructure.TWO_FOUR.value: + # only supported for 2:4 sparsity + return False + + if not is_model_quantized(model): + # non-quantized 2:4 sparse models are supported + return True + + # when model is quantized, and has 2:4 sparsity + + supported_scheme_types: List[str] = [ + QuantizationType.INT.value, + QuantizationType.FLOAT.value, + ] + + for _, submodule in iter_named_leaf_modules(model): + if is_module_quantized(submodule): + weight_scheme = submodule.quantization_scheme.weights + input_scheme = submodule.quantization_scheme.input_activations + + if weight_scheme and input_scheme: + # weight and activation quantization + # check schemes are supported + for scheme in [weight_scheme, input_scheme]: + scheme_supported = ( + scheme.num_bits == 8 + and scheme.type in supported_scheme_types + ) + if not scheme_supported: + logger.info( + "Quantization scheme not supported," + " turning off sparse 24 compression." + f" Invalid Scheme: {scheme}" + ) + return False + + elif weight_scheme or input_scheme: + # weight only quantization + logger.info( + "Weight only quantization detected, " + "turning off sparse 24 compression." + ) + return False + + return True diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index 4c1e798b2..ec9951f6a 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -8,6 +8,7 @@ import transformers from accelerate.accelerator import get_state_dict_offloaded_model from compressed_tensors import ( + CompressionFormat, ModelCompressor, SparsityCompressionConfig, is_module_offloaded, @@ -124,6 +125,7 @@ def save_pretrained_wrapper( quantization_format: Optional[str] = None, save_compressed: bool = True, skip_compression_stats: bool = False, + disable_sparse_compression: bool = False, **kwargs, ): """ @@ -133,13 +135,15 @@ def save_pretrained_wrapper( :param save_directory: output directory to save model to :param sparsity_config: optional sparsity config to compress model with, - if no config is provided it will be inferred from the model + if no config is provided it will be inferred from the model :param quantization_format: optional compression format for quantized - models. If none is provided it will be inferred from the model + models. If none is provided it will be inferred from the model :param save_compressed: whether or not to compress the model on disk :param skip_compression_stats: whether to skip the calculation of - compression statistics (such as global sparsity and sparsity structure) when - saving a model in dense format + compression statistics (such as global sparsity and sparsity structure) + when saving a model in dense format + :param disable_sparse_compression: whether to skip sparse compression + during save, default is False :param kwargs: additional kwargs to pass on to model.save_pretrained """ @@ -169,6 +173,7 @@ def skip(*args, **kwargs): save_compressed=save_compressed, skip_compression_stats=skip_compression_stats, state_dict=state_dict, + disable_sparse_compression=disable_sparse_compression, ) if compressor is None: @@ -260,6 +265,7 @@ def get_model_compressor( save_compressed: bool = True, skip_compression_stats: bool = False, state_dict: Optional[Dict] = None, + disable_sparse_compression: bool = False, ): """ Obtain the compressor based on the config and the @@ -273,19 +279,26 @@ def get_model_compressor( format :param skip_compression_stats: bool allowing compression stats on std out :param state_dict: state_dict of the model + :param disable_sparse_compression: bool to skip sparse compression """ # find offloaded state dict if none is provided if state_dict is None: state_dict = get_state_dict_offloaded_model(model) + sparsity_stucture = SparsityConfigMetadata.infer_sparsity_structure(model) + quantization_format: Optional[CompressionFormat] = infer_quantization_format( + model=model, + quantization_format=quantization_format, + save_compressed=save_compressed, + sparsity_structure=sparsity_stucture, + ) + if sparsity_config is not None: sparsity_config.global_sparsity = SparsityConfigMetadata.infer_global_sparsity( model, state_dict=state_dict ) - sparsity_config.sparsity_structure = ( - SparsityConfigMetadata.infer_sparsity_structure() - ) + sparsity_config.sparsity_structure = sparsity_stucture elif not skip_compression_stats: # try to infer a sparsity config from the model if none is provided logger.info( @@ -295,15 +308,13 @@ def get_model_compressor( "skip_compression_stats=True" ) sparsity_config = SparsityConfigMetadata.from_pretrained( - model, state_dict=state_dict, compress=save_compressed + model, + state_dict=state_dict, + compress=save_compressed, + quantization_format=quantization_format, + disable_sparse_compression=disable_sparse_compression, ) - quantization_format = infer_quantization_format( - model=model, - quantization_format=quantization_format, - save_compressed=save_compressed, - sparsity_config=sparsity_config, - ) return ModelCompressor.from_pretrained_model( model, sparsity_config=sparsity_config, diff --git a/tests/llmcompressor/transformers/compression/recipes/sparse_24.yaml b/tests/llmcompressor/transformers/compression/recipes/sparse_24.yaml new file mode 100644 index 000000000..79d7184b5 --- /dev/null +++ b/tests/llmcompressor/transformers/compression/recipes/sparse_24.yaml @@ -0,0 +1,7 @@ +pruning_stage: + obcq_modifiers: + SparseGPTModifier: + sparsity: 0.5 + sequential_update: true + mask_structure: "2:4" + targets: ['re:model.layers.\d*$'] \ No newline at end of file diff --git a/tests/llmcompressor/transformers/compression/recipes/sparse_24_fp8.yaml b/tests/llmcompressor/transformers/compression/recipes/sparse_24_fp8.yaml new file mode 100644 index 000000000..5f423a111 --- /dev/null +++ b/tests/llmcompressor/transformers/compression/recipes/sparse_24_fp8.yaml @@ -0,0 +1,38 @@ +pruning_stage: + obcq_modifiers: + SparseGPTModifier: + sparsity: 0.5 + sequential_update: true + mask_structure: "2:4" + targets: ['re:model.layers.\d*$'] +quant_stage: + quant_modifiers: + QuantizationModifier: + ignore: ["lm_head"] + config_groups: + group_0: + weights: + num_bits: 8 + type: float + strategy: channel + dynamic: false + symmetric: true + input_activations: + num_bits: 8 + type: float + strategy: token + dynamic: true + symmetric: true + targets: ["Linear"] + pruning_modifiers: + ConstantPruningModifier: + targets: [ + 're:.*q_proj.weight', + 're:.*k_proj.weight', + 're:.*v_proj.weight', + 're:.*o_proj.weight', + 're:.*gate_proj.weight', + 're:.*up_proj.weight', + 're:.*down_proj.weight', + ] + start: 0 \ No newline at end of file diff --git a/tests/llmcompressor/transformers/compression/test_infer_quant_format.py b/tests/llmcompressor/transformers/compression/test_infer_quant_format.py index 7db2f0687..1eb3bf202 100644 --- a/tests/llmcompressor/transformers/compression/test_infer_quant_format.py +++ b/tests/llmcompressor/transformers/compression/test_infer_quant_format.py @@ -1,5 +1,4 @@ import pytest -from compressed_tensors.config import SparsityCompressionConfig from compressed_tensors.quantization import preset_name_to_scheme from llmcompressor.transformers.compression.quantization_format import ( @@ -20,9 +19,6 @@ ], ) def test_infer_quant_format(preset, sparsity_structure, expected_format): - sparsity_config = SparsityCompressionConfig( - format="dense", sparsity_structure=sparsity_structure - ) quant_scheme = preset_name_to_scheme(preset, targets=["Linear"]) dummy_model = LinearNet() @@ -30,6 +26,6 @@ def test_infer_quant_format(preset, sparsity_structure, expected_format): module.quantization_scheme = quant_scheme inferred_format = infer_quantization_format( - dummy_model, save_compressed=True, sparsity_config=sparsity_config + dummy_model, save_compressed=True, sparsity_structure=sparsity_structure ) assert inferred_format.value == expected_format diff --git a/tests/llmcompressor/transformers/compression/test_sparsity_config.py b/tests/llmcompressor/transformers/compression/test_sparsity_config.py new file mode 100644 index 000000000..2c4ed0ca1 --- /dev/null +++ b/tests/llmcompressor/transformers/compression/test_sparsity_config.py @@ -0,0 +1,122 @@ +from unittest.mock import Mock, patch + +import pytest +from torch.nn import Module + +from llmcompressor.transformers.compression.sparsity_config import ( + SparsityConfigMetadata, +) + +SPARSITY_CONFIG_LOCATION = "llmcompressor.transformers.compression.sparsity_config" + + +# Mock classes and functions +class MockSparsityStructure: + TWO_FOUR = Mock(value="2:4") + + +class MockQuantizationType: + INT = Mock(value="int") + FLOAT = Mock(value="float") + + +class MockSparsityConfigMetadata: + @staticmethod + def infer_sparsity_structure(model): + return model.sparsity_structure + + +def mock_is_model_quantized(model): + return model.is_quantized + + +def mock_iter_named_leaf_modules(model): + for name, module in model.named_modules(): + yield name, module + + +# Mock model class +class MockModel(Module): + def __init__( + self, sparsity_structure=None, is_quantized=False, quantization_scheme=None + ): + super().__init__() + self.sparsity_structure = sparsity_structure + self.is_quantized = is_quantized + self.quantization_scheme = quantization_scheme + + def named_modules(self): + yield "mock_submodule", self + + +# Fixtures +@pytest.fixture +def models(): + return { + "non_sparse": MockModel(sparsity_structure=None), + "non_24_sparse": MockModel(sparsity_structure="unstructured"), + "non_quantized_24_sparse": MockModel( + sparsity_structure=MockSparsityStructure.TWO_FOUR.value, is_quantized=False + ), + "quantized_24_sparse_supported": MockModel( + sparsity_structure=MockSparsityStructure.TWO_FOUR.value, + is_quantized=True, + # W8A8 + quantization_scheme=Mock( + weights=Mock(num_bits=8, type=MockQuantizationType.FLOAT.value), + input_activations=Mock( + num_bits=8, type=MockQuantizationType.FLOAT.value + ), + ), + ), + "quantized_24_sparse_unsupported": MockModel( + sparsity_structure=MockSparsityStructure.TWO_FOUR.value, + is_quantized=True, + # W4A8 + quantization_scheme=Mock( + weights=Mock(num_bits=4, type=MockQuantizationType.INT.value), + input_activations=Mock( + num_bits=8, type=MockQuantizationType.FLOAT.value + ), + ), + ), + } + + +@pytest.mark.usefixtures("models") +class TestSparse24BitmaskSupport: + @pytest.fixture(autouse=True) + def setup_mocks(self, request): + patchers = [ + patch( + f"{SPARSITY_CONFIG_LOCATION}" + ".SparsityConfigMetadata.infer_sparsity_structure", + side_effect=MockSparsityConfigMetadata.infer_sparsity_structure, + ), + patch( + f"{SPARSITY_CONFIG_LOCATION}.is_model_quantized", + side_effect=mock_is_model_quantized, + ), + patch( + f"{SPARSITY_CONFIG_LOCATION}.iter_named_leaf_modules", + side_effect=mock_iter_named_leaf_modules, + ), + ] + for patcher in patchers: + patcher.start() + request.addfinalizer(patcher.stop) # for cleanup + + @pytest.mark.parametrize( + "model_key, expected", + [ + ("non_sparse", False), + ("non_24_sparse", False), + ("non_quantized_24_sparse", True), + ("quantized_24_sparse_supported", True), + ("quantized_24_sparse_unsupported", False), + ], + ) + def test_sparse24_bitmask_support(self, models, model_key, expected): + model = models[model_key] + result = SparsityConfigMetadata.is_sparse24_bitmask_supported(model) + assert result == expected diff --git a/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py b/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py index 3fddd254a..92e600de9 100644 --- a/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py +++ b/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py @@ -6,11 +6,16 @@ import torch from accelerate import cpu_offload from accelerate.accelerator import get_state_dict_offloaded_model -from compressed_tensors import QUANTIZATION_CONFIG_NAME +from compressed_tensors import QUANTIZATION_CONFIG_NAME, CompressionFormat from compressed_tensors.compressors import ModelCompressor from compressed_tensors.config import BitmaskConfig, DenseSparsityConfig -from compressed_tensors.quantization import QuantizationStatus +from compressed_tensors.quantization import ( + QuantizationConfig, + QuantizationStatus, + quantize, +) from compressed_tensors.utils import get_offloaded_device, update_prefix_dict +from torch import nn from transformers import AutoConfig, AutoModelForCausalLM from transformers.utils.quantization_config import CompressedTensorsConfig @@ -21,6 +26,7 @@ SparsityConfigMetadata, ) from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( + get_model_compressor, modify_save_pretrained, patch_tied_tensors_bug, ) @@ -356,3 +362,342 @@ def test_model_shared_tensors_gpu( test_model_shared_tensors( offload, torch_dtype, tie_word_embeddings, device_map, tmp_path ) + + +@pytest.mark.parametrize( + "model_stub, recipe, sparse_format, quant_format", + [ + ( + "Xenova/llama2.c-stories15M", + "tests/llmcompressor/transformers/compression/recipes/sparse_24_fp8.yaml", + CompressionFormat.sparse_24_bitmask.value, + CompressionFormat.float_quantized.value, + ), + ], +) +def test_compressor_stacking(model_stub, recipe, sparse_format, quant_format, tmp_path): + from llmcompressor.pytorch.model_load.helpers import get_session_model + + device = "cuda" + if not torch.cuda.is_available(): + device = "cpu" + dataset = "open_platypus" + concatenate_data = False + num_calibration_samples = 64 + splits = {"calibration": "train[:10%]"} + empty_model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype="auto") + + oneshot( + model=model_stub, + dataset=dataset, + num_calibration_samples=num_calibration_samples, + recipe=recipe, + concatenate_data=concatenate_data, + splits=splits, + oneshot_device=device, + clear_sparse_session=False, + ) + + # Fetch the oneshot model + model = get_session_model() + og_state_dict = model.state_dict() + path = tmp_path / "compressed" + + # Compress and save + model.save_pretrained( + path, + quantization_format=quant_format, + save_compressed=True, + ) + + # Verify config on disk + config = AutoConfig.from_pretrained(path) + compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None) + quant_config = ModelCompressor.parse_quantization_config(compression_config) + + # As HFQuantizer doesn't decompress the model, use the compressor to decompress + # the model instead + compressor = ModelCompressor.from_compression_config(compression_config) + + assert ( + compressor.sparsity_compressor is not None + ), "Sparse compressor not initialized" + assert compressor.sparsity_config.format == sparse_format + + assert ( + compressor.quantization_compressor is not None + ), "Quantization compressor not initialized" + assert quant_config["format"] == quant_format + + compressor.quantization_config.quantization_status = QuantizationStatus.FROZEN + compressor.decompress(model_path=path, model=empty_model) + + # Verify the abs difference between the decompressed model + # and the original model + reconstructed_state_dict = empty_model.state_dict() + assert len(og_state_dict) == len(reconstructed_state_dict) + for key in og_state_dict.keys(): + dense_tensor = og_state_dict[key].to(device) + reconstructed_tensor = reconstructed_state_dict[key].to(device) + assert dense_tensor.dtype == reconstructed_tensor.dtype + if key.endswith("weight") and quant_format != "dense": + # we don't expect an exact match for compressed + diff = torch.abs(dense_tensor - reconstructed_tensor) + # max diff value found empirically + assert not torch.any(diff > 0.022), f"Max diff: {torch.max(diff)}" + else: + assert torch.equal(dense_tensor, reconstructed_tensor) + shutil.rmtree(tmp_path) + + +@pytest.mark.parametrize( + "model_stub, recipe, sparse_format", + [ + ( + "Xenova/llama2.c-stories15M", + "tests/llmcompressor/transformers/compression/recipes/sparse_24.yaml", + CompressionFormat.sparse_24_bitmask.value, + ), + ], +) +def test_sparse_24_compressor_is_lossless(model_stub, recipe, sparse_format, tmp_path): + from llmcompressor.pytorch.model_load.helpers import get_session_model + + device = "cuda" + if not torch.cuda.is_available(): + device = "cpu" + dataset = "open_platypus" + concatenate_data = False + num_calibration_samples = 64 + splits = {"calibration": "train[:10%]"} + empty_model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype="auto") + + oneshot( + model=model_stub, + dataset=dataset, + num_calibration_samples=num_calibration_samples, + recipe=recipe, + concatenate_data=concatenate_data, + splits=splits, + oneshot_device=device, + clear_sparse_session=False, + ) + + # Fetch the oneshot model + model = get_session_model() + og_state_dict = model.state_dict() + path = tmp_path / "compressed" + + # Compress and save + model.save_pretrained( + path, + save_compressed=True, + ) + + # Verify config on disk + config = AutoConfig.from_pretrained(path) + compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None) + + # As HFQuantizer doesn't decompress the model, use the compressor to decompress + # the model instead + compressor = ModelCompressor.from_compression_config(compression_config) + + assert ( + compressor.sparsity_compressor is not None + ), "Sparse compressor not initialized" + assert compressor.sparsity_config.format == sparse_format + + compressor.decompress(model_path=path, model=empty_model) + + # Verify the abs difference between the decompressed model + # and the original model + reconstructed_state_dict = empty_model.state_dict() + assert len(og_state_dict) == len(reconstructed_state_dict) + for key in og_state_dict.keys(): + dense_tensor = og_state_dict[key].to(device) + reconstructed_tensor = reconstructed_state_dict[key].to(device) + assert dense_tensor.dtype == reconstructed_tensor.dtype + if key.endswith("weight"): + assert torch.equal(dense_tensor, reconstructed_tensor) + shutil.rmtree(tmp_path) + + +def test_no_sparse_compression_flag(tmp_path): + two_four_sparse_model_id = "nm-testing/llama2.c-stories42M-pruned2.4" + two_four_sparse_model = AutoModelForCausalLM.from_pretrained( + two_four_sparse_model_id, torch_dtype="auto" + ) + modify_save_pretrained(two_four_sparse_model) + + save_path = tmp_path / "no_sparse_compression_model" + two_four_sparse_model.save_pretrained(save_path, no_sparse_compression=True) + + config = AutoConfig.from_pretrained(save_path) + quantization_config = getattr(config, QUANTIZATION_CONFIG_NAME, None) + + assert quantization_config + sparsity_config = quantization_config.get("sparsity_config") + + assert sparsity_config + assert sparsity_config["format"] == "dense" + shutil.rmtree(tmp_path) + + +class DummyLinearModel(nn.Module): + """ + A dummy linear model for testing purposes, simulating a quantized linear layer. + """ + + def __init__(self, weights, weight_scale=None, weight_zero_point=None): + super().__init__() + out_features, in_features = weights.shape + + # Linear layer without bias + self.linear = nn.Linear(in_features, out_features, bias=False) + self.linear.weight = nn.Parameter(weights, requires_grad=False) + + # Attach scale and zero-point if provided + if weight_scale is not None: + self.linear.weight_scale = nn.Parameter( + torch.tensor(weight_scale), requires_grad=False + ) + if weight_zero_point is not None: + self.linear.weight_zero_point = nn.Parameter( + torch.tensor(weight_zero_point), requires_grad=False + ) + + def forward(self, x): + return self.linear(x) + + +def _create_quantization_config( + w_bits=8, + w_type="int", + w_strategy="tensor", + quantize_activations=False, + a_bits=8, + a_type="int", + a_strategy="tensor", +): + """ + Create a quantization configuration for testing. + """ + config_dict = { + "global_compression_ratio": 1.0, + "quant_method": "compressed-tensors", + "config_groups": { + "group_0": { + "targets": ["Linear"], + "weights": { + "num_bits": w_bits, + "strategy": w_strategy, + "symmetric": True, + "type": w_type, + }, + } + }, + } + + if quantize_activations: + config_dict["config_groups"]["group_0"]["input_activations"] = { + "num_bits": a_bits, + "strategy": a_strategy, + "symmetric": True, + "type": a_type, + } + + return QuantizationConfig.model_validate(config_dict) + + +def _quantization_config_from_string(config_str, q_type): + """ + Parse quantization config from string and type. + """ + w_bits = int(config_str[1]) + a_bits = int(config_str[3:]) + quantize_activations = a_bits < 16 + + return _create_quantization_config( + w_bits=w_bits, + w_type=q_type, + w_strategy="channel", + quantize_activations=quantize_activations, + a_bits=a_bits, + a_type=q_type, + a_strategy="channel", + ) + + +def _make_24_sparse(tensor): + """ + Apply 2:4 sparsity pattern to the given tensor. + """ + reshaped_tensor = tensor.view(tensor.size(0), -1, 4) + mask = torch.zeros_like(reshaped_tensor, dtype=torch.bool) + mask[..., :2] = True + sparsified_tensor = torch.where( + mask, reshaped_tensor, torch.tensor(0.0, dtype=tensor.dtype) + ) + return sparsified_tensor.view_as(tensor) + + +@pytest.mark.parametrize( + "quant_style, quant_type, is_24, expected_quant_compressor, " + "expected_sparsity_compressor", + [ + ("W8A8", "int", False, "int-quantized", "dense"), + ("W4A16", "int", False, "pack-quantized", "dense"), + ("W8A16", "int", False, "pack-quantized", "dense"), + ("W8A8", "int", True, "int-quantized", "sparse-24-bitmask"), + ("W4A16", "int", True, "marlin-24", "dense"), + ("W8A16", "int", True, "marlin-24", "dense"), + ("W8A8", "float", False, "float-quantized", "dense"), + ("W8A16", "float", False, "naive-quantized", "dense"), + ("W8A8", "float", True, "float-quantized", "sparse-24-bitmask"), + ("W8A16", "float", True, "naive-quantized", "dense"), + ], +) +def test_correct_compressor_inferred( + quant_style, + quant_type, + is_24, + expected_quant_compressor, + expected_sparsity_compressor, +): + """ + Test if the correct compressor is inferred based on + quantization and sparsity configurations. + """ + weights = torch.rand(10, 4) + if is_24: + weights = _make_24_sparse(weights) + + quantization_config = _quantization_config_from_string(quant_style, quant_type) + quantization_args = quantization_config.config_groups["group_0"].weights + + scale = ( + torch.ones((weights.shape[0], 1)) + if quantization_args.strategy == "channel" + else torch.tensor([1.0]) + ) + zero_point = torch.zeros_like(scale) + + quantized_weights = quantize( + weights, scale=scale, zero_point=zero_point, args=quantization_args + ) + + model = DummyLinearModel(quantized_weights, scale, zero_point) + model.linear.quantization_scheme = quantization_config.config_groups["group_0"] + model.linear.quantization_status = QuantizationStatus.FROZEN + + compressor = get_model_compressor(model) + + assert compressor.quantization_config.format == expected_quant_compressor + + if expected_sparsity_compressor == "dense": + assert ( + compressor.sparsity_config is None + or compressor.sparsity_config.format == expected_sparsity_compressor + ) + else: + assert compressor.sparsity_config.format == expected_sparsity_compressor