From c44566dab86daeaaa3b2e5bf0f52153f10a4f55a Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 13 Jan 2025 23:19:06 +0100 Subject: [PATCH] Feat (quantization): torch_function based quantization (#1147) --- src/brevitas/graph/quantize.py | 33 +++ src/brevitas_examples/llm/README.md | 7 +- .../llm/config/default_template.yml | 1 + src/brevitas_examples/llm/main.py | 239 ++++++++++-------- 4 files changed, 173 insertions(+), 107 deletions(-) diff --git a/src/brevitas/graph/quantize.py b/src/brevitas/graph/quantize.py index ee035b9bd..7724e8f9d 100644 --- a/src/brevitas/graph/quantize.py +++ b/src/brevitas/graph/quantize.py @@ -1,9 +1,14 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from typing import Dict + +from packaging import version +import torch from torch import nn from brevitas import config +from brevitas import torch_version from brevitas.core.scaling.standalone import ConstScaling from brevitas.core.scaling.standalone import ParameterScaling from brevitas.fx.brevitas_tracer import symbolic_trace @@ -33,6 +38,34 @@ from brevitas.quant import Uint8ActPerTensorFloatMaxInit from brevitas.quant.scaled_int import Int8WeightPerTensorFloat +if torch_version >= version.parse('1.12'): + from torch.overrides import TorchFunctionMode + + class functional_quantization_mode(TorchFunctionMode): + + def __init__(self, model: torch.nn.Module, quant_map: Dict, enabled: bool = True): + super().__init__() + self.quant_map = quant_map + self.model = model + self.enabled = enabled + for stateless_function, stateless_module in quant_map.items(): + if not hasattr(model, str(stateless_function)): + setattr(model, str(stateless_function), stateless_module()) + + def __torch_function__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = dict() + + if hasattr(self.model, str(func)) and self.enabled: + module = getattr(self.model, str(func)) + out = module(*args, **kwargs) + else: + out = func(*args, **kwargs) + + return out +else: + functional_quantization_mode = object() + COMPUTE_LAYER_MAP = { nn.AvgPool2d: None, diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index ee0e3df77..c1c9d9919 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -47,8 +47,8 @@ usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED] [--convert-layernorm-to-rmsnorm] [--replace-rmsnorm] [--no-quantize] [--no-float16] [--scaling-min-val SCALING_MIN_VAL] [--quant-sdpa] - [--replace-mha] [--weight-equalization] - [--rotation {fx,layerwise,fused_no_fx}] + [--functional-sdpa-quant] [--replace-mha] + [--weight-equalization] [--rotation {fx,layerwise,fused_no_fx}] [--rotation-mode {had,ort}] [--rotation-orphan-sink] [--rotation-sdpa-regions] [--act-equalization {None,layerwise,fx}] @@ -171,6 +171,9 @@ options: fp16 quantization. --quant-sdpa Quantize `F.scaled_dot_product_attention` (default: False) + --functional-sdpa-quant + Quantize `F.scaled_dot_product_attention` with + stateless module and torch_function (default: False) --replace-mha Replace HuggingFace Attention with a quantizable version --weight-equalization diff --git a/src/brevitas_examples/llm/config/default_template.yml b/src/brevitas_examples/llm/config/default_template.yml index f686a7b36..b7d1ab864 100644 --- a/src/brevitas_examples/llm/config/default_template.yml +++ b/src/brevitas_examples/llm/config/default_template.yml @@ -17,6 +17,7 @@ few_shot_tasks: - winogrande - piqa few_shot_zeroshot: false +functional_sdpa_quant: false fuse_sequences: false gpfq: false gptq: false diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index cd997df11..4f03ba087 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import argparse +from contextlib import nullcontext from copy import deepcopy import functools import sys @@ -23,8 +24,10 @@ from brevitas.graph import load_quant_model_mode from brevitas.graph.equalize import GraphRotationEqualization from brevitas.graph.equalize import LayerwiseActivationRotation +from brevitas.graph.quantize import functional_quantization_mode from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.utils import get_module +from brevitas.nn.quant_sdpa import ScaledDotProductAttention from brevitas.utils.python_utils import hooked_on_a_function from brevitas_examples.common.accelerate_utils.accelerate import offload_model from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks @@ -133,6 +136,8 @@ def model_export(model, ref_input, args): def validate(args): + if args.functional_sdpa_quant: + assert args.input_scale_type == 'dynamic' or args.input_bit_width is None, "Functional SDPA Quant requires dynamic activation quantization" if args.rotation == 'fx': assert args.ln_affine_merge, 'Graph rotation requires to merge LN/RMS norm affine parameters' assert args.replace_rmsnorm, 'Graph rotation requires to replace HF RMSNorm with PyTorch ones (torch 2.4+ require)' @@ -319,7 +324,12 @@ def quantize_llm(args): print("Replace `F.scaled_dot_product_attention` with QuantSDPA...") model = replace_sdpa_with_quantizable_layers(model) print("Replacing done.") - + elif args.functional_sdpa_quant: + print("Inserting SDPA quantizable module") + model = offload_model(model) + with torch.no_grad(), functional_quantization_mode(model, {torch.nn.functional.scaled_dot_product_attention: ScaledDotProductAttention}): + model(**calibration_loader[0]) + remove_hooks(model) if args.weight_equalization: print("Apply weight equalization...") # In case of float16 model, we need to offload to account for missing ops @@ -424,118 +434,131 @@ def quantize_llm(args): new_funct = functools.partial(update_internal_dict, m) m._hf_hook.post_forward = hooked_on_a_function(m._hf_hook.post_forward, new_funct) - with torch.no_grad(): - model(**calibration_loader[0]) + # If we are doing functional SDPA quantization, we create the correct context manager, + # otherwise nullcontext. We would love to avoid the extra indentation level but it doesn't seem easy. + if args.functional_sdpa_quant: + quantization_cm = functional_quantization_mode( + model, {torch.nn.functional.scaled_dot_product_attention: ScaledDotProductAttention}) + else: + quantization_cm = nullcontext() - # We restore the original behaviour of the post-forward. - for k, v in dict_hooks.items(): - k._hf_hook.post_forward = v + with quantization_cm: + with torch.no_grad(): + model(**calibration_loader[0]) - if args.act_calibration and not args.load_checkpoint: - print("Apply act calibration...") - apply_calibration(model, calibration_loader) - print("Act calibration applied.") + # We restore the original behaviour of the post-forward. + for k, v in dict_hooks.items(): + k._hf_hook.post_forward = v - if args.learned_round: - print("Applying learned round...") - if args.load_checkpoint: - iters = 1 - loader = [calibration_loader[0]] - else: - iters = args.learned_round_iters - loader = calibration_loader - remove_hooks(model) - apply_learned_round( - model, - loader, - iters=iters, - block_name_attribute=args.gpxq_block_name, - learn_scale=args.learned_round_scale, - scale_optimizer_class='sgd', - optimizer_kwargs={'lr': args.learned_round_lr}, - scale_optimizer_kwargs={ - 'lr': args.learned_round_scale_lr, 'momentum': args.learned_round_scale_momentum}, - fast_update=args.learned_round_fast_update) - print("Learned round applied.") + if args.act_calibration and not args.load_checkpoint: + print("Apply act calibration...") + apply_calibration(model, calibration_loader) + print("Act calibration applied.") - model = offload_model(model) + if args.learned_round: + print("Applying learned round...") + if args.load_checkpoint: + iters = 1 + loader = [calibration_loader[0]] + else: + iters = args.learned_round_iters + loader = calibration_loader + remove_hooks(model) + apply_learned_round( + model, + loader, + iters=iters, + block_name_attribute=args.gpxq_block_name, + learn_scale=args.learned_round_scale, + scale_optimizer_class='sgd', + optimizer_kwargs={'lr': args.learned_round_lr}, + scale_optimizer_kwargs={ + 'lr': args.learned_round_scale_lr, + 'momentum': args.learned_round_scale_momentum}, + fast_update=args.learned_round_fast_update) + print("Learned round applied.") + model = offload_model(model) - if args.load_checkpoint: + if args.load_checkpoint: + remove_hooks(model) + with load_quant_model_mode(model): + model.load_state_dict(torch.load(args.checkpoint_name, map_location='cpu')) + model = offload_model(model) + + if args.gptq and not args.load_checkpoint: + print("Applying GPTQ...") + apply_gptq( + model, + calibration_loader, + act_order=args.gpxq_act_order, + use_quant_activations=args.gpxq_use_quant_activations, + create_weight_orig=args.gpxq_create_weight_orig, + block_name=args.gpxq_block_name, + max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width, + max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size) + print("GPTQ applied.") + + if args.gpfq and not args.load_checkpoint: + print("Applying GPFQ...") + apply_gpfq( + model, + calibration_loader, + act_order=args.gpxq_act_order, + block_name=args.gpxq_block_name, + max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width, + max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size) + print("GPFQ applied.") + + if args.bias_corr and not args.load_checkpoint: + print("Applying bias correction...") + apply_bias_correction(model, calibration_loader) + print("Bias correction applied.") + + if args.checkpoint_name is not None: + print(f"Saving checkpoint to {args.checkpoint_name}") + torch.save(model.state_dict(), args.checkpoint_name) + + if args.eval and not args.no_quantize: + print("Model eval...") + with torch.no_grad(), quant_inference_mode(model): + model(**calibration_loader[0]) + quant_ppl = compute_perplexity( + model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) + print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}") + + if args.few_shot_eval: + with torch.no_grad(), quant_inference_mode(model): + model(**calibration_loader[0]) + if args.few_shot_compile: + remove_hooks(model) + model.cuda() + model = torch.compile(model) + + wrapped_model = HFLM(pretrained=model) # need to wrap for LLM eval + results = evaluator.simple_evaluate( + model=wrapped_model, + model_args=None, + tasks=list(args.few_shot_tasks), + device='cuda:0', + limit=args.few_shot_limit, + num_fewshot=0 if args.few_shot_zeroshot else None, + log_samples=False, + batch_size=None, + verbosity="ERROR") + results = filter_results(results, args.few_shot_tasks) + print("Few shot eval results") + print(results) remove_hooks(model) - with load_quant_model_mode(model): - model.load_state_dict(torch.load(args.checkpoint_name, map_location='cpu')) - model = offload_model(model) - if args.gptq and not args.load_checkpoint: - print("Applying GPTQ...") - apply_gptq( - model, - calibration_loader, - act_order=args.gpxq_act_order, - use_quant_activations=args.gpxq_use_quant_activations, - create_weight_orig=args.gpxq_create_weight_orig, - block_name=args.gpxq_block_name, - max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width, - max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size) - print("GPTQ applied.") - - if args.gpfq and not args.load_checkpoint: - print("Applying GPFQ...") - apply_gpfq( - model, - calibration_loader, - act_order=args.gpxq_act_order, - block_name=args.gpxq_block_name, - max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width, - max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size) - print("GPFQ applied.") - - if args.bias_corr and not args.load_checkpoint: - print("Applying bias correction...") - apply_bias_correction(model, calibration_loader) - print("Bias correction applied.") - - if args.eval and not args.no_quantize: - print("Model eval...") - with torch.no_grad(), quant_inference_mode(model): - model(**calibration_loader[0]) - quant_ppl = compute_perplexity( - model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) - print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}") + if args.checkpoint_name is not None and not args.load_checkpoint: + print(f"Saving checkpoint to {args.checkpoint_name}") + torch.save(model.state_dict(), args.checkpoint_name) - if args.few_shot_eval: - with torch.no_grad(), quant_inference_mode(model): - model(**calibration_loader[0]) - if args.few_shot_compile: - remove_hooks(model) - model.cuda() - model = torch.compile(model) - - wrapped_model = HFLM(pretrained=model) # need to wrap for LLM eval - results = evaluator.simple_evaluate( - model=wrapped_model, - model_args=None, - tasks=list(args.few_shot_tasks), - device='cuda:0', - limit=args.few_shot_limit, - num_fewshot=0 if args.few_shot_zeroshot else None, - log_samples=False, - batch_size=None, - verbosity="ERROR") - results = filter_results(results, args.few_shot_tasks) - print("Few shot eval results") - print(results) - remove_hooks(model) - - if args.checkpoint_name is not None and not args.load_checkpoint: - print(f"Saving checkpoint to {args.checkpoint_name}") - torch.save(model.state_dict(), args.checkpoint_name) - - if args.export_target: - print(f"Export to {args.export_target}") - # Currently we always export on CPU with a float32 container to avoid float16 CPU errors - model = model.to(dtype=torch.float32) - model_export(model, calibration_loader[0], args) + if args.export_target: + print(f"Export to {args.export_target}") + # Currently we always export on CPU with a float32 container to avoid float16 CPU errors + model = model.to(dtype=torch.float32) + model_export(model, calibration_loader[0], args) return float_ppl, quant_ppl, model @@ -765,6 +788,12 @@ def parse_args(args, override_defaults={}): '--quant-sdpa', action='store_true', help='Quantize `F.scaled_dot_product_attention` (default: %(default)s)') + parser.add_argument( + '--functional-sdpa-quant', + action='store_true', + help= + 'Quantize `F.scaled_dot_product_attention` with stateless module and torch_function (default: %(default)s)' + ) parser.add_argument( '--replace-mha', action='store_true',