diff --git a/src/llmcompressor/transformers/compression/quantization_format.py b/src/llmcompressor/transformers/compression/quantization_format.py index e6f14a6c7..12e5c9414 100644 --- a/src/llmcompressor/transformers/compression/quantization_format.py +++ b/src/llmcompressor/transformers/compression/quantization_format.py @@ -1,7 +1,7 @@ from typing import Optional from compressed_tensors import CompressionFormat -from compressed_tensors.config import SparsityCompressionConfig +from compressed_tensors.config import SparsityStructure from compressed_tensors.quantization import QuantizationStrategy, QuantizationType from compressed_tensors.quantization.utils import ( is_model_quantized, @@ -16,7 +16,7 @@ def infer_quantization_format( model, quantization_format: Optional[str] = None, save_compressed: bool = False, - sparsity_config: Optional[SparsityCompressionConfig] = None, + sparsity_structure: Optional[str] = None, ) -> str: """ Infers a quantization format based on model state and compression args @@ -37,7 +37,7 @@ def infer_quantization_format( if save_compressed: weight_args, input_args = _get_unique_quant_args(model) is_24_structure = ( - sparsity_config and sparsity_config.sparsity_structure == "2:4" + SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR ) is_weight_only = len(input_args) == 0 and len(weight_args) > 0 diff --git a/src/llmcompressor/transformers/compression/sparsity_config.py b/src/llmcompressor/transformers/compression/sparsity_config.py index d6ed9f7e7..cd0021d99 100644 --- a/src/llmcompressor/transformers/compression/sparsity_config.py +++ b/src/llmcompressor/transformers/compression/sparsity_config.py @@ -1,7 +1,7 @@ from typing import Dict, Optional from compressed_tensors import CompressionFormat, SparsityCompressionConfig -from compressed_tensors.quantization.utils import is_model_quantized +from compressed_tensors.config import SparsityStructure from torch import Tensor from torch.nn import Module @@ -20,7 +20,7 @@ class SparsityConfigMetadata: metadata from the model """ - SPARSITY_THRESHOLD: float = 0.4 + SPARSITY_THRESHOLD: float = 0.5 @staticmethod def infer_global_sparsity( @@ -67,13 +67,14 @@ def infer_sparsity_structure(model: Optional[Module] = None) -> str: if model and sparsity_structure is None: sparsity_structure = infer_sparsity_structure_from_model(model) - return sparsity_structure or "unstructured" + return SparsityStructure(sparsity_structure).value @staticmethod def from_pretrained( model: Module, state_dict: Optional[Dict[str, Tensor]] = None, compress: bool = False, + is_marlin: bool = False, ) -> Optional["SparsityCompressionConfig"]: """ Determines compression type and informational parameters for a given model @@ -82,6 +83,7 @@ def from_pretrained( :param state_dict: optional state_dict to replace that in model, used for gathering global FSDP model info :param compress: whether or not to compress the model on disk + :param is_marlin: whether or not marlin compression is being used :return: compression config inferred from the model """ @@ -95,11 +97,17 @@ def from_pretrained( sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure( model=model ) - if is_model_quantized(model): - # compressing a sparse quantized model is not supported yet + if is_marlin: + # sparse compressor should be dense for marlin + # compression format = CompressionFormat.dense.value - elif compress: - format = CompressionFormat.sparse_bitmask.value + if compress: + format = ( + CompressionFormat.sparse_24_bitmask.value + if sparsity_structure == SparsityStructure.TWO_FOUR.value + else CompressionFormat.sparse_bitmask.value + ) + else: format = CompressionFormat.dense.value diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index 6de89dd8b..bef8227ef 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -8,6 +8,7 @@ import transformers from accelerate.accelerator import get_state_dict_offloaded_model from compressed_tensors import ( + CompressionFormat, ModelCompressor, SparsityCompressionConfig, is_module_offloaded, @@ -272,13 +273,20 @@ def get_model_compressor( if state_dict is None: state_dict = get_state_dict_offloaded_model(model) + sparsity_stucture = SparsityConfigMetadata.infer_sparsity_structure(model) + quantization_format = infer_quantization_format( + model=model, + quantization_format=quantization_format, + save_compressed=save_compressed, + sparsity_structure=sparsity_stucture, + ) + is_marlin = quantization_format == CompressionFormat.marlin_24.value + if sparsity_config is not None: sparsity_config.global_sparsity = SparsityConfigMetadata.infer_global_sparsity( model, state_dict=state_dict ) - sparsity_config.sparsity_structure = ( - SparsityConfigMetadata.infer_sparsity_structure() - ) + sparsity_config.sparsity_structure = sparsity_stucture elif not skip_compression_stats: # try to infer a sparsity config from the model if none is provided logger.info( @@ -288,15 +296,12 @@ def get_model_compressor( "skip_compression_stats=True" ) sparsity_config = SparsityConfigMetadata.from_pretrained( - model, state_dict=state_dict, compress=save_compressed + model, + state_dict=state_dict, + compress=save_compressed, + is_marlin=is_marlin, ) - quantization_format = infer_quantization_format( - model=model, - quantization_format=quantization_format, - save_compressed=save_compressed, - sparsity_config=sparsity_config, - ) return ModelCompressor.from_pretrained_model( model, sparsity_config=sparsity_config, diff --git a/tests/llmcompressor/transformers/compression/recipes/sparse_24_fp8.yaml b/tests/llmcompressor/transformers/compression/recipes/sparse_24_fp8.yaml new file mode 100644 index 000000000..6f8ba86be --- /dev/null +++ b/tests/llmcompressor/transformers/compression/recipes/sparse_24_fp8.yaml @@ -0,0 +1,38 @@ +pruning_stage: + obcq_modifiers: + SparseGPTModifier: + sparsity: 0.5 + sequential_update: true + mask_structure: "2:4" + targets: ['re:model.layers.\d*$'] +quant_stage: + quant_modifiers: + QuantizationModifier: + ignore: ["lm_head"] + config_groups: + group_0: + weights: + num_bits: 8 + type: float + strategy: tensor + dynamic: false + symmetric: true + input_activations: + num_bits: 8 + type: float + strategy: tensor + dynamic: true + symmetric: true + targets: ["Linear"] + pruning_modifiers: + ConstantPruningModifier: + targets: [ + 're:.*q_proj.weight', + 're:.*k_proj.weight', + 're:.*v_proj.weight', + 're:.*o_proj.weight', + 're:.*gate_proj.weight', + 're:.*up_proj.weight', + 're:.*down_proj.weight', + ] + start: 0 \ No newline at end of file diff --git a/tests/llmcompressor/transformers/compression/recipes/sparse_24_int8.yaml b/tests/llmcompressor/transformers/compression/recipes/sparse_24_int8.yaml new file mode 100644 index 000000000..73279db2c --- /dev/null +++ b/tests/llmcompressor/transformers/compression/recipes/sparse_24_int8.yaml @@ -0,0 +1,38 @@ +pruning_stage: + obcq_modifiers: + SparseGPTModifier: + sparsity: 0.5 + sequential_update: true + mask_structure: "2:4" + targets: ['re:model.layers.\d*$'] +quant_stage: + quant_modifiers: + QuantizationModifier: + ignore: ["lm_head"] + config_groups: + group_0: + weights: + num_bits: 8 + type: int + strategy: tensor + dynamic: false + symmetric: true + input_activations: + num_bits: 8 + type: int + strategy: tensor + dynamic: true + symmetric: true + targets: ["Linear"] + pruning_modifiers: + ConstantPruningModifier: + targets: [ + 're:.*q_proj.weight', + 're:.*k_proj.weight', + 're:.*v_proj.weight', + 're:.*o_proj.weight', + 're:.*gate_proj.weight', + 're:.*up_proj.weight', + 're:.*down_proj.weight', + ] + start: 0 \ No newline at end of file diff --git a/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py b/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py index df9726647..42c965b7a 100644 --- a/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py +++ b/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py @@ -6,7 +6,7 @@ import torch from accelerate import cpu_offload from accelerate.accelerator import get_state_dict_offloaded_model -from compressed_tensors import QUANTIZATION_CONFIG_NAME +from compressed_tensors import QUANTIZATION_CONFIG_NAME, CompressionFormat from compressed_tensors.compressors import ModelCompressor from compressed_tensors.config import BitmaskConfig, DenseSparsityConfig from compressed_tensors.quantization import QuantizationStatus @@ -364,3 +364,111 @@ def test_model_shared_tensors_gpu( test_model_shared_tensors( offload, torch_dtype, tie_word_embeddings, device_map, tmp_path ) + + +@pytest.mark.parametrize( + "model_stub, recipe, sparse_format, quant_format", + [ + ( + "Xenova/llama2.c-stories110M", + "tests/llmcompressor/transformers/compression/recipes/sparse_24_int8.yaml", + CompressionFormat.sparse_24.value, + CompressionFormat.int_quantized.value, + ), + ], +) +def test_compressor_stacking(model_stub, recipe, sparse_format, quant_format, tmp_path): + from llmcompressor.pytorch.model_load.helpers import get_session_model + + device = "cuda" + if not torch.cuda.is_available(): + device = "cpu" + dataset = "open_platypus" + concatenate_data = False + num_calibration_samples = 64 + splits = {"calibration": "train[:10%]"} + empty_model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype="auto") + + oneshot( + model=model_stub, + dataset=dataset, + num_calibration_samples=num_calibration_samples, + recipe=recipe, + concatenate_data=concatenate_data, + splits=splits, + oneshot_device=device, + clear_sparse_session=False, + ) + + # Fetch the oneshot model + model = get_session_model() + og_state_dict = model.state_dict() + path = tmp_path / "compressed" + + # Compress and save + model.save_pretrained( + path, + quantization_format=quant_format, + save_compressed=True, + ) + + # Verify config on disk + config = AutoConfig.from_pretrained(path) + compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None) + quant_config = ModelCompressor.parse_quantization_config(compression_config) + + # As HFQuantizer doesn't decompress the model, use the compressor to decompress + # the model instead + compressor = ModelCompressor.from_compression_config(compression_config) + + assert ( + compressor.sparsity_compressor is not None + ), "Sparse compressor not initialized" + assert compressor.sparsity_config.format == sparse_format + + assert ( + compressor.quantization_compressor is not None + ), "Quantization compressor not initialized" + assert quant_config["format"] == quant_format + + compressor.quantization_config.quantization_status = QuantizationStatus.FROZEN + compressor.decompress(model_path=path, model=empty_model) + + # Verify the abs difference between the decompressed model + # and the original model + reconstructed_state_dict = empty_model.state_dict() + assert len(og_state_dict) == len(reconstructed_state_dict) + for key in og_state_dict.keys(): + dense_tensor = og_state_dict[key].to(device) + reconstructed_tensor = reconstructed_state_dict[key].to(device) + assert dense_tensor.dtype == reconstructed_tensor.dtype + if key.endswith("weight") and quant_format != "dense": + # we don't expect an exact match for compressed + diff = torch.abs(dense_tensor - reconstructed_tensor) + assert not torch.any( + diff > 0.01 + ).item(), f"{key} has a diff greater than 0.01" + else: + assert torch.equal(dense_tensor, reconstructed_tensor) + shutil.rmtree(tmp_path) + + +# This parameterization should be added to the test_compressor_stacking test +# once the lossy nature of FP8 compress-decompress is resolved. +# Until then, this test is marked as xfail. +@pytest.mark.xfail(reason="Known issue with FP8 compress-decompress") +@pytest.mark.parametrize( + "model_stub, recipe, sparse_format, quant_format", + [ + ( + "Xenova/llama2.c-stories110M", + "tests/llmcompressor/transformers/compression/recipes/sparse_24_fp8.yaml", + CompressionFormat.sparse_24.value, + CompressionFormat.float_quantized.value, + ), + ], +) +def test_compressor_stacking_fp8( + model_stub, recipe, sparse_format, quant_format, tmp_path +): + test_compressor_stacking(model_stub, recipe, sparse_format, quant_format, tmp_path)