From 29f83bb7ff78ac6647ba327e1696f2d1df8aeb94 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Tue, 5 Mar 2024 21:13:04 +0000 Subject: [PATCH] Add state dict translation methods --- .../transformers/utils/transformations.py | 253 ++++++++++++++++++ .../transformers/utils/vllm_export_helpers.py | 170 ++++++++++++ src/sparseml/utils/helpers.py | 21 ++ 3 files changed, 444 insertions(+) create mode 100644 src/sparseml/transformers/utils/transformations.py create mode 100644 src/sparseml/transformers/utils/vllm_export_helpers.py diff --git a/src/sparseml/transformers/utils/transformations.py b/src/sparseml/transformers/utils/transformations.py new file mode 100644 index 00000000000..aceff72346d --- /dev/null +++ b/src/sparseml/transformers/utils/transformations.py @@ -0,0 +1,253 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# flake8: noqa #F821,#E501 + +import functools +import logging +from typing import Dict + +import numpy +import torch +from torch import Tensor + + +__all__ = [ + "transform_names", + "add_tensors", + "transform_tensors", + "remove_unwanted_tensors", + "is_quantization_target", +] + +_LOGGER = logging.getLogger(__name__) + + +def _log_call(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + _LOGGER.info("Applying transformation: %s", func.__name__.upper()) + return_value = func(*args, **kwargs) + _LOGGER.info("Transformation: %s complete", func.__name__.upper()) + return return_value + + return wrapper + + +def is_quantization_target(key: str) -> bool: + """ + Assumes self_attn and mlp are the only quantization targets + in model layers of the state_dict. + + :param key: The key of the state_dict + :return: True if the key is a quantization target, False otherwise + """ + return "model.layers" in key and ("self_attn" in key or "mlp" in key) + + +@_log_call +def transform_names(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: + """ + Transforms the state_dict keys to match with exllama format + + The renames include: + - weight_fake_quant.scale -> scales + - weight_fake_quant.zero_point -> qzeros + - weight -> qweight + + Note: does not transforms the actual tensor values + + :pre-condition: The state_dict should be for a quantized model + :pre-condition: Targets only the weights of the self_attn and mlp nodes + :param state_dict: The quantized state_dict to be transformed + :return: The transformed state_dict + """ + # mapping of the old names to the new names + name_map: Dict[str, str] = { + ".weight_fake_quant.scale": ".scales", + ".weight_fake_quant.zero_point": ".qzeros", + ".weight": ".qweight", + } + + new_state_dict: Dict[str, Tensor] = {} + for key, tensor in state_dict.items(): + if is_quantization_target(key) and any( + key.endswith(target_suffix := suffix) for suffix in name_map + ): + updated_key = key.replace(target_suffix, name_map[target_suffix]) + new_state_dict[updated_key] = tensor + else: + new_state_dict[key] = tensor + return new_state_dict + + +def pack(weight: Tensor, scales: Tensor, zeros: Tensor, g_idx: Tensor) -> Tensor: + """ + Quantize the weight tensor using the scales, zeros, and g_idx tensors + into 4 bit integers, and packs a group of 8 of them into a single 32 bit integer. + + Adapted from: + https://github.com/AutoGPTQ/AutoGPTQ/blob/ea4a99778f90b60c9b5177d7487af1b4ca87744f/auto_gptq/nn_modules/qlinear/qlinear_exllama.py#L118 + + :param weight: The weight tensor to be quantized of shape [x, 8y] + :param scales: The scales tensor + :param zeros: The zero points tensor + :param g_idx: The group index tensor + :return: The quantized weight tensor of int32 dtype and shape [x, y] + """ + g_idx = g_idx.clone() + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + scales = scales.clone().half() + bits = 4 + + intweight = [] + infeatures = weight.shape[1] + for idx in range(infeatures): + intweight.append( + torch.round( + (weight[:, idx] + scale_zeros[g_idx[idx]]) / scales[g_idx[idx]] + ).to(torch.int)[:, None] + ) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(numpy.uint32) + + i = 0 + row = 0 + qweight = numpy.zeros( + (intweight.shape[0] // 32 * bits, intweight.shape[1]), dtype=numpy.uint32 + ) + while row < qweight.shape[0]: + if bits in [4]: + for j in range(i, i + (32 // bits)): + qweight[row] |= intweight[j] << (bits * (j - i)) + i += 32 // bits + row += 1 + else: + raise NotImplementedError("Only 4 bits are supported.") + + qweight = qweight.astype(numpy.int32) + qweight = torch.from_numpy(qweight) + return qweight + + +@_log_call +def add_tensors(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: + + new_dict: Dict[str, Tensor] = {} + + for key, tensor in state_dict.items(): + if is_quantization_target(key) and key.endswith(".qweight"): + # add bias and g_idx tensors + bias_key = key.replace(".qweight", ".bias") + g_idx_key = key.replace(".qweight", ".g_idx") + + # bias tensor + bias_tensor = torch.zeros(tensor.shape[0], dtype=torch.float16) + new_dict[bias_key] = bias_tensor + + # g_idx tensor of shape [num_channels] dtype int32 filled + # with zeros + g_idx_tensor = torch.zeros(tensor.shape[1], dtype=torch.int32) + new_dict[g_idx_key] = g_idx_tensor + + # copy the original tensor, (qweight is also copied in this step) + new_dict[key] = tensor + return new_dict + + +@_log_call +def transform_tensors(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: + + new_dict: Dict[str, Tensor] = {} + + # auxillary dict to store transformed weights + weights_dict: Dict[str, Tensor] = {} + + # quantize qweights before scales, and qzeros + # because the ordering is not guaranteed + # in our implementation + for key, tensor in state_dict.items(): + if is_quantization_target(key) and key.endswith(".qweight"): + # quantize the weight tensor + qweight = pack( + weight=tensor, + scales=state_dict[key.replace("qweight", "scales")], + zeros=state_dict[key.replace("qweight", "qzeros")], + g_idx=state_dict[key.replace("qweight", "g_idx")], + ) + assert qweight.dtype == torch.int32 + weights_dict[key] = qweight + + # transform scales and zero points + for key, tensor in state_dict.items(): + if is_quantization_target(key) and key.endswith(".scales"): + # scales [x] should be reshaped to [1, x] + # and converted to fp16 + scales = tensor.reshape(1, -1).to(torch.float16) + new_dict[key] = scales + elif is_quantization_target(key) and key.endswith(".qzeros"): + # zero points [8x] should be reshaped to [1, x] + # of type int32 and filled with zeros (symmetric quantization) + zeros = torch.zeros(tensor.shape[0] // 8, dtype=torch.int32) + new_dict[key] = zeros.reshape(1, -1) + else: + new_dict[key] = tensor + + # overwrite old weights with the new quantized weights + new_dict.update(weights_dict) + + # auxillary weights_dict not needed anymore + del weights_dict + + return new_dict + + +@_log_call +def remove_unwanted_tensors(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: + """ + Remove unwanted tensors from the state_dict that are not necessary for inference. + These tensors include: + - eps + - min_val + - max_val + - fake_quant_enabled + - observer_enabled + + """ + to_delete = ["eps", "min_val", "max_val", "fake_quant_enabled", "observer_enabled"] + keys = list(state_dict.keys()) + for key in keys: + if any(key.endswith(suffix) for suffix in to_delete): + del state_dict[key] + return state_dict + + +def check_dicts(actual, expected): + assert len(actual) == len( + expected + ), "The number of tensors in the actual and expected state dicts do not match" + + for key, value in actual.items(): + assert ( + key in expected + ), f"The key {key} is not present in the expected state dict" + assert ( + value.shape == expected[key].shape + ), f"The shape of the tensor {key} in the actual state dict does not match the shape of the tensor in the expected state dict, expected {expected[key].shape} but got {value.shape}" + assert ( + value.dtype == expected[key].dtype + ), f"The dtype of the tensor {key} in the actual state dict does not match the dtype of the tensor in the expected state dict, expected {expected[key].dtype} but got {value.dtype}" diff --git a/src/sparseml/transformers/utils/vllm_export_helpers.py b/src/sparseml/transformers/utils/vllm_export_helpers.py new file mode 100644 index 00000000000..ddf24ebebc3 --- /dev/null +++ b/src/sparseml/transformers/utils/vllm_export_helpers.py @@ -0,0 +1,170 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +General utilities for exporting models to different formats using safe tensors. +""" +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, Literal, Optional, Union + +from torch import Tensor +from transformers import PreTrainedModel, PreTrainedTokenizerBase + +from sparseml.transformers.utils.transformations import ( + add_tensors, + remove_unwanted_tensors, + transform_names, + transform_tensors, +) +from sparseml.utils import get_unique_dir_name + + +__all__ = [ + "export_safetensors", + "SUPPORTED_FORMAT_TYPES", +] + +SUPPORTED_FORMAT_TYPES = Literal["exllama", "marlin"] +_LOGGER = logging.getLogger(__name__) + + +def export_safetensors( + model: PreTrainedModel, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + format: SUPPORTED_FORMAT_TYPES = "exllama", + save_dir: Union[str, Path, None] = None, +): + """ + A utility function to export a model to safetensor format. + Calls the appropriate state dict translation function based on the format + and saves the translated state dict to the specified directory. + If the directory is defaults to cwd/exported_model. If the directory already exists, + a new directory is created with a unique name. + + :param model: The loaded model to be exported. + :param tokenizer: The tokenizer associated with the model. + :param format: The format to which the model should be exported. + Default is "exllama". + :param save_dir: The directory where the model should be saved. + """ + # original state dict + state_dict: Dict[Any, Any] = model.state_dict() + + # translation + _LOGGER.info(f"Adding {format} quantization info to config") + model.config.quantization_config = QuantizationConfig() + + _LOGGER.info(f"Translating state dict to {format} format.") + translated_state_dict: Dict[Any, Any] = translate_state_dict( + state_dict=state_dict, format=format + ) + + if save_dir is None: + save_dir = Path.cwd() / f"{format}_model" + + save_dir: str = get_unique_dir_name(dir_name=save_dir) + + # saving + save_checkpoint( + model=model, + tokenizer=tokenizer, + state_dict=translated_state_dict, + save_dir=save_dir, + ) + + +def save_checkpoint( + model: PreTrainedModel, + state_dict: Dict[Any, Any], + save_dir: str, + tokenizer: Optional[PreTrainedTokenizerBase] = None, +): + model.save_pretrained( + save_directory=save_dir, state_dict=state_dict, safe_serialization=True + ) + _LOGGER.info(f"Model and config saved to {save_dir}") + + if tokenizer: + tokenizer.save_pretrained(save_directory=save_dir) + _LOGGER.info(f"tokenizer saved to {save_dir}") + + +def translate_state_dict( + state_dict: Dict[Any, Any], format: SUPPORTED_FORMAT_TYPES +) -> Dict[Any, Any]: + """ + A utility function to translate the state dict to the specified format. + + :param state_dict: The state dict to be translated. + :param format: The format to which the state dict should be translated. + """ + if format == "exllama": + return _translate_state_dict_exllama(state_dict=state_dict) + + +def _translate_state_dict_exllama(state_dict: Dict[str, Any]) -> Dict[Any, Any]: + """ + Translate the state dict to the Exllama format. + + Changes made to quantized params in the passed state_dict: + - weight tensor renamed to qweight, and the corresponding tensor + value of shape [x, 8y] will be repacked to [x, y] + - scale tensor renamed to scales, and the corresponding tensor + value of shape [8x] will be reshaped to [1, 8x] and + then repacked to [1, x] + - zero_point tensor renamed to qzeros, and the corresponding tensor + value of shape [x] will be reshaped to [1, x] + - A g_idx tensor of shape [num_channels] will be added to the + state_dict, this tensor will be filled with zeros + - All fake quantization parameters will be removed from the state_dict + + + + + :param state_dict: The model state dict to be translated. + :return: The translated state dict compatible with Exllama. + """ + + transformations = ( + transform_names, + add_tensors, + transform_tensors, + remove_unwanted_tensors, + ) + for transformation in transformations: + state_dict_copy: Dict[str, Tensor] = transformation(state_dict=state_dict) + + return state_dict_copy + + +@dataclass +class QuantizationConfig: + bits: int = field(default=4, metadata={"choices": [2, 3, 4, 8]}) + group_size: int = field(default=-1) + damp_percent: float = field(default=0.01) + desc_act: bool = field(default=False) + sym: bool = field(default=True) + is_marlin_format: bool = field(default=False) + + def to_dict(self): + return { + "bits": self.bits, + "group_size": self.group_size, + "desc_act": self.desc_act, + "sym": self.sym, + "is_marlin_format": self.is_marlin_format, + "quant_method": "gptq", + } diff --git a/src/sparseml/utils/helpers.py b/src/sparseml/utils/helpers.py index 6c1d4f3ad6c..0c5b4d8330e 100644 --- a/src/sparseml/utils/helpers.py +++ b/src/sparseml/utils/helpers.py @@ -74,6 +74,7 @@ "parse_kwarg_tuples", "download_zoo_training_dir", "is_package_available", + "get_unique_dir_name", ] @@ -974,3 +975,23 @@ def is_package_available( return package_exists, package_version else: return package_exists + + +def get_unique_dir_name(dir_name: Union[str, Path]) -> str: + """ + A utility function to get a unique directory name by appending + a number to the directory name if the directory already exists + (Note: the function does not create the directory, it only + returns the unique directory name) + + :param dir_name: The directory name to get a unique name for + :return: The unique directory name + """ + dir_name: str = str(dir_name) + counter: str = 1 + new_dir_name: str = dir_name + + while Path(new_dir_name).exists(): + new_dir_name = f"{dir_name}_{counter}" + counter += 1 + return new_dir_name