From 87798fd002eab89704ed4993f2e87f8ed44dea06 Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Fri, 19 Jul 2024 13:07:38 -0700 Subject: [PATCH] Update quantize.py to use AO's int4 quantizer (#919) * Use ao's int4 quantizer * Point AO to commit hash of Jerry's fix * When device is cuda, only run for dtype==bfloat16 Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Typo Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Use tensor subclass for int4 weight only quant * Fix bug * Fix * Use both quantizer and subclass API * Bug * unwrap tensor subclass for aoti * Add import * Eval fix * Evaluate AOTI --------- Co-authored-by: Mengwei Liu --- .ci/scripts/validate.sh | 97 +++++++++++++++++---------------- quantization/quantize.py | 114 ++++++--------------------------------- 2 files changed, 65 insertions(+), 146 deletions(-) diff --git a/.ci/scripts/validate.sh b/.ci/scripts/validate.sh index 12107f8d3c..55d37d3b5f 100644 --- a/.ci/scripts/validate.sh +++ b/.ci/scripts/validate.sh @@ -92,13 +92,16 @@ function generate_compiled_model_output() { python3 -W ignore generate.py --dtype ${DTYPE} --compile --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_compiled" || exit 1 .ci/scripts/check_gibberish "$MODEL_DIR/output_compiled" - echo "******************************************" - echo "******** INT4 group-wise quantized *******" - echo "******************************************" - python3 -W ignore generate.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_eager" || exit 1 - .ci/scripts/check_gibberish "$MODEL_DIR/output_eager" - python3 -W ignore generate.py --dtype ${DTYPE} --compile --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_compiled" || exit 1 - .ci/scripts/check_gibberish "$MODEL_DIR/output_compiled" + if [[ $TARGET_DEVICE != "cuda" || "$DTYPE" == "bfloat16" ]]; then + # For CUDA, only bfloat16 makes sense for int4 mm kernel + echo "******************************************" + echo "******** INT4 group-wise quantized *******" + echo "******************************************" + python3 -W ignore generate.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_eager" || exit 1 + .ci/scripts/check_gibberish "$MODEL_DIR/output_eager" + python3 -W ignore generate.py --dtype ${DTYPE} --compile --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_compiled" || exit 1 + .ci/scripts/check_gibberish "$MODEL_DIR/output_compiled" + fi fi done } @@ -180,12 +183,11 @@ function generate_aoti_model_output() { echo "******************************************" echo "******** INT4 group-wise quantized *******" echo "******************************************" - if [ "$TARGET_DEVICE" == "cuda" ]; then - if [ "$DTYPE" != "float16" ]; then - python3 -W ignore export.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1 - python3 -W ignore generate.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 - .ci/scripts/check_gibberish "$MODEL_DIR/output_aoti" - fi + if [[ "$TARGET_DEVICE" != "cuda" || "$DTYPE" == "bfloat16" ]]; then + # For CUDA, only bfloat16 makes sense for int4 mm kernel + python3 -W ignore export.py --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1 + python3 -W ignore generate.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1 + .ci/scripts/check_gibberish "$MODEL_DIR/output_aoti" fi done } @@ -225,21 +227,23 @@ function eval_model() { echo "perplexity checking succeeded for non-quantized model $MODEL_NAME with $DTYPE $TARGET_DEVICE" fi; - echo "******************************************" - echo "******** INT4 group-wise quantized *******" - echo "******************************************" + if [[ "$TARGET_DEVICE" != "cuda" || "$DTYPE" == "bfloat16" ]]; then + echo "******************************************" + echo "******** INT4 group-wise quantized *******" + echo "******************************************" - export QUANT_OPTIONS='{"linear:int4" : {"groupsize": 32}}' - python -W ignore eval.py --compile --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --device "$TARGET_DEVICE" > "$MODEL_DIR/eval" || exit 1 - cat "$MODEL_DIR/eval" - export REF_PERPLEXITY=100000 - export PERPLEXITY=cat "$MODEL_DIR/eval" | tail -n 1 log | awk -F '[, ]' '{print $4}' - # == 1 meaning the check succeeded - if [ "$(echo "$PERPLEXITY >= $REF_PERPLEXITY" | bc)" == 1]; then - echo "perplexity checking failed for int4-quantized model $MODEL_NAME with $DTYPE $TARGET_DEVICE $QUANT_OPTIONS" - else - echo "perplexity checking succeeded for int4-quantized model $MODEL_NAME with $DTYPE $TARGET_DEVICE $QUANT_OPTIONS" - fi; + export QUANT_OPTIONS='{"linear:int4" : {"groupsize": 32}}' + python -W ignore eval.py --compile --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --device "$TARGET_DEVICE" > "$MODEL_DIR/eval" || exit 1 + cat "$MODEL_DIR/eval" + export REF_PERPLEXITY=100000 + export PERPLEXITY=cat "$MODEL_DIR/eval" | tail -n 1 log | awk -F '[, ]' '{print $4}' + # == 1 meaning the check succeeded + if [ "$(echo "$PERPLEXITY >= $REF_PERPLEXITY" | bc)" == 1]; then + echo "perplexity checking failed for int4-quantized model $MODEL_NAME with $DTYPE $TARGET_DEVICE $QUANT_OPTIONS" + else + echo "perplexity checking succeeded for int4-quantized model $MODEL_NAME with $DTYPE $TARGET_DEVICE $QUANT_OPTIONS" + fi; + fi done } @@ -260,32 +264,31 @@ function eval_model_sanity_check() { python -W ignore eval.py --compile --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/eval" || exit 1 cat "$MODEL_DIR/eval" - echo "******************************************" - echo "******** INT4 group-wise quantized *******" - echo "******************************************" + if [[ "$TARGET_DEVICE" != "cuda" || "$DTYPE" == "bfloat16" ]]; then + echo "******************************************" + echo "******** INT4 group-wise quantized *******" + echo "******************************************" - export QUANT_OPTIONS='{"linear:int4" : {"groupsize": 32}}' - python -W ignore eval.py --compile --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/eval" || exit 1 - cat "$MODEL_DIR/eval" + export QUANT_OPTIONS='{"linear:int4" : {"groupsize": 32}}' + python -W ignore eval.py --compile --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/eval" || exit 1 + cat "$MODEL_DIR/eval" - echo "**************************************************" - echo "******** INT4 group-wise quantized (eager) *******" - echo "**************************************************" + echo "**************************************************" + echo "******** INT4 group-wise quantized (eager) *******" + echo "**************************************************" - if [ "$TARGET_DEVICE" == "cuda" ] && [ "$DTYPE" != "float16" ]; then python -W ignore eval.py --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/eval_eager" || exit 1 cat "$MODEL_DIR/eval_eager" - fi; - - # there is some issues with AOTI cpu and cuda, need to fix and enable the test for cuda as well - echo "*************************************************" - echo "******** INT4 group-wise quantized (AOTI) *******" - echo "*************************************************" - if [ "$DTYPE" != "float16" ]; then - python3 -W ignore export.py --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1 - python3 -W ignore eval.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/output_eval_aoti" || exit 1 - cat "$MODEL_DIR/output_eval_aoti" + # there is some issues with AOTI cpu and cuda, need to fix and enable the test for cuda as well + echo "*************************************************" + echo "******** INT4 group-wise quantized (AOTI) *******" + echo "*************************************************" + if [ "$DTYPE" != "float16" ]; then + python3 -W ignore export.py --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1 + python3 -W ignore eval.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/output_eval_aoti" || exit 1 + cat "$MODEL_DIR/output_eval_aoti" + fi; fi; done diff --git a/quantization/quantize.py b/quantization/quantize.py index 4b6ec25cc7..c72ef2aa1d 100644 --- a/quantization/quantize.py +++ b/quantization/quantize.py @@ -31,23 +31,19 @@ import torch import torch.nn as nn import torch.nn.functional as F -from build.utils import ( - find_multiple, - get_device_str, - get_precision, - name_to_dtype, - state_dict_device, -) +from build.utils import get_device_str, get_precision, name_to_dtype, state_dict_device -from quantization.qops import ( - LinearInt4 as WeightOnlyInt4Linear, - LinearInt8 as WeightOnlyInt8Linear, - QuantizedEmbedding, -) +from quantization.qops import LinearInt8 as WeightOnlyInt8Linear, QuantizedEmbedding # AttributeError: '_OpNamespace' 'quantized_decomposed' object has no attribute 'quantize_per_channel_group' from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa -from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer +from torchao.quantization.quant_api import ( + int4_weight_only, + Int4WeightOnlyQuantizer, + Int8DynActInt4WeightQuantizer, + quantize_, +) +from torchao.utils import unwrap_tensor_subclass ######################################################################### @@ -75,6 +71,11 @@ def quantize_model(model: nn.Module, device, quantize_options, tokenizer=None): ): raise RuntimeError(f"unknown quantizer {quantizer} specified") if quantizer in ao_quantizer_class_dict: + # Use tensor subclass API for int4 weight only. + if device == "cuda" and quantizer == "linear:int4": + quantize_(model, int4_weight_only(q_kwargs["groupsize"])) + unwrap_tensor_subclass(model) + continue # Use dtype precision specified in user config, else fallback on global precision. if "precision" in quantize_options: dtype = quantize_options["precision"].get("dtype", str(get_precision())) @@ -556,91 +557,6 @@ def quantized_model(self) -> nn.Module: return self.quantize(self.model_) -######################################################################### -##### weight only int4 per channel groupwise quantized code ###### - - -class WeightOnlyInt4QuantHandler(QuantHandler): - def __init__( - self, - model: nn.Module, - device=None, - *, - tokenizer=None, - groupsize=128, - inner_k_tiles=8, - padding_allowed=True, - ): - self.model_ = model - self.device = device - self.groupsize = groupsize - self.inner_k_tiles = inner_k_tiles - self.padding_allowed = padding_allowed - assert groupsize in [32, 64, 128, 256] - assert inner_k_tiles in [2, 4, 8] - - @torch.no_grad() - def quantize(self, module): - for name, child in module.named_children(): - # print(f"name: {name}") - if isinstance(child, torch.nn.Linear): - assert not child.bias - out_features = child.out_features - in_features = child.in_features - assert out_features % 8 == 0, "require out_features % 8 == 0" - # print(f"linear: {fqn}, in={in_features}, out={out_features}") - - weight = child.weight.data - if not WeightOnlyInt4Linear._check_k( - k=in_features, - groupsize=self.groupsize, - inner_k_tiles=self.inner_k_tiles, - ): - if self.padding_allowed: - # print( - # f"warning: {name} is padded to satisfy in_features % 1024 == 0" - # ) - padded_in_features = find_multiple(in_features, 1024) - weight = F.pad( - weight, pad=(0, padded_in_features - in_features) - ) - else: - print( - f"warning: {name} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " - + "and that groupsize and inner_k_tiles*16 evenly divide into it" - ) - continue - weight_int4pack, scales_and_zeros = ( - WeightOnlyInt4Linear._prepare_weight_and_scales_and_zeros( - weight.to(torch.float), self.groupsize, self.inner_k_tiles - ) - ) - weight_int4pack = weight_int4pack.to(device=self.device) - scales_and_zeros = scales_and_zeros.to(device=self.device) - - setattr( - module, - name, - WeightOnlyInt4Linear( - child.in_features, - child.out_features, - bias=False, - device=self.device, - groupsize=self.groupsize, - inner_k_tiles=self.inner_k_tiles, - weight=weight_int4pack, - scales_and_zeros=scales_and_zeros, - ), - ) - else: - self.quantize(child) - - return module - - def quantized_model(self) -> nn.Module: - return self.quantize(self.model_) - - ########################################################################## ### quantization dictionary ### @@ -650,11 +566,11 @@ def quantized_model(self) -> nn.Module: quantizer_class_dict = { "embedding": EmbeddingOnlyQuantHandler, "linear:int8": WeightOnlyInt8QuantHandler, - "linear:int4": WeightOnlyInt4QuantHandler, "precision": PrecisionHandler, "executor": ExecutorHandler, } ao_quantizer_class_dict = { + "linear:int4": Int4WeightOnlyQuantizer, "linear:a8w4dq": Int8DynActInt4WeightQuantizer, }