-
Notifications
You must be signed in to change notification settings - Fork 149
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
121d7fe
commit 29f83bb
Showing
3 changed files
with
444 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,253 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# flake8: noqa #F821,#E501 | ||
|
||
import functools | ||
import logging | ||
from typing import Dict | ||
|
||
import numpy | ||
import torch | ||
from torch import Tensor | ||
|
||
|
||
__all__ = [ | ||
"transform_names", | ||
"add_tensors", | ||
"transform_tensors", | ||
"remove_unwanted_tensors", | ||
"is_quantization_target", | ||
] | ||
|
||
_LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
def _log_call(func): | ||
@functools.wraps(func) | ||
def wrapper(*args, **kwargs): | ||
_LOGGER.info("Applying transformation: %s", func.__name__.upper()) | ||
return_value = func(*args, **kwargs) | ||
_LOGGER.info("Transformation: %s complete", func.__name__.upper()) | ||
return return_value | ||
|
||
return wrapper | ||
|
||
|
||
def is_quantization_target(key: str) -> bool: | ||
""" | ||
Assumes self_attn and mlp are the only quantization targets | ||
in model layers of the state_dict. | ||
:param key: The key of the state_dict | ||
:return: True if the key is a quantization target, False otherwise | ||
""" | ||
return "model.layers" in key and ("self_attn" in key or "mlp" in key) | ||
|
||
|
||
@_log_call | ||
def transform_names(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: | ||
""" | ||
Transforms the state_dict keys to match with exllama format | ||
The renames include: | ||
- weight_fake_quant.scale -> scales | ||
- weight_fake_quant.zero_point -> qzeros | ||
- weight -> qweight | ||
Note: does not transforms the actual tensor values | ||
:pre-condition: The state_dict should be for a quantized model | ||
:pre-condition: Targets only the weights of the self_attn and mlp nodes | ||
:param state_dict: The quantized state_dict to be transformed | ||
:return: The transformed state_dict | ||
""" | ||
# mapping of the old names to the new names | ||
name_map: Dict[str, str] = { | ||
".weight_fake_quant.scale": ".scales", | ||
".weight_fake_quant.zero_point": ".qzeros", | ||
".weight": ".qweight", | ||
} | ||
|
||
new_state_dict: Dict[str, Tensor] = {} | ||
for key, tensor in state_dict.items(): | ||
if is_quantization_target(key) and any( | ||
key.endswith(target_suffix := suffix) for suffix in name_map | ||
): | ||
updated_key = key.replace(target_suffix, name_map[target_suffix]) | ||
new_state_dict[updated_key] = tensor | ||
else: | ||
new_state_dict[key] = tensor | ||
return new_state_dict | ||
|
||
|
||
def pack(weight: Tensor, scales: Tensor, zeros: Tensor, g_idx: Tensor) -> Tensor: | ||
""" | ||
Quantize the weight tensor using the scales, zeros, and g_idx tensors | ||
into 4 bit integers, and packs a group of 8 of them into a single 32 bit integer. | ||
Adapted from: | ||
https://github.com/AutoGPTQ/AutoGPTQ/blob/ea4a99778f90b60c9b5177d7487af1b4ca87744f/auto_gptq/nn_modules/qlinear/qlinear_exllama.py#L118 | ||
:param weight: The weight tensor to be quantized of shape [x, 8y] | ||
:param scales: The scales tensor | ||
:param zeros: The zero points tensor | ||
:param g_idx: The group index tensor | ||
:return: The quantized weight tensor of int32 dtype and shape [x, y] | ||
""" | ||
g_idx = g_idx.clone() | ||
|
||
scales = scales.t().contiguous() | ||
zeros = zeros.t().contiguous() | ||
scale_zeros = zeros * scales | ||
scales = scales.clone().half() | ||
bits = 4 | ||
|
||
intweight = [] | ||
infeatures = weight.shape[1] | ||
for idx in range(infeatures): | ||
intweight.append( | ||
torch.round( | ||
(weight[:, idx] + scale_zeros[g_idx[idx]]) / scales[g_idx[idx]] | ||
).to(torch.int)[:, None] | ||
) | ||
intweight = torch.cat(intweight, dim=1) | ||
intweight = intweight.t().contiguous() | ||
intweight = intweight.numpy().astype(numpy.uint32) | ||
|
||
i = 0 | ||
row = 0 | ||
qweight = numpy.zeros( | ||
(intweight.shape[0] // 32 * bits, intweight.shape[1]), dtype=numpy.uint32 | ||
) | ||
while row < qweight.shape[0]: | ||
if bits in [4]: | ||
for j in range(i, i + (32 // bits)): | ||
qweight[row] |= intweight[j] << (bits * (j - i)) | ||
i += 32 // bits | ||
row += 1 | ||
else: | ||
raise NotImplementedError("Only 4 bits are supported.") | ||
|
||
qweight = qweight.astype(numpy.int32) | ||
qweight = torch.from_numpy(qweight) | ||
return qweight | ||
|
||
|
||
@_log_call | ||
def add_tensors(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: | ||
|
||
new_dict: Dict[str, Tensor] = {} | ||
|
||
for key, tensor in state_dict.items(): | ||
if is_quantization_target(key) and key.endswith(".qweight"): | ||
# add bias and g_idx tensors | ||
bias_key = key.replace(".qweight", ".bias") | ||
g_idx_key = key.replace(".qweight", ".g_idx") | ||
|
||
# bias tensor | ||
bias_tensor = torch.zeros(tensor.shape[0], dtype=torch.float16) | ||
new_dict[bias_key] = bias_tensor | ||
|
||
# g_idx tensor of shape [num_channels] dtype int32 filled | ||
# with zeros | ||
g_idx_tensor = torch.zeros(tensor.shape[1], dtype=torch.int32) | ||
new_dict[g_idx_key] = g_idx_tensor | ||
|
||
# copy the original tensor, (qweight is also copied in this step) | ||
new_dict[key] = tensor | ||
return new_dict | ||
|
||
|
||
@_log_call | ||
def transform_tensors(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: | ||
|
||
new_dict: Dict[str, Tensor] = {} | ||
|
||
# auxillary dict to store transformed weights | ||
weights_dict: Dict[str, Tensor] = {} | ||
|
||
# quantize qweights before scales, and qzeros | ||
# because the ordering is not guaranteed | ||
# in our implementation | ||
for key, tensor in state_dict.items(): | ||
if is_quantization_target(key) and key.endswith(".qweight"): | ||
# quantize the weight tensor | ||
qweight = pack( | ||
weight=tensor, | ||
scales=state_dict[key.replace("qweight", "scales")], | ||
zeros=state_dict[key.replace("qweight", "qzeros")], | ||
g_idx=state_dict[key.replace("qweight", "g_idx")], | ||
) | ||
assert qweight.dtype == torch.int32 | ||
weights_dict[key] = qweight | ||
|
||
# transform scales and zero points | ||
for key, tensor in state_dict.items(): | ||
if is_quantization_target(key) and key.endswith(".scales"): | ||
# scales [x] should be reshaped to [1, x] | ||
# and converted to fp16 | ||
scales = tensor.reshape(1, -1).to(torch.float16) | ||
new_dict[key] = scales | ||
elif is_quantization_target(key) and key.endswith(".qzeros"): | ||
# zero points [8x] should be reshaped to [1, x] | ||
# of type int32 and filled with zeros (symmetric quantization) | ||
zeros = torch.zeros(tensor.shape[0] // 8, dtype=torch.int32) | ||
new_dict[key] = zeros.reshape(1, -1) | ||
else: | ||
new_dict[key] = tensor | ||
|
||
# overwrite old weights with the new quantized weights | ||
new_dict.update(weights_dict) | ||
|
||
# auxillary weights_dict not needed anymore | ||
del weights_dict | ||
|
||
return new_dict | ||
|
||
|
||
@_log_call | ||
def remove_unwanted_tensors(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: | ||
""" | ||
Remove unwanted tensors from the state_dict that are not necessary for inference. | ||
These tensors include: | ||
- eps | ||
- min_val | ||
- max_val | ||
- fake_quant_enabled | ||
- observer_enabled | ||
""" | ||
to_delete = ["eps", "min_val", "max_val", "fake_quant_enabled", "observer_enabled"] | ||
keys = list(state_dict.keys()) | ||
for key in keys: | ||
if any(key.endswith(suffix) for suffix in to_delete): | ||
del state_dict[key] | ||
return state_dict | ||
|
||
|
||
def check_dicts(actual, expected): | ||
assert len(actual) == len( | ||
expected | ||
), "The number of tensors in the actual and expected state dicts do not match" | ||
|
||
for key, value in actual.items(): | ||
assert ( | ||
key in expected | ||
), f"The key {key} is not present in the expected state dict" | ||
assert ( | ||
value.shape == expected[key].shape | ||
), f"The shape of the tensor {key} in the actual state dict does not match the shape of the tensor in the expected state dict, expected {expected[key].shape} but got {value.shape}" | ||
assert ( | ||
value.dtype == expected[key].dtype | ||
), f"The dtype of the tensor {key} in the actual state dict does not match the dtype of the tensor in the expected state dict, expected {expected[key].dtype} but got {value.dtype}" |
Oops, something went wrong.