diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index 939991ab6..f1545c992 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -9,7 +9,7 @@ model = SparseAutoModelForCausalLM.from_pretrained( MODEL_ID, - device_map="auto", + device_map="cuda:0", torch_dtype="auto", ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) @@ -20,7 +20,7 @@ # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 512 +NUM_CALIBRATION_SAMPLES = 160 #2048 MAX_SEQUENCE_LENGTH = 2048 # Load dataset and preprocess. @@ -55,7 +55,7 @@ def tokenize(sample): # Configure the quantization algorithm to run. # * quantize the weights to 4 bit with GPTQ with a group size 128 -recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]) +recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"], batch_size=-1, dampening_frac=0.5) # Apply algorithms. oneshot( diff --git a/examples/quantization_w4a16/vision2_example.py b/examples/quantization_w4a16/vision2_example.py new file mode 100644 index 000000000..1f57bb9f9 --- /dev/null +++ b/examples/quantization_w4a16/vision2_example.py @@ -0,0 +1,83 @@ +from datasets import load_dataset +from transformers import AutoProcessor, MllamaForConditionalGeneration + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import oneshot + +# Select model and load it. +MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct" + +model = MllamaForConditionalGeneration.from_pretrained( + MODEL_ID, + device_map="cuda:0", + torch_dtype="auto", +) +processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) + +# Select calibration dataset. +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 160 #2048 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) +ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) + + +def preprocess(example): + return { + "text": processor.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return processor( + None, + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# Configure the quantization algorithm to run. +# * quantize the weights to 4 bit with GPTQ with a group size 128 +recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"], batch_size=1, dampening_frac=0.5) + +# Apply algorithms. +oneshot( + model=model, + tokenizer=MODEL_ID, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, +) + +# Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +input_ids = processor("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=100) +print(processor.decode(output[0])) +print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128" +model.save_pretrained(SAVE_DIR, save_compressed=True) +processor.save_pretrained(SAVE_DIR) diff --git a/examples/quantization_w4a16/vision_example.py b/examples/quantization_w4a16/vision_example.py new file mode 100644 index 000000000..f89ada21a --- /dev/null +++ b/examples/quantization_w4a16/vision_example.py @@ -0,0 +1,88 @@ +from datasets import load_dataset +from transformers import AutoProcessor + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot + +# Select model and load it. +MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct" + +model = SparseAutoModelForCausalLM.from_pretrained( + MODEL_ID, + device_map="cuda:0", + torch_dtype="auto", +) +breakpoint() +processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) + +# Select calibration dataset. +DATASET_ID = "lmms-lab/flickr30k" +DATASET_SPLIT = "test[:165]" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 165 #2048 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) +ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) + + +def preprocess(example): + messages = [ + [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What does the image show?"} + ] + } + ], + ] + return { + "text": processor.apply_chat_template( + messages, + add_generation_prompt=True, + ), + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return processor(sample["image"], sample["text"], add_special_tokens=False, return_tensors="pt", max_length=MAX_SEQUENCE_LENGTH) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# Configure the quantization algorithm to run. +# * quantize the weights to 4 bit with GPTQ with a group size 128 +recipe = GPTQModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"], batch_size=-1, dampening_frac=0.5) + +# Apply algorithms. +oneshot( + model=model, + tokenizer=MODEL_ID, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, +) + +# Confirm generations of the quantized model look sane. +print("\n\n") +print("========== SAMPLE GENERATION ==============") +input_ids = processor("Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=100) +print(processor.decode(output[0])) +print("==========================================\n\n") + +# Save to disk compressed. +SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128" +model.save_pretrained(SAVE_DIR, save_compressed=True) +processor.save_pretrained(SAVE_DIR) diff --git a/shubhra.py b/shubhra.py new file mode 100644 index 000000000..4996c8277 --- /dev/null +++ b/shubhra.py @@ -0,0 +1,92 @@ +from datasets import load_dataset +from transformers import AutoProcessor, MllamaForConditionalGeneration, LlavaForConditionalGeneration + +from llmcompressor.modifiers.quantization import GPTQModifier +from llmcompressor.transformers import oneshot, wrap_hf_model_class +import os + +# Load model. +#model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" +model_id = "mgoin/pixtral-12b" +model_class = wrap_hf_model_class(LlavaForConditionalGeneration) +model = model_class.from_pretrained(model_id, device_map="auto", torch_dtype="auto", trust_remote_code=True, _attn_implementation="eager",) +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + +print("Loading dataset") +DATASET_ID = "lmms-lab/flickr30k" +DATASET_SPLIT = "test[:128]" + +NUM_CALIBRATION_SAMPLES = 1#128 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) +ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) + +print("Preprocessing samples") +def preprocess(example): + messages = [ + [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What does the image show?"} + ] + } + ], + ] + return { + "text": processor.apply_chat_template( + messages, + add_generation_prompt=True, + ), + } + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return processor(sample["image"], sample["text"], add_special_tokens=False, return_tensors="pt") + + +ds = ds.map(tokenize, remove_columns=ds.column_names) +print(ds) + +print("Setting up quantization params") +# Configure the quantization algorithm and scheme. +# In this case, we: +# * quantize the weights to fp8 with per channel via ptq +# * quantize the activations to fp8 with dynamic per token +#ignore=["re:.*lm_head", "re:model.vision_embed_tokens.*"] +#ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*", "re:language_model.*cross_attn.*"], +ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"] + +recipe = [ + # SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore), + GPTQModifier(targets="Linear", scheme="W8A8", ignore=ignore), +] + +save_name = model_id.split("/")[1] + "-W8A8" +save_path = os.path.join("./my_test/", save_name) +print("Starting quantization") +oneshot( + model=model, + tokenizer=model_id, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + trust_remote_code_model=True, + output_dir=save_path, +) + +#processor.save_pretrained(save_path) + +# Confirm generations of the quantized model look sane. +print("========== SAMPLE GENERATION ==============") +input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda") +output = model.generate(input_ids, max_new_tokens=20) +print(processor.decode(output[0])) +print("==========================================") diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index b6dbda485..74877bf93 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -1,36 +1,41 @@ import warnings -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch +import math from compressed_tensors.quantization import ( QuantizationScheme, - disable_quantization, - enable_quantization, ) +from llmcompressor.modifiers.quantization.calibration import freeze_module_quantization from loguru import logger -from pydantic import Field, field_validator -from torch.nn import Module +from pydantic import Field, PrivateAttr, field_validator from llmcompressor.core import State from llmcompressor.modifiers import Modifier, ModifierFactory -from llmcompressor.modifiers.quantization.calibration import freeze_module_quantization -from llmcompressor.modifiers.quantization.gptq.utils import ( - GPTQWrapper, - get_output_error, -) -from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor +from llmcompressor.modifiers.quantization.gptq.utils.gptq_quantize import quantize_weight +from llmcompressor.modifiers.quantization.quantization.base import QuantizationModifier +from llmcompressor.modifiers.utils.hooks import LayerCompressorMixin from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward -from llmcompressor.utils.fsdp.context import fix_fsdp_module_name -from llmcompressor.utils.pytorch.module import ( - get_layers, - get_no_split_params, - qat_active, +from llmcompressor.observers.base import Observer +from llmcompressor.transformers.finetune.data.data_helpers import ( + create_batch_dataloader, +) +from llmcompressor.utils.fsdp.helpers import delete_offload_parameter, register_offload_parameter, update_offload_parameter +from llmcompressor.utils.helpers import ( + align_module, + calibration_forward_context, + getattr_chain, +) +from compressed_tensors.quantization import ( + fake_quantize, ) +from llmcompressor.utils.pytorch.module import qat_active + __all__ = ["GPTQModifier"] -class GPTQModifier(Modifier): +class GPTQModifier(Modifier, LayerCompressorMixin): """ Modifier for applying the one-shot OBCQ algorithm to a model @@ -48,6 +53,7 @@ class GPTQModifier(Modifier): | test_stage: | obcq_modifiers: | GPTQModifier: + | true_sequential: False | dampening_frac: 0.001 | block_size: 128 | config_groups: @@ -67,8 +73,8 @@ class GPTQModifier(Modifier): :param sequential_update: Whether or not to update weights sequentially by layer. This option is depreciated and setting to False is no longer supported - :param targets: list of layer names to compress during GPTQ, or '__ALL__' - to compress every layer in the model + :param sequential_targets: list of layer names to compress during GPTQ, or + '__ALL__' to compress every layer in the model :param block_size: Used to determine number of columns to compress in one pass :param quantize: Set to True to quantize using an existing quantization modifier, or pass in the configuration for a quantization modifier if one does not @@ -97,21 +103,22 @@ class GPTQModifier(Modifier): """ sequential_update: bool = True # DEPRECIATED - targets: Union[str, List[str], None] = None + batch_size: int = -1 sequential_targets: Union[str, List[str], None] = None block_size: int = 128 - quantize: Union[bool, Dict] = True dampening_frac: Optional[float] = 0.01 + quantize: Union[bool, Dict] = True + + # arguments used for quant modifier config_groups: Optional[Dict[str, QuantizationScheme]] = None + scheme: Optional[Union[str, Dict[str, Any]]] = None + targets: Union[str, List[str], None] = None ignore: List[str] = Field(default_factory=list) - disable_quantization_observer_epoch: Optional[float] = None num_calibration_steps: Optional[int] = None - scheme: Optional[Union[str, Dict[str, Any]]] = None + disable_quantization_observer_epoch: Optional[float] = None - model: Optional[Any] = None - layer_compressors_: Optional[List[Any]] = None - compressible_layers_: Optional[List] = None - quantization_modifier_: Any = None + _quantization_modifier: Optional[QuantizationModifier] = PrivateAttr() + _num_batches: int = PrivateAttr() @field_validator("sequential_update", mode="before") def validate_sequential_update(cls, value: bool) -> bool: @@ -174,8 +181,8 @@ def on_initialize_structure(self, state: State, **kwargs): self._build_quant_modifier_from_dict(self.quantize) self.quantize = True - if self.quantization_modifier_: - self.quantization_modifier_.on_initialize_structure(state, **kwargs) + if self._quantization_modifier: + self._quantization_modifier.on_initialize_structure(state, **kwargs) def on_initialize(self, state: "State", **kwargs) -> bool: """ @@ -185,25 +192,62 @@ def on_initialize(self, state: "State", **kwargs) -> bool: """ if not self.initialized_structure_: self.on_initialize_structure(state, **kwargs) - if self.quantization_modifier_: - self.quantization_modifier_.initialize(state, **kwargs) + if self._quantization_modifier: + self._quantization_modifier.initialize(state, **kwargs) if not self.quantize: raise ValueError("To use the GPTQModifier, quantization must be enabled.") - - modifiable_model = state.model - calibration_dataloader = state.data.calib - - if self.sequential_targets is None: - # if no targets are provided, default to the modules that shouldn't be - # split by FSDP. For Transformers models this is equivalent to the - # decoder layers (ie LlamaDecoderLayer) - self.sequential_targets = get_no_split_params(modifiable_model) - - self.initialize_compression(modifiable_model, calibration_dataloader) - self.apply_compression(calibration_dataloader) + + if self.batch_size <= 0: + batch_size = len(state.data.calib.dataset) + else: + batch_size = self.batch_size + self._num_batches = math.ceil(len(state.data.calib.dataset) / batch_size) + + self.register_hooks(state.model) + #torch.cuda.memory._record_memory_history(max_entries=1_000_000) + try: + self.calibration_forward(state.model, state.data.calib) + finally: + pass + #torch.cuda.memory._dump_snapshot("bs10.pickle") + #torch.cuda.memory._record_memory_history(enabled=None) + #exit(0) + + self.remove_hooks() + self.finish_compression(state.model) + + # freeze quantization state.model.apply(freeze_module_quantization) return True + + def finish_compression(self, model: torch.nn.Module): + for module in model.modules(): + quant_args = getattr_chain(module, "quantization_scheme.weights", None) + if quant_args is None: + continue + + with align_module(module): + + if self.batch_size != -1: + weight = module.weight_acc / self._num_batches + delete_offload_parameter(module, "weight_acc") + else: + weight = module.weight + + observer = Observer.load_from_registry( + quant_args.observer, quantization_args=quant_args + ) + scale, zero_point = observer(weight) + weight = fake_quantize( + weight, + scale, + zero_point, + quant_args, + ) + update_offload_parameter(module, "weight", weight) + update_offload_parameter(module, "weight_scale", scale) + update_offload_parameter(module, "weight_zero_point", zero_point) def on_finalize(self, state: "State", **kwargs) -> bool: """ @@ -211,117 +255,83 @@ def on_finalize(self, state: "State", **kwargs) -> bool: :param state: session state storing input model and calibration data """ - if self.quantization_modifier_: - self.quantization_modifier_.finalize(state, **kwargs) + if self._quantization_modifier: + self._quantization_modifier.finalize(state, **kwargs) return True - def compressible_layers(self) -> Dict: + def calibration_forward( + self, model: torch.nn.Module, dataloader: torch.utils.data.DataLoader + ): """ - Retrieves the modules corresponding to a list of - compressible layer names + Perform calibration forward pass with one batch whose size is the size + of the dataset - :precondition: self.model is set and is a torch.nn.Module - :return: dictionary of modules to compress + :param model: model to perform forward pass with + :param dataloader: dataloader containing calibration dataset """ - if not isinstance(self.model, Module): - raise ValueError( - "`self.model` must be a torch.nn.Module to use " - f"the {self.__class__.__qualname__} modifier but got " - f"{type(self.model)} instead" - ) + if self.batch_size <= 0: + batch_size = len(dataloader.dataset) + else: + batch_size = self.batch_size + dataloader = create_batch_dataloader(dataloader, batch_size=batch_size) + with calibration_forward_context(model): + run_calibration_forward(model, dataloader, mask_padding=True) - return get_layers(self.sequential_targets, self.model) + def pre_compress_module(self, module: torch.nn.Module): + if self.batch_size != -1: + print("created aux buffers") + register_offload_parameter(module, "weight_acc", torch.nn.Parameter(torch.zeros_like(module.weight.data), requires_grad=False)) - def initialize_compression( + def compress_module( self, - model: Module, - dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None, - ): + name: str, + module: torch.nn.Module, + args: Tuple[torch.Tensor, ...], + ) -> float: """ - Setup for GPTQ, initializes the model - and other parameters, also initilializes the - compressible layers of model, and sets the device + Quantize a module's weight according to the GPTQ algorithm - :param model: model to initialize for compression - :param dataloader: calibration data, not used by GPTQ in this function - """ - self.model = model - self.compressible_layers_ = self.compressible_layers() - self.layer_compressors_ = [] - - for idx, (name, layer) in enumerate(self.compressible_layers_.items()): - name = fix_fsdp_module_name(name) - logger.info(f"Preparing {name} for compression") - args = self._pruning_arguments() - comp_cls = self._compression_class() - compressor = LayerCompressor(comp_cls, self.model, layer, idx, name, args) - self.layer_compressors_.append(compressor) - - # for the initial forward data pass, add an early stop exception in order - # to capture inputs right before being compressed by first module - first_layer_compressor = self.layer_compressors_[0] - first_layer_compressor.set_early_stop() - - @torch.no_grad() - def apply_compression( - self, dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None - ) -> Dict: - """ - Run GPTQ on the loaded model, using dataloader as calibration data + :param name: name of module being quantized + :param module: module being quantized + :param args: input arguments for module forward pass - :param dataloader: calibration data for GPTQ + :return: total loss from applying weight quantization to this module """ - class_name = self.__class__.__name__.replace("PyTorch", "") - logger.info( - f"Running {class_name} calibration with " f"{len(dataloader)} samples..." - ) - - # quantization scales and zp are already initialized but we do not - # want to calibrate wrt to these - self.model.apply(disable_quantization) - - forward_pass_use_cache = self.model.config.use_cache - self.model.config.use_cache = False - - # run_calibration_forward uses the early stop exception to capture values - # as intermediates right before the forward pass of the first module - intermediates = run_calibration_forward( - self.model, dataloader, mask_padding=True - ) - self.layer_compressors_[0].clear_early_stop() - - num_layers = len(self.compressible_layers_) - for idx, layer_compressor in enumerate(self.layer_compressors_): - logger.info(f"\n===== Compressing layer {idx+1}/{num_layers} " " =====") - - # run the forward pass for each transformer layer (block) one at a time - logger.info(f"Calibrating {layer_compressor.name}...") - layer_compressor.pre_compress() - unquantized_outputs = layer_compressor.calibrate_layer(intermediates) - - layer_compressor.compress() - layer_compressor.post_compress() - layer_compressor.revert_layer_wrappers() - - # perform a second forward pass of the module to calculate weight-quantized - # outputs for use as inputs to the next layer (block) - quantized_outputs = layer_compressor.calibrate_layer(intermediates) - error = get_output_error(unquantized_outputs, quantized_outputs) - logger.info(f"Mean output error from quantization: {error:.3f}") - intermediates = quantized_outputs + logger.info(f"Quantizing {name}...") + + # Assume that first argument is the input + inp = args[0] + quant_args = getattr_chain(module, "quantization_scheme.weights") + logger.info(f"Using {inp.size(0)} samples") + + with align_module(module): + print(inp.shape) + loss, quantized_weight, _scale, _zero_point, _g_idx = quantize_weight( + module.weight.data, + inp, + quant_args, + blocksize=self.block_size, + percdamp=self.dampening_frac, + module_class=type(module), + ) - self.model.config.use_cache = forward_pass_use_cache + if self.batch_size != -1: + module.weight_acc += quantized_weight + update_offload_parameter(module, "weight_acc") + else: + module.weight -= module.weight + module.weight += quantized_weight + update_offload_parameter(module, "weight") - # re-enable quantization - self.model.apply(enable_quantization) + return loss def _build_quant_modifier(self): """ Build a quantization modifier based on the specified config_groups, ignore list, and num_calibration_steps. - :postcondition: self.quantization_modifier_ is set to the built + :postcondition: self._quantization_modifier is set to the built quantization modifier """ @@ -347,26 +357,9 @@ def _build_quant_modifier(self): def _build_quant_modifier_from_dict(self, quant_config): modifier_type = list(quant_config.keys())[0] modifier_args = quant_config[modifier_type] - self.quantization_modifier_ = ModifierFactory.create( + self._quantization_modifier = ModifierFactory.create( modifier_type, allow_registered=True, allow_experimental=True, **modifier_args, ) - - def _pruning_arguments(self): - """ - Gather the parameters needed for root module compression in a dict - - :return: dict of params for pruning - """ - return { - "blocksize": self.block_size, - "percdamp": self.dampening_frac, - } - - def _compression_class(self): - """ - :return: wrapper class used for root modules of this compression class - """ - return GPTQWrapper diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/__init__.py b/src/llmcompressor/modifiers/quantization/gptq/utils/__init__.py index a8673dfc2..ec39da973 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/__init__.py +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/__init__.py @@ -1,4 +1,3 @@ # flake8: noqa -from .gptq_wrapper import * -from .helpers import * +from .gptq_quantize import * diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py new file mode 100644 index 000000000..a625e8a7b --- /dev/null +++ b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_quantize.py @@ -0,0 +1,272 @@ +import math +from copy import copy +from typing import Tuple, Union, Optional, Type + +from llmcompressor.observers.base import Observer +import torch +import transformers +from compressed_tensors.quantization import ( + ActivationOrdering, + QuantizationArgs, + QuantizationStrategy, + fake_quantize, +) + +from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD +from llmcompressor.pytorch.utils.helpers import tensor_sparsity + +GPTQ_PRECISION = torch.float32 + + +def compute_hessian(inp: torch.Tensor, module_class: Type[torch.nn.Module], device) -> torch.Tensor: + """ + Calculate the hessian with respect to the module inputs + + :param inp: module inputs + :param module_class: class of module, likely torch.nn.Linear + :return: hessian w.r.t. module inputs + """ + inp = inp.to(device=device) + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + + nsamples = inp.shape[0] # note this is the number of dataset samples, not + # multiplied by the sequence length + + if module_class in (torch.nn.Linear, transformers.Conv1D): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + + inp = inp.to(dtype=GPTQ_PRECISION) + inp = math.sqrt(2 / nsamples) * inp + return inp.matmul(inp.t()) + + +def invert_hessian(H: torch.Tensor, percdamp: float) -> torch.Tensor: + """ + Performs in-place inversion of the hessian in order to save memory + + :param H: hessian being inverted + :param percdamp: dampening factor on hessian diagonal + :return: inverted hessian + """ + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(H.shape[0], device=H.device) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + return H + + +def quantize_weight( + weight: torch.Tensor, + inp: torch.Tensor, + quant_args: QuantizationArgs, + blocksize: int = 128, + percdamp: float = 0.01, + module_class: Type[torch.nn.Module] = torch.nn.Linear, + weight_original: Optional[torch.Tensor] = None, +) -> Tuple[float, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]: + """ + Quantize a module weight according to the GPTQ algorithm + + TODO + :param weight: weight being quantized + :param inp: module inputs used to calculate hessian + :param quant_args: quantization arguments used to find quantization parameters + :param blocksize: chunk size of quantization updates + :param percdamp: dampening factor on hessian diagonal + :param module_class: class of module, likely torch.nn.Linear + :return: loss, quantized_weight, scale, zero_point, g_idx + """ + strategy = quant_args.strategy + actorder = quant_args.actorder + final_shape = weight.shape + final_dtype = weight.dtype + W = weight.data.clone() + + # create observer for calculating quantization parameters + observer = Observer.load_from_registry( + "minmax", + quantization_args=quant_args, + averaging_constant=1.0, # ignore moving average + ) + + if weight_original is not None: + raise NotImplementedError() + + # standardize shape and dtype + if module_class == torch.nn.Conv2d: + W = W.flatten(1) + elif module_class == transformers.Conv1D: + W.transpose_(0, 1) + W = W.to(dtype=GPTQ_PRECISION) + num_rows = W.shape[0] + num_columns = W.shape[1] + + H = compute_hessian(inp, module_class, device=weight.device) + + if strategy == QuantizationStrategy.GROUP: + # mapping from column index to group index + g_idx = ( + torch.arange(num_columns, device=W.device, dtype=torch.int) + // quant_args.group_size + ) + + if actorder == ActivationOrdering.GROUP: + # permute by activation order first, then update groups + W, H, perm = _apply_activation_ordering(W, H) + scale, zero_point = observer(W, g_idx=None) + + # use identity g_idx (invert permutation later) + + elif actorder == ActivationOrdering.WEIGHT: + # update groups first, then permute by activation order + scale, zero_point = observer(W, g_idx=None) + W, H, perm = _apply_activation_ordering(W, H) + + # permute g_idx to maintain identity mapping after unpermutation + g_idx = g_idx[perm] + + else: + scale, zero_point = observer(W, g_idx=None) + else: + scale, zero_point = observer(W, g_idx=None) + + # sparsity mask + sparsity = tensor_sparsity(W) + preserve_zeros = sparsity >= SPARSITY_THRESHOLD + W_nz_mask = ( + (~torch.isclose(W, torch.zeros(1, device=W.device).float())).float() + if preserve_zeros + else None + ) + + losses = torch.zeros(num_rows, device=weight.device) + + # mask dead hessian values + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + # compute inverse hessian in place to save memory + # TODO: check in place + Hinv = invert_hessian(H, percdamp) + + # See section 3.4 of https://arxiv.org/abs/2203.07259 + for i1 in range(0, num_columns, blocksize): + i2 = min(i1 + blocksize, num_columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + if preserve_zeros: + W1_nz_mask = W_nz_mask[:, i1:i2] + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + q = w.clone() + + # quantize column + if strategy == QuantizationStrategy.TENSOR: + q = fake_quantize( + q, + scale, + zero_point, + quant_args, + ) + elif strategy == QuantizationStrategy.CHANNEL: + q = fake_quantize( + q, + scale[:, 0], + zero_point[:, 0], + quant_args, + ) + elif strategy == QuantizationStrategy.GROUP: + # get the group index for the current column + column_idx = i1 + i + group_index = g_idx[column_idx] + + # Since we're only applying quantization to a slice, this + # ends up being a channelwise application + altered_qargs = copy(quant_args) + altered_qargs.strategy = QuantizationStrategy.CHANNEL + q = fake_quantize( + q, + scale[:, group_index], + zero_point[:, group_index], + altered_qargs, + ) + else: + raise ValueError( + f"Quantization strategy is not supported for GPTQ: {strategy}" + ) + + # propagate column error + Q1[:, i] = q + losses1[:, i] = (w - q) ** 2 / d**2 + + err1 = (w - q) / d + w1_err = err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + if preserve_zeros: + W1[:, i:] -= w1_err * W1_nz_mask[:, i:] + else: + W1[:, i:] -= w1_err + Err1[:, i] = err1 + + # propagate block error + W[:, i1:i2] = Q1 + losses += torch.sum(losses1, 1) / 2 + + w_err = Err1.matmul(Hinv[i1:i2, i2:]) + if preserve_zeros: + W[:, i2:] -= w_err * W_nz_mask[:, i2:] + else: + W[:, i2:] -= w_err + + has_gidx = False + if strategy == QuantizationStrategy.GROUP: + if actorder == ActivationOrdering.WEIGHT: + # restore original permutation + invperm = torch.argsort(perm) + W = W[:, invperm] + + elif actorder == ActivationOrdering.GROUP: + # restore original permutation + invperm = torch.argsort(perm) + W = W[:, invperm] + g_idx = g_idx[invperm] + + # only save g_idx if mapping is not identity + has_gidx = True + + if not has_gidx: + g_idx = None + + if module_class == transformers.Conv1D: + W.transpose_(0, 1) + W = W.reshape(final_shape).to(final_dtype) + + loss = torch.sum(losses).item() + return loss, W, scale, zero_point, g_idx + + +def _apply_activation_ordering( + W: torch.Tensor, H: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Permute weight and hessian in order of greatest outupt activations + + :param W: weight to permute + :param H: hessian used to determine activation ordering + :return: permuted weight, permuted hessian, permutation map + """ + perm = torch.argsort(torch.diag(H), descending=True) + return W[:, perm], H[perm][:, perm], perm diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py b/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py deleted file mode 100644 index bc8b43284..000000000 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/gptq_wrapper.py +++ /dev/null @@ -1,350 +0,0 @@ -import time -from typing import Tuple - -from compressed_tensors.quantization import ( - ActivationOrdering, - QuantizationArgs, - QuantizationStrategy, -) -from compressed_tensors.quantization.lifecycle.forward import fake_quantize - -from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD -from llmcompressor.modifiers.utils.compression_wrapper import ModuleCompressionWrapper -from llmcompressor.observers import Observer -from llmcompressor.pytorch.utils.helpers import tensor_sparsity -from llmcompressor.utils import getattr_chain -from llmcompressor.utils.metric_logging import ( - get_GPU_memory_usage, - get_layer_size_bytes, -) - -try: - import transformers -except ImportError as err: - transformers = None - transformers_err = err - -import math -from copy import copy - -import torch -import torch.nn as nn -from compressed_tensors.utils import ( - get_offloaded_device, - is_module_offloaded, - update_parameter_data, - update_prefix_dict, -) -from loguru import logger - -__all__ = ["GPTQWrapper"] - - -class GPTQWrapper(ModuleCompressionWrapper): - """ - Runs GPTQ on a single module that contains no sub-modules - - Lifecycle: - - add_batch - - compress - - free - - :param name: name of module to run compression on - :param layer: module to run compression on - """ - - def __init__(self, name, layer): - super().__init__(name=name, layer=layer) - - # for Hessian calculation - self.register_buffer( - "H", - torch.zeros( - (self.columns, self.columns), device=self.dev, dtype=torch.float32 - ), - ) - - def add_batch(self, inp: torch.Tensor, out: torch.Tensor): - """ - Add a batch of layer input and output data to the Hessian calculation - - :param inp: tensor containing layer input - :param out: tensor containing layer output - """ - if len(inp.shape) == 2: - inp = inp.unsqueeze(0) - tmp = inp.shape[0] - if isinstance(self.layer, nn.Linear) or isinstance( - self.layer, transformers.Conv1D - ): - if len(inp.shape) == 3: - inp = inp.reshape((-1, inp.shape[-1])) - inp = inp.t() - self.H *= self.nsamples / (self.nsamples + tmp) - self.nsamples += tmp - inp = inp.to(dtype=self.H.dtype) - inp = math.sqrt(2 / self.nsamples) * inp - self.H += inp.matmul(inp.t()) - - def compress( - self, - blocksize: int = 128, - percdamp: float = 0.01, - ): - """ - Run pruning and quantization(if applicable) on the layer up to the target - sparsity value. - - :param blocksize: Number of columns to compress in one pass - :param percdamp: Amount of dampening to apply to H, as a fraction of the - diagonal norm - """ - args_loc = "quantization_scheme.weights" - weight_quant_args = getattr_chain(self.layer, args_loc, None) - if weight_quant_args is None: - logger.debug(f"Skipping unquantized layer {self.name}...") - return - - if is_module_offloaded(self.layer): - self.layer._hf_hook.pre_forward(self.layer) - - strategy = weight_quant_args.strategy - actorder = weight_quant_args.actorder - final_shape = self.layer.weight.shape - final_dtype = self.layer.weight.dtype - W = self.layer.weight.data.clone() - - # standardize shape and dtype - if isinstance(self.layer, nn.Conv2d): - W = W.flatten(1) - elif isinstance(self.layer, transformers.Conv1D): - W.transpose_(0, 1) - W = W.float() - - tick = time.time() - - if strategy == QuantizationStrategy.GROUP: - # mapping from column index to group index - g_idx = ( - torch.arange(self.columns, device=W.device, dtype=torch.int) - // weight_quant_args.group_size - ) - - if actorder == ActivationOrdering.GROUP: - # permute by activation order first, then update groups - W, self.H, perm = self._apply_activation_ordering(W, self.H) - self._update_quantization_parameters(weight_quant_args, W) - - # use identity g_idx (invert permutation later) - - elif actorder == ActivationOrdering.WEIGHT: - # update groups first, then permute by activation order - self._update_quantization_parameters(weight_quant_args, W) - W, self.H, perm = self._apply_activation_ordering(W, self.H) - - # permute g_idx to maintain identity mapping after unpermutation - g_idx = g_idx[perm] - - scale = self.layer.weight_scale - zero_point = self.layer.weight_zero_point - - # sparsity mask - sparsity = tensor_sparsity(W) - preserve_zeros = sparsity >= SPARSITY_THRESHOLD - W_nz_mask = ( - (~torch.isclose(W, torch.zeros(1, device=W.device).float())).float() - if preserve_zeros - else None - ) - - # mask dead hessian values - dead = torch.diag(self.H) == 0 - self.H[dead, dead] = 1 - W[:, dead] = 0 - - Losses = torch.zeros(self.rows, device=self.dev) - - # compute inverse hessian in place to save memory - try: - damp = percdamp * torch.mean(torch.diag(self.H)) - diag = torch.arange(self.columns, device=self.dev) - self.H[diag, diag] += damp - self.H = torch.linalg.cholesky(self.H) - self.H = torch.cholesky_inverse(self.H) - self.H = torch.linalg.cholesky(self.H, upper=True) - Hinv = self.H - except torch._C._LinAlgError: - raise ValueError( - "Failed to invert hessian due to numerical instability. Consider " - "increasing GPTQModifier.dampening_frac, increasing the number " - "of calibration samples, or shuffling the calibration dataset" - ) - - # See section 3.4 of https://arxiv.org/abs/2203.07259 - for i1 in range(0, self.columns, blocksize): - i2 = min(i1 + blocksize, self.columns) - count = i2 - i1 - - W1 = W[:, i1:i2].clone() - Q1 = torch.zeros_like(W1) - Err1 = torch.zeros_like(W1) - Losses1 = torch.zeros_like(W1) - Hinv1 = Hinv[i1:i2, i1:i2] - - if preserve_zeros: - W1_nz_mask = W_nz_mask[:, i1:i2] - - for i in range(count): - w = W1[:, i] - d = Hinv1[i, i] - q = w.clone() - - # quantize column - if strategy == QuantizationStrategy.TENSOR: - q = fake_quantize( - q, - scale, - zero_point, - self.layer.quantization_scheme.weights, - ) - elif strategy == QuantizationStrategy.CHANNEL: - q = fake_quantize( - q, - scale[:, 0], - zero_point[:, 0], - weight_quant_args, - ) - elif strategy == QuantizationStrategy.GROUP: - # get the group index for the current column - column_idx = i1 + i - group_index = g_idx[column_idx] - - # Since we're only applying quantization to a slice, this - # ends up being a channelwise application - altered_qargs = copy(weight_quant_args) - altered_qargs.strategy = QuantizationStrategy.CHANNEL - q = fake_quantize( - q, - scale[:, group_index], - zero_point[:, group_index], - altered_qargs, - ) - else: - raise ValueError( - "Quantization strategy is not supported for GPTQ: " - f"{strategy}" - ) - - # propagate column error - Q1[:, i] = q - Losses1[:, i] = (w - q) ** 2 / d**2 - - err1 = (w - q) / d - w1_err = err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) - if preserve_zeros: - W1[:, i:] -= w1_err * W1_nz_mask[:, i:] - else: - W1[:, i:] -= w1_err - Err1[:, i] = err1 - - # propagate block error - W[:, i1:i2] = Q1 - Losses += torch.sum(Losses1, 1) / 2 - - w_err = Err1.matmul(Hinv[i1:i2, i2:]) - if preserve_zeros: - W[:, i2:] -= w_err * W_nz_mask[:, i2:] - else: - W[:, i2:] -= w_err - - if "METRIC" in logger._core.levels.keys(): - self._log_metrics(tick, Losses) - - if strategy == QuantizationStrategy.GROUP: - if actorder == ActivationOrdering.WEIGHT: - # restore original permutation - invperm = torch.argsort(perm) - W = W[:, invperm] - - elif actorder == ActivationOrdering.GROUP: - # restore original permutation - invperm = torch.argsort(perm) - W = W[:, invperm] - g_idx = g_idx[invperm] - - # only save g_idx if mapping is not identity - update_parameter_data(self.layer, g_idx, "weight_g_idx") - - if isinstance(self.layer, transformers.Conv1D): - W.transpose_(0, 1) - W = W.reshape(final_shape).to(final_dtype) - - # This is a bit hacky, but FSDP updates only work if we change - # the weight in place, clone() or direct assignment won't work - self.layer.weight -= self.layer.weight - self.layer.weight += W - - if is_module_offloaded(self.layer): - device = get_offloaded_device(self.layer) - update_prefix_dict(self.layer, "weight", self.layer.weight.to(device)) - self.layer._hf_hook.post_forward(self.layer, None) - - def free(self): - """ - Free the Hessian memory after the layer is complete - """ - delattr(self, "H") - super().free() - - def _update_quantization_parameters(self, args: QuantizationArgs, W: torch.Tensor): - """ - Update layer quantization parameters with potentially permuted weight - - :param args: quantization arguments - :param W: weight to calculate quantization parameters from - """ - observer = args.get_observer() - observer = Observer.load_from_registry(observer, quantization_args=args) - _scale, _zero_point = observer(W, g_idx=None) - update_parameter_data(self.layer, _scale, "weight_scale") - update_parameter_data(self.layer, _zero_point, "weight_zero_point") - - def _apply_activation_ordering( - self, W: torch.Tensor, H: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Permute weight and hessian in order of greatest outupt activations - - :param W: weight to permute - """ - perm = torch.argsort(torch.diag(H), descending=True) - return W[:, perm], H[perm][:, perm], perm - - def _log_metrics(self, start_tick: float, losses: torch.Tensor): - """ - Log metrics related to compression algorithm - - :param start_tick: time when algorithm started" - :param losses: loss as result of algorithm - """ - patch = logger.patch(lambda r: r.update(function="compress")) - patch.log("METRIC", "time %.2f" % (time.time() - start_tick)) - patch.log("METRIC", "error %.2f" % torch.sum(losses).item()) - - gpu_usage = get_GPU_memory_usage() - if len(gpu_usage) > 0: - for i in range(len(gpu_usage)): - perc = gpu_usage[i][0] * 100 - total_memory = int(gpu_usage[i][1]) # GB - patch.log( - "METRIC", - ( - f"GPU {i} | usage: {perc:.2f}%" - f" | total memory: {total_memory} GB" - ), - ) - - patch.log( - "METRIC", - f"Compressed layer size: {get_layer_size_bytes(self.layer)} MB", - ) diff --git a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py b/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py deleted file mode 100644 index 58fedc634..000000000 --- a/src/llmcompressor/modifiers/quantization/gptq/utils/helpers.py +++ /dev/null @@ -1,51 +0,0 @@ -from typing import Any, Iterable, List, Tuple, Union - -import torch - -__all__ = ["get_output_error"] - - -def get_output_error( - unquantized: List[Tuple[Union[Iterable, torch.Tensor], Any]], - quantized: List[Tuple[Union[Iterable, torch.Tensor], Any]], -) -> torch.Tensor: - """ - Calculate mean l1 loss between weight-unquantized outputs and weight-quantized - outputs - - :param unquantized: unquantized-weight outputs - :param quantized: quantized-weight outputs - :return: mean l1 loss between outputs - """ - unquantized_outputs = sum( - [ - [output for output in outputs] - if isinstance(outputs, Iterable) - else [outputs] - for outputs, _ in unquantized - ], - start=[], - ) - - quantized_outputs = sum( - [ - [output for output in outputs] - if isinstance(outputs, Iterable) - else [outputs] - for outputs, _ in quantized - ], - start=[], - ) - - if len(unquantized_outputs) != len(quantized_outputs): - raise ValueError( - "Number of samples of weight-unquantized and weight-quantized " - "outputs differs" - ) - - return sum( - [ - torch.nn.functional.l1_loss(unq, q) - for unq, q in zip(unquantized_outputs, quantized_outputs) - ] - ) / len(unquantized_outputs) diff --git a/src/llmcompressor/modifiers/utils/hooks.py b/src/llmcompressor/modifiers/utils/hooks.py new file mode 100644 index 000000000..267de838c --- /dev/null +++ b/src/llmcompressor/modifiers/utils/hooks.py @@ -0,0 +1,198 @@ +import contextlib +from abc import abstractmethod +from functools import partial +from typing import Any, Callable, ClassVar, Dict, List, Set, Tuple, Union + +import torch +from loguru import logger +from pydantic import BaseModel +from torch.utils.hooks import RemovableHandle +from collections import defaultdict + +from llmcompressor.modifiers.utils.pytorch_helpers import EarlyStopException +from llmcompressor.utils.helpers import getattr_chain +from llmcompressor.utils.metric_logging import CompressionLogger +from llmcompressor.utils.pytorch.module import get_layers, get_no_split_params + +__all__ = ["HooksMixin", "LayerCompressorMixin"] + + +class HooksMixin(BaseModel): + """ " + Class to manage the registration, disabling, and removal of hooks. Registering + and removing hooks should be handled by modifier classes which inherit from this + mixin, while disabling hooks should disable all hooks across modifiers. + + Modifiers which implement hooks should use the @HooksMixin.hook decorator + Modifiers must pass registered hooks handles to self.register_hook() and must + remove hooks when finished using self.remove_hooks() + + Lifecycle: + - Modifier.register_hooks(model) + - model.forward() + - Modifier.remove_hooks() + """ + + _HOOKS_DISABLED: ClassVar[bool] = False + _hooks: List[RemovableHandle] = [] + + @classmethod + def hook(cls, func: Callable[[Any], Any]): + def wrapped(*args, **kwargs): + if cls._HOOKS_DISABLED: + return + + return func(*args, **kwargs) + + return wrapped + + @classmethod + @contextlib.contextmanager + def disable_hooks(cls): + """ + Disable all hooks across all modifiers + TODO: select which modifier hooks are disabled/ kept enabled + """ + try: + cls._HOOKS_DISABLED = True + yield + finally: + cls._HOOKS_DISABLED = False + + def register_hook(self, handle: RemovableHandle): + """ + Usage: self.register_hook(module.register_forward_hook(...)) + + :param handle: handle of added hook + """ + self._hooks.append(handle) + + def remove_hooks(self): + """ + Remove all hooks belonging to a modifier + """ + for hook in self._hooks: + hook.remove() + + +class LayerCompressorMixin(HooksMixin): + """ + Apply a given compression function to a model during the model's calibration + forward pass + + Lifecycle: + - QuantizationModifier.initialize(model) + - Modifier.register_hooks(model) + - model.forward() + - compress_fn(name, target_module, args) + - Modifier.remove_hooks() + + :ivar true_sequential: Used to control the granularity of compression updates + through the forward pass. Set to True to use the weight-compressed outputs + of each module, set to False to use the weight-compressed outputs of each + layer (transformer block), defaults to False + :ivar sequential_targets: list of layer names to compress during GPTQ, or + '__ALL__' to compress every layer in the model + :ivar compresss_module: Function to be called on target modules + """ + + sequential_targets: bool + + _layer_index = 0 + _num_layers = 0 + + @abstractmethod + def pre_compress_module( + self, + name: str, + module: torch.nn.Module, + args: Tuple[torch.Tensor, ...], + ) -> float: + raise NotImplementedError() + + @abstractmethod + def compress_module( + self, + name: str, + module: torch.nn.Module, + args: Tuple[torch.Tensor, ...], + ) -> float: + raise NotImplementedError() + + def register_hooks(self, model: torch.nn.Module): + # find layers (used for printing even if true_sequential=True) + # if no targets are provided, default to the modules that shouldn't be + # split by FSDP. For Transformers models this is equivalent to the + # decoder layers (ie LlamaDecoderLayer) + # sequential_targets = self.sequential_targets + # if sequential_targets is None: + # sequential_targets = get_no_split_params(model) + # layers = get_layers(sequential_targets, model) + # self._num_layers = len(layers) + + for name, module in model.named_modules(): + if getattr_chain(module, "quantization_scheme.weights", None) is not None: + pre_hook = partial(self.target_pre_forward, name) + post_hook = partial(self.target_post_forward, name) + self.register_hook(module.register_forward_pre_hook(pre_hook, with_kwargs=True)) + self.register_hook(module.register_forward_hook(post_hook)) + + self.pre_compress_module(module) + + if "head" in name: + def hook(module: torch.nn.Module, args: Tuple[Any, ...]): + raise EarlyStopException(None, None) + + self.register_hook(module.register_forward_pre_hook(hook, with_kwargs=False)) + + # if name in layers.keys(): + # pre_hook = partial(self.layer_pre_forward, name) + # post_hook = partial(self.layer_post_forward, name) + # self.register_hook(module.register_forward_pre_hook(pre_hook, with_kwargs=True)) + # self.register_hook( + # module.register_forward_hook(post_hook, with_kwargs=True) + # ) + + + @HooksMixin.hook + def target_pre_forward( + self, name: str, module: torch.nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ): + # compress + print(f"compressing {name}") + if True: #self.true_sequential: + with CompressionLogger(module) as comp_logger: + loss = self.compress_module(name, module, args) + comp_logger.set_loss(loss) + + @HooksMixin.hook + def target_post_forward( + self, + name: str, + module: torch.nn.Module, + args: Tuple[Any, ...], + output: Tuple[Any, ...], + ): + print(f"post {name}") + + + @HooksMixin.hook + def layer_pre_forward(self, name: str, layer: torch.nn.Module, args: Any, kwargs): + logger.info( + f"\n===== Compressing layer {self._layer_index}/{self._num_layers} =====" + ) + + + @HooksMixin.hook + def layer_post_forward( + self, + name: str, + layer: torch.nn.Module, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + output: Tuple[Any, ...], + ): + print(f"post {name}") + self._layer_index += 1 + + return output diff --git a/src/llmcompressor/modifiers/utils/layer_compressor.py b/src/llmcompressor/modifiers/utils/layer_compressor.py index 3dd3caa7e..3f3aa3d02 100644 --- a/src/llmcompressor/modifiers/utils/layer_compressor.py +++ b/src/llmcompressor/modifiers/utils/layer_compressor.py @@ -4,7 +4,6 @@ import torch from compressed_tensors import get_execution_device from loguru import logger -from torch.nn import Module from tqdm import tqdm from llmcompressor.modifiers.utils.compression_wrapper import ModuleCompressionWrapper @@ -14,8 +13,7 @@ fix_fsdp_module_name, summon_full_params_context, ) -from llmcompressor.utils.pytorch import set_layer -from llmcompressor.utils.pytorch.module import get_prunable_layers +from llmcompressor.utils.pytorch.module import get_prunable_layers, set_layer __all__ = ["LayerCompressor"] @@ -45,8 +43,8 @@ class LayerCompressor: def __init__( self, module_compressor_class: ModuleCompressionWrapper, - model: Module, - layer: Module, + model: torch.nn.Module, + layer: torch.nn.Module, layer_index: int, name: str, args: Dict, diff --git a/src/llmcompressor/modifiers/utils/pytorch_helpers.py b/src/llmcompressor/modifiers/utils/pytorch_helpers.py index 9003ff22d..8c7b8b318 100644 --- a/src/llmcompressor/modifiers/utils/pytorch_helpers.py +++ b/src/llmcompressor/modifiers/utils/pytorch_helpers.py @@ -26,6 +26,8 @@ class EarlyStopException(Exception): """ def __init__(self, args: Tuple, kwargs: Dict): + if args is None: + return self.args = tensors_to_device(args, "cpu") self.kwargs = kwargs @@ -39,7 +41,7 @@ def apply_pad_mask_to_batch(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.T :param batch: batch to apply padding to if it exists :return: batch with padding zeroed out in the input_ids """ - batch["input_ids"] = batch["input_ids"] * batch["attention_mask"] + batch["input_ids"].masked_fill_(batch["attention_mask"] == 0, 0) return batch @@ -98,7 +100,8 @@ def run_calibration_forward( except EarlyStopException as e: # model was stopped early, save last calculated output and # move on to next calibration sample - intermediates.append((e.args, e.kwargs)) + #intermediates.append((e.args, e.kwargs)) + pass # TODO: not ideal, figure out where we aren't freeing memory instead # currently without this we run OOM on the 2nd forward pass diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index d4c3a6222..941306180 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -51,17 +51,20 @@ def __init__( self.padding = False if self.tokenizer: + if hasattr(self.tokenizer, "tokenizer"): + self.tokenizer = self.tokenizer.tokenizer + if not self.tokenizer.pad_token: self.tokenizer.pad_token = self.tokenizer.eos_token # configure sequence length max_seq_length = data_args.max_seq_length - model_max_length = tokenizer.model_max_length if tokenizer else max_seq_length + model_max_length = self.tokenizer.model_max_length if self.tokenizer else max_seq_length if self.tokenizer and max_seq_length > model_max_length: logger.warning( f"The max_seq_length passed ({max_seq_length}) is larger than " - f"the maximum length for the model ({tokenizer.model_max_length}). " - f"Using max_seq_length={tokenizer.model_max_length}." + f"the maximum length for the model ({self.tokenizer.model_max_length}). " + f"Using max_seq_length={self.tokenizer.model_max_length}." ) self.max_seq_length = min(data_args.max_seq_length, model_max_length) diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index 23c70e561..caa6b07d3 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -1,9 +1,11 @@ import logging import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, Optional +import datasets import torch from datasets import Dataset, load_dataset +from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers.data import default_data_collator @@ -11,20 +13,60 @@ LABELS_MASK_VALUE = -100 __all__ = [ + "create_batch_dataloader", "format_calibration_data", "get_raw_dataset", "make_dataset_splits", "get_custom_datasets_from_path", + "LABELS_MASK_VALUE", ] +def create_batch_dataloader( + dataloader: torch.utils.data.DataLoader, + batch_size: int, +) -> torch.utils.data.DataLoader: + """ + Create a dataloader whose batch size is equal to the size of the dataset + + :param dataset: dataset used to generate dataloader + :param batch_size: batch size of new dataloader + :return: dataloader + """ + dataset = dataloader.dataset + sampler = dataloader.sampler.__class__(dataset) + + def pad_sequences(batch): + # extract input_ids and attention_mask from the batch + input_ids = [torch.tensor(item["input_ids"]).squeeze(0) for item in batch] + masks = [torch.tensor(item["attention_mask"]).squeeze(0) for item in batch] + + # while 0 is not necessarily the "correct" padding value, the padded + # input_ids are ignored according to the attention_mask + pad_input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0) + pad_masks = pad_sequence(masks, batch_first=True, padding_value=0) + + return { + "input_ids": pad_input_ids, + "attention_mask": pad_masks, + } + + return torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + collate_fn=pad_sequences, + pin_memory=True, + ) + + def format_calibration_data( tokenized_dataset: Dataset, num_calibration_samples: Optional[int] = None, do_shuffle: bool = True, collate_fn: Callable = default_data_collator, accelerator: Optional[Any] = None, -) -> List[torch.Tensor]: +) -> torch.utils.data.DataLoader: """ Creates a dataloader out of the calibration dataset split, trimming it to the desired number of calibration samples diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index 1856ca954..8db9ceac0 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -23,7 +23,7 @@ from loguru import logger from transformers import ( AutoConfig, - AutoTokenizer, + AutoProcessor, DefaultDataCollator, HfArgumentParser, set_seed, @@ -221,12 +221,12 @@ def initialize_model_from_path( def initialize_tokenizer_from_path(model_args, model, teacher): tokenizer_src = model_args.tokenizer tokenizer_src = tokenizer_src or get_shared_tokenizer_src(model, teacher) - tokenizer = AutoTokenizer.from_pretrained( + tokenizer = AutoProcessor.from_pretrained( tokenizer_src, - cache_dir=model_args.cache_dir, - use_fast=True, - revision=model_args.model_revision, - use_auth_token=True if model_args.use_auth_token else None, + # cache_dir=model_args.cache_dir, + # use_fast=True, + # revision=model_args.model_revision, + # use_auth_token=True if model_args.use_auth_token else None, trust_remote_code=model_args.trust_remote_code_model, ) diff --git a/src/llmcompressor/utils/fsdp/helpers.py b/src/llmcompressor/utils/fsdp/helpers.py index 8cc0f5405..e58b4f1c3 100644 --- a/src/llmcompressor/utils/fsdp/helpers.py +++ b/src/llmcompressor/utils/fsdp/helpers.py @@ -1,9 +1,14 @@ +import contextlib +from functools import wraps import operator from pathlib import Path from typing import Optional +import warnings from loguru import logger +from llmcompressor.utils.helpers import getattr_chain + try: from torch.distributed.fsdp import ( FullStateDictConfig, @@ -20,6 +25,16 @@ from llmcompressor.pytorch.model_load.helpers import save_model_and_recipe from llmcompressor.utils.pytorch import set_layer +try: + from accelerate.hooks import AlignDevicesHook + from accelerate.utils import OffloadedWeightsLoader, PrefixedDataset, set_module_tensor_to_device + _has_accelerate = True +except ImportError: + _has_accelerate = False + AlignDevicesHook = None + OffloadedWeightsLoader = None + PrefixedDataset = None + __all__ = [ "is_fsdp_model", "maybe_get_wrapped", @@ -179,3 +194,240 @@ def get_fsdp_parent(layer_name: str, model: Module) -> Optional[Module]: parent = operator.attrgetter(parent_name)(model) return parent + +# upstream candidate +def has_offloaded_params(module: torch.nn.Module) -> bool: + """ + Checks if a module has offloaded parameters by checking if the given module + has a AlignDevicesHook attached with offloading enabled + + Args: + module (`torch.nn.Module`): The module to check for an offload hook. + + Returns: + bool: `True` if the module has an offload hook and offloading is enabled, + `False` otherwise. + """ + return ( + hasattr(module, "_hf_hook") and + isinstance(module._hf_hook, AlignDevicesHook) and + module._hf_hook.offload + ) + + +# depreciation candidate +@wraps(has_offloaded_params) +def is_module_offloaded(module: torch.nn.Module) -> bool: + if not _has_accelerate: + return False + + return has_offloaded_params(module) + + +# depreciation candidate +def get_execution_device(module: torch.nn.Module) -> torch.device: + """ + :param module: module to check + :return: device module is loaded onto during forward pass + """ + if is_module_offloaded(module): + return module._hf_hook.execution_device + device = next(module.parameters()).device + + # offload only gets set for leaf modules, fallback to checking for device type + if device.type == "meta": + return module._hf_hook.execution_device + + return device + + +# upstream candidate +def _infer_offload_device(module: torch.nn.Module) -> torch.device: + if not has_offloaded_params(module): + raise ValueError("Cannot infer offload device from non-offloaded module") + + first_key = next(module._hf_hook.weights_map.keys(), None) + if first_key is None: + raise ValueError("Cannot infer offload device from empty weights map") + + prefix_dataset = module._hf_hook.weights_map.dataset + return prefix_dataset[first_key].device + +# depreciation candidate +def get_offloaded_device(module: torch.nn.Module) -> torch.device: + """ + :param module: module to check + :return: device module is offloaded to onto after forward pass + """ + return _infer_offload_device(module) + + +# depreciation candidate +def update_prefix_dict(module: torch.nn.Module, key: str, data: torch.Tensor): + """ + Updates the offloaded state dict for a given module. Parameter named key is replaced + by data. This is neccesary because parameter updates for offloaded modules do not + persist automatically between loads. This function only affects the offloaded + state dict and not the current state of the loaded module. + + :param module: module containing the parameter to update + :param key: name of parameter to update + :param data: tensor to update parameter with in the offloaded state dict + """ + if not is_module_offloaded(module): + raise ValueError("Prefix dict is only applicable to offloaded modules") + prefix_dict = module._hf_hook.weights_map + prefix_dict.dataset[f"{prefix_dict.prefix}{key}"] = data + + +# upstream candidate? +def update_offload_parameter( + module: torch.nn.Module, + name: str, + data: Optional[torch.Tensor] = None, + offload_device: Optional[torch.device] = None, +): + """ + :param module: module containing the parameter to update + :param name: name of module parameter to update + :param data: tensor to update parameter with + :param offload_device: offload device for newly registered parameters + """ + param = getattr(module, name) + if data is not None: + if data.device == "meta": + raise ValueError("Cannot copy data from meta device. Consider calling with align_module(module) context") + + if param.data.dtype != data.dtype: + warnings.warn("TODO") + + param.data.copy_(data) + + if has_offloaded_params(module): + weights_map = module._hf_hook.weights_map + + # for upstreaming, probably better to modify the weight map types so that they can be written to? + if isinstance(weights_map, PrefixedDataset): + prefix_dict = getattr_chain(module, "module._hf_hook.weights_map.dataset", None) + if prefix_dict is not None: + prefix = module._hf_hook.weights_map.prefix + key = f"{prefix}{name}" + + offload_device = ( + prefix_dict[key].device if key in prefix_dict + else offload_device if offload_device is not None + else _infer_offload_device(module) + ) + prefix_dict[key] = param.data.to(device=offload_device) + + if isinstance(weights_map, OffloadedWeightsLoader): + raise NotImplementedError() + + else: + raise NotImplementedError() + +# depreciation candidate +def update_parameter_data( + module: torch.nn.Module, new_param_data: torch.Tensor, param_name: str +): + param = getattr(module, param_name) + new_param_data = new_param_data.to(device=param.device, dtype=param.dtype) + update_offload_parameter(module, param_name, new_param_data) + + +# upstream candidate +@contextlib.contextmanager +def align_module(module: torch.nn.Module, execution_device: Optional[torch.device] = None): + """ + Moves a module's parameters to the specified execution device. + + Args: + module (torch.nn.Module): Module with parameters to align. + execution_device (Optional[torch.device]): If provided, overrides the + module's execution device within the context. + + Yields: + None: Yields control while the module's parameters are aligned to the execution device. + """ + if has_offloaded_params(module): + if execution_device is not None: + original_device = module._hf_hook.execution_device + module._hf_hook.execution_device = execution_device + + module._hf_hook.pre_forward(module) + yield + module._hf_hook.post_forward(module, None) + + if execution_device is not None: + module._hf_hook.execution_device = original_device + + elif execution_device is not None: + devices = {} + for name, param in module.named_parameters(): + devices[name] = param.device + set_module_tensor_to_device( + module, + name, + execution_device, + ) + + yield + + for name, param in module.named_parameters(): + set_module_tensor_to_device( + module, + name, + devices[name], + ) + + else: + yield + + + +@contextlib.contextmanager +def modify_offload_module( + module: torch.nn.Module, + execution_device: Optional[torch.device] = None, + offload_device: Optional[torch.device] = None, +): + with align_module(module, execution_device): + yield + + # there is little performance gain from checking if a parameter's data + # has been modified before copying since the new data must be copied + # to the offload device anyways; just update all module parameters + for name, param in module.named_parameters(): + update_offload_parameter(module, name, param.data, offload_device) + + +# upstream candidate? +def register_offload_parameter( + module: torch.nn.Module, + name: str, + parameter: torch.nn.Parameter, + offload_device: Optional[torch.device] = None, +): + module.register_parameter(name, parameter) + update_offload_parameter(module, name, parameter.data, offload_device) + + +# upstream candidate? +def delete_offload_parameter(module: torch.nn.Module, name: str): + delattr(module, name) + + if has_offloaded_params(module): + weights_map = module._hf_hook.weights_map + + # for upstreaming, probably better to modify the weight map types so that they can be written to? + if isinstance(weights_map, PrefixedDataset): + dataset = weights_map.dataset + prefix = weights_map.prefix + if dataset is not None: + del dataset[f"{prefix}{name}"] + + elif isinstance(weights_map, OffloadedWeightsLoader): + raise NotImplementedError() + + elif weights_map is not None: + raise NotImplementedError(f"Cannot delete parameter from weights_map of type {type(weights_map)}") \ No newline at end of file diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index 266acf973..a45db89c1 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -4,6 +4,7 @@ """ import ast +import contextlib import errno import fnmatch import glob @@ -18,10 +19,13 @@ from collections import OrderedDict from io import BytesIO from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from urllib.parse import urlparse import numpy +import torch +from compressed_tensors import is_module_offloaded +from compressed_tensors.quantization import disable_quantization, enable_quantization from loguru import logger __all__ = [ @@ -59,6 +63,7 @@ "is_package_available", "import_from_path", "getattr_chain", + "DisableKVCache", ] @@ -1041,3 +1046,95 @@ def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any: res = getattr(res, attr_name) return res + + +class DisableKVCache: + def __init__(self, model: torch.nn.Module): + if hasattr(model.config, "use_cache"): + self.config = model.config + + # MllamaConfig + elif hasattr(model.config, "text_config") and hasattr( + model.config.text_config, "use_cache" + ): + self.config = model.config.text_config + + # unknown config structure + else: + raise NotImplementedError( + f"Cannot find `use_cache` for config of type {type(model.config)}" + ) + + self.restore_value = self.config.use_cache + + def __enter__(self): + self.restore_value = self.config.use_cache + self.config.use_cache = False + + def __exit__(self, _exc_type, _exc_val, _exc_tb): + self.config.use_cache = self.restore_value + + +@contextlib.contextmanager +def DisableQuantization(model: torch.nn.Module): + """ + Disable quantization from QuantizationModifier + """ + model.apply(disable_quantization) + yield + model.apply(enable_quantization) + + +@contextlib.contextmanager +def calibration_forward_context(model: torch.nn.Module): + """ + Context in which all calibration forward passes should occur. + + - Remove gradient calculations + - Disable the KV cache + - Disable quantization from QuantizationModifier + """ + model.eval() + + with ( + torch.no_grad(), + DisableKVCache(model), + DisableQuantization(model), + ): + yield + + +@contextlib.contextmanager +def align_module(module: torch.nn.Module, device: Optional[torch.device] = None): + """ + Move an offloaded module's parameters to device or module execution device + + :param module: module with parameters to align + :param device: optional device to move parameters to, if None is provided then + module execution device will be used + """ + if is_module_offloaded(module): + if device is not None: + original_device = module._hf_hook.execution_device + module._hf_hook.execution_device = device + + module._hf_hook.pre_forward(module) + yield + module._hf_hook.post_forward(module, torch.tensor([])) + + if device is not None: + module._hf_hook.execution_device = original_device + + elif device is not None: + devices = {} + for name, param in module.named_parameters(recurse=False): + devices[name] = param.device + setattr(module, name, param.to(device)) + + yield + + for name, param_device in module.named_parameters: + setattr(module, name, param.to(param_device)) + + else: + yield diff --git a/src/llmcompressor/utils/metric_logging.py b/src/llmcompressor/utils/metric_logging.py index d0b3bb11e..b23ba200a 100644 --- a/src/llmcompressor/utils/metric_logging.py +++ b/src/llmcompressor/utils/metric_logging.py @@ -1,7 +1,10 @@ +import time from typing import List, Tuple +import torch from loguru import logger -from torch.nn import Module + +__all__ = ["CompressionLogger"] def get_GPU_memory_usage() -> List[Tuple]: @@ -35,7 +38,7 @@ def get_GPU_memory_usage() -> List[Tuple]: return [] -def get_layer_size_bytes(module: Module) -> float: +def get_module_size_bytes(module: torch.nn.Module) -> float: param_size = 0 buffer_size = 0 @@ -49,3 +52,50 @@ def get_layer_size_bytes(module: Module) -> float: total_size_mb = total_size / (1024**2) # Convert bytes to MB return total_size_mb + + +class CompressionLogger: + """ + Log metrics related to compression algorithm + + :param start_tick: time when algorithm started" + :param losses: loss as result of algorithm + """ + + def __init__(self, module: torch.nn.Module): + self.module = module + self.start_tick = None + self.loss = None + + def set_loss(self, loss: float): + self.loss = loss + + def __enter__(self) -> "CompressionLogger": + self.start_tick = time.time() + return self + + def __exit__(self, _exc_type, _exc_val, _exc_tb): + stop_tick = time.time() + patch = logger.patch(lambda r: r.update(function="compress")) + + if self.start_tick is not None: + duration = stop_tick - self.start_tick + patch.log("METRIC", f"time {duration:.2f}") + if self.loss is not None: + patch.log("METRIC", f"error {self.loss:.2f}") + + gpu_usage = get_GPU_memory_usage() + if len(gpu_usage) > 0: + for i in range(len(gpu_usage)): + perc = gpu_usage[i][0] * 100 + total_memory = int(gpu_usage[i][1]) # GB + patch.log( + "METRIC", + ( + f"GPU {i} | usage: {perc:.2f}%" + f" | total memory: {total_memory} GB" + ), + ) + + compressed_size = get_module_size_bytes(self.module) + patch.log("METRIC", f"Compressed module size: {compressed_size} MB") diff --git a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py index 5421af4cf..3cdc25038 100644 --- a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -75,15 +75,15 @@ def test_create_default_quant_modifier(self): kwargs = dict(block_size=128) modifier = GPTQModifier(**kwargs) - assert modifier.quantization_modifier_ is None + assert modifier._quantization_modifier is None testing_harness = LifecyleTestingHarness(model=LinearNet()) modifier.on_initialize_structure(testing_harness.get_state()) assert modifier.quantize - assert isinstance(modifier.quantization_modifier_, QuantizationModifier) - modifier.quantization_modifier_.create_init_config() + assert isinstance(modifier._quantization_modifier, QuantizationModifier) + modifier._quantization_modifier.create_init_config() default_config_group_name = "group_0" - should_be_default_quant_scheme = modifier.quantization_modifier_.config_groups[ + should_be_default_quant_scheme = modifier._quantization_modifier.config_groups[ default_config_group_name ] assert should_be_default_quant_scheme.input_activations is None @@ -113,7 +113,7 @@ def test_set_quant_if_modifer_already_exists(self): kwargs = dict(block_size=128) modifier = GPTQModifier(**kwargs) - assert not modifier.quantization_modifier_ + assert not modifier._quantization_modifier modifier.on_initialize_structure(testing_harness.get_state()) # since quantization modifier is already applied, quantization must be set in @@ -150,14 +150,14 @@ def test_set_quant_in_gptq(self): kwargs = dict(block_size=128, quantize=self.quant_config) modifier = GPTQModifier(**kwargs) - assert modifier.quantization_modifier_ is None + assert modifier._quantization_modifier is None testing_harness = LifecyleTestingHarness(model=LinearNet()) modifier.on_initialize_structure(testing_harness.get_state()) assert modifier.quantize - self.assertIsInstance(modifier.quantization_modifier_, QuantizationModifier) + self.assertIsInstance(modifier._quantization_modifier, QuantizationModifier) - dict_scheme = dict(modifier.quantization_modifier_.config_groups) + dict_scheme = dict(modifier._quantization_modifier.config_groups) self._check_config( dict(dict_scheme["config_group_0"].weights), self.quant_kwargs["config_groups"]["config_group_0"]["weights"],