From 2a5a71a2b4e14b923d7b0889b52aec8d7d607260 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Thu, 12 Dec 2024 21:55:46 +0000 Subject: [PATCH] Enable sparse24bytemask compressor --- .../transformers/compression/sparsity_config.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/transformers/compression/sparsity_config.py b/src/llmcompressor/transformers/compression/sparsity_config.py index 769e485d4..b15dc87e1 100644 --- a/src/llmcompressor/transformers/compression/sparsity_config.py +++ b/src/llmcompressor/transformers/compression/sparsity_config.py @@ -94,15 +94,19 @@ def from_pretrained( if global_sparsity < 0.05: return None - sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure( - model=model + sparsity_structure = SparsityStructure( + SparsityConfigMetadata.infer_sparsity_structure(model=model) ) if is_marlin: # sparse compressor should be dense for marlin # compression format = CompressionFormat.dense.value if compress: - format = CompressionFormat.sparse_bitmask.value + format = ( + CompressionFormat.sparse_bitmask.value + if sparsity_structure == SparsityStructure.TWO_FOUR + else CompressionFormat.sparse_24_bytemask.value + ) else: format = CompressionFormat.dense.value