Skip to content

Commit

Permalink
Feat (quantization): torch_function based quantization (#1147)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Jan 13, 2025
1 parent e3de228 commit c44566d
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 107 deletions.
33 changes: 33 additions & 0 deletions src/brevitas/graph/quantize.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Dict

from packaging import version
import torch
from torch import nn

from brevitas import config
from brevitas import torch_version
from brevitas.core.scaling.standalone import ConstScaling
from brevitas.core.scaling.standalone import ParameterScaling
from brevitas.fx.brevitas_tracer import symbolic_trace
Expand Down Expand Up @@ -33,6 +38,34 @@
from brevitas.quant import Uint8ActPerTensorFloatMaxInit
from brevitas.quant.scaled_int import Int8WeightPerTensorFloat

if torch_version >= version.parse('1.12'):
from torch.overrides import TorchFunctionMode

class functional_quantization_mode(TorchFunctionMode):

def __init__(self, model: torch.nn.Module, quant_map: Dict, enabled: bool = True):
super().__init__()
self.quant_map = quant_map
self.model = model
self.enabled = enabled
for stateless_function, stateless_module in quant_map.items():
if not hasattr(model, str(stateless_function)):
setattr(model, str(stateless_function), stateless_module())

def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = dict()

if hasattr(self.model, str(func)) and self.enabled:
module = getattr(self.model, str(func))
out = module(*args, **kwargs)
else:
out = func(*args, **kwargs)

return out
else:
functional_quantization_mode = object()

COMPUTE_LAYER_MAP = {
nn.AvgPool2d:
None,
Expand Down
7 changes: 5 additions & 2 deletions src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED]
[--convert-layernorm-to-rmsnorm] [--replace-rmsnorm]
[--no-quantize] [--no-float16]
[--scaling-min-val SCALING_MIN_VAL] [--quant-sdpa]
[--replace-mha] [--weight-equalization]
[--rotation {fx,layerwise,fused_no_fx}]
[--functional-sdpa-quant] [--replace-mha]
[--weight-equalization] [--rotation {fx,layerwise,fused_no_fx}]
[--rotation-mode {had,ort}] [--rotation-orphan-sink]
[--rotation-sdpa-regions]
[--act-equalization {None,layerwise,fx}]
Expand Down Expand Up @@ -171,6 +171,9 @@ options:
fp16 quantization.
--quant-sdpa Quantize `F.scaled_dot_product_attention` (default:
False)
--functional-sdpa-quant
Quantize `F.scaled_dot_product_attention` with
stateless module and torch_function (default: False)
--replace-mha Replace HuggingFace Attention with a quantizable
version
--weight-equalization
Expand Down
1 change: 1 addition & 0 deletions src/brevitas_examples/llm/config/default_template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ few_shot_tasks:
- winogrande
- piqa
few_shot_zeroshot: false
functional_sdpa_quant: false
fuse_sequences: false
gpfq: false
gptq: false
Expand Down
239 changes: 134 additions & 105 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import argparse
from contextlib import nullcontext
from copy import deepcopy
import functools
import sys
Expand All @@ -23,8 +24,10 @@
from brevitas.graph import load_quant_model_mode
from brevitas.graph.equalize import GraphRotationEqualization
from brevitas.graph.equalize import LayerwiseActivationRotation
from brevitas.graph.quantize import functional_quantization_mode
from brevitas.graph.quantize import layerwise_quantize
from brevitas.graph.utils import get_module
from brevitas.nn.quant_sdpa import ScaledDotProductAttention
from brevitas.utils.python_utils import hooked_on_a_function
from brevitas_examples.common.accelerate_utils.accelerate import offload_model
from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks
Expand Down Expand Up @@ -133,6 +136,8 @@ def model_export(model, ref_input, args):


def validate(args):
if args.functional_sdpa_quant:
assert args.input_scale_type == 'dynamic' or args.input_bit_width is None, "Functional SDPA Quant requires dynamic activation quantization"
if args.rotation == 'fx':
assert args.ln_affine_merge, 'Graph rotation requires to merge LN/RMS norm affine parameters'
assert args.replace_rmsnorm, 'Graph rotation requires to replace HF RMSNorm with PyTorch ones (torch 2.4+ require)'
Expand Down Expand Up @@ -319,7 +324,12 @@ def quantize_llm(args):
print("Replace `F.scaled_dot_product_attention` with QuantSDPA...")
model = replace_sdpa_with_quantizable_layers(model)
print("Replacing done.")

elif args.functional_sdpa_quant:
print("Inserting SDPA quantizable module")
model = offload_model(model)
with torch.no_grad(), functional_quantization_mode(model, {torch.nn.functional.scaled_dot_product_attention: ScaledDotProductAttention}):
model(**calibration_loader[0])
remove_hooks(model)
if args.weight_equalization:
print("Apply weight equalization...")
# In case of float16 model, we need to offload to account for missing ops
Expand Down Expand Up @@ -424,118 +434,131 @@ def quantize_llm(args):
new_funct = functools.partial(update_internal_dict, m)
m._hf_hook.post_forward = hooked_on_a_function(m._hf_hook.post_forward, new_funct)

with torch.no_grad():
model(**calibration_loader[0])
# If we are doing functional SDPA quantization, we create the correct context manager,
# otherwise nullcontext. We would love to avoid the extra indentation level but it doesn't seem easy.
if args.functional_sdpa_quant:
quantization_cm = functional_quantization_mode(
model, {torch.nn.functional.scaled_dot_product_attention: ScaledDotProductAttention})
else:
quantization_cm = nullcontext()

# We restore the original behaviour of the post-forward.
for k, v in dict_hooks.items():
k._hf_hook.post_forward = v
with quantization_cm:
with torch.no_grad():
model(**calibration_loader[0])

if args.act_calibration and not args.load_checkpoint:
print("Apply act calibration...")
apply_calibration(model, calibration_loader)
print("Act calibration applied.")
# We restore the original behaviour of the post-forward.
for k, v in dict_hooks.items():
k._hf_hook.post_forward = v

if args.learned_round:
print("Applying learned round...")
if args.load_checkpoint:
iters = 1
loader = [calibration_loader[0]]
else:
iters = args.learned_round_iters
loader = calibration_loader
remove_hooks(model)
apply_learned_round(
model,
loader,
iters=iters,
block_name_attribute=args.gpxq_block_name,
learn_scale=args.learned_round_scale,
scale_optimizer_class='sgd',
optimizer_kwargs={'lr': args.learned_round_lr},
scale_optimizer_kwargs={
'lr': args.learned_round_scale_lr, 'momentum': args.learned_round_scale_momentum},
fast_update=args.learned_round_fast_update)
print("Learned round applied.")
if args.act_calibration and not args.load_checkpoint:
print("Apply act calibration...")
apply_calibration(model, calibration_loader)
print("Act calibration applied.")

model = offload_model(model)
if args.learned_round:
print("Applying learned round...")
if args.load_checkpoint:
iters = 1
loader = [calibration_loader[0]]
else:
iters = args.learned_round_iters
loader = calibration_loader
remove_hooks(model)
apply_learned_round(
model,
loader,
iters=iters,
block_name_attribute=args.gpxq_block_name,
learn_scale=args.learned_round_scale,
scale_optimizer_class='sgd',
optimizer_kwargs={'lr': args.learned_round_lr},
scale_optimizer_kwargs={
'lr': args.learned_round_scale_lr,
'momentum': args.learned_round_scale_momentum},
fast_update=args.learned_round_fast_update)
print("Learned round applied.")
model = offload_model(model)

if args.load_checkpoint:
if args.load_checkpoint:
remove_hooks(model)
with load_quant_model_mode(model):
model.load_state_dict(torch.load(args.checkpoint_name, map_location='cpu'))
model = offload_model(model)

if args.gptq and not args.load_checkpoint:
print("Applying GPTQ...")
apply_gptq(
model,
calibration_loader,
act_order=args.gpxq_act_order,
use_quant_activations=args.gpxq_use_quant_activations,
create_weight_orig=args.gpxq_create_weight_orig,
block_name=args.gpxq_block_name,
max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width,
max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size)
print("GPTQ applied.")

if args.gpfq and not args.load_checkpoint:
print("Applying GPFQ...")
apply_gpfq(
model,
calibration_loader,
act_order=args.gpxq_act_order,
block_name=args.gpxq_block_name,
max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width,
max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size)
print("GPFQ applied.")

if args.bias_corr and not args.load_checkpoint:
print("Applying bias correction...")
apply_bias_correction(model, calibration_loader)
print("Bias correction applied.")

if args.checkpoint_name is not None:
print(f"Saving checkpoint to {args.checkpoint_name}")
torch.save(model.state_dict(), args.checkpoint_name)

if args.eval and not args.no_quantize:
print("Model eval...")
with torch.no_grad(), quant_inference_mode(model):
model(**calibration_loader[0])
quant_ppl = compute_perplexity(
model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}")

if args.few_shot_eval:
with torch.no_grad(), quant_inference_mode(model):
model(**calibration_loader[0])
if args.few_shot_compile:
remove_hooks(model)
model.cuda()
model = torch.compile(model)

wrapped_model = HFLM(pretrained=model) # need to wrap for LLM eval
results = evaluator.simple_evaluate(
model=wrapped_model,
model_args=None,
tasks=list(args.few_shot_tasks),
device='cuda:0',
limit=args.few_shot_limit,
num_fewshot=0 if args.few_shot_zeroshot else None,
log_samples=False,
batch_size=None,
verbosity="ERROR")
results = filter_results(results, args.few_shot_tasks)
print("Few shot eval results")
print(results)
remove_hooks(model)
with load_quant_model_mode(model):
model.load_state_dict(torch.load(args.checkpoint_name, map_location='cpu'))
model = offload_model(model)

if args.gptq and not args.load_checkpoint:
print("Applying GPTQ...")
apply_gptq(
model,
calibration_loader,
act_order=args.gpxq_act_order,
use_quant_activations=args.gpxq_use_quant_activations,
create_weight_orig=args.gpxq_create_weight_orig,
block_name=args.gpxq_block_name,
max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width,
max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size)
print("GPTQ applied.")

if args.gpfq and not args.load_checkpoint:
print("Applying GPFQ...")
apply_gpfq(
model,
calibration_loader,
act_order=args.gpxq_act_order,
block_name=args.gpxq_block_name,
max_accumulator_bit_width=args.gpxq_max_accumulator_bit_width,
max_accumulator_tile_size=args.gpxq_max_accumulator_tile_size)
print("GPFQ applied.")

if args.bias_corr and not args.load_checkpoint:
print("Applying bias correction...")
apply_bias_correction(model, calibration_loader)
print("Bias correction applied.")

if args.eval and not args.no_quantize:
print("Model eval...")
with torch.no_grad(), quant_inference_mode(model):
model(**calibration_loader[0])
quant_ppl = compute_perplexity(
model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}")
if args.checkpoint_name is not None and not args.load_checkpoint:
print(f"Saving checkpoint to {args.checkpoint_name}")
torch.save(model.state_dict(), args.checkpoint_name)

if args.few_shot_eval:
with torch.no_grad(), quant_inference_mode(model):
model(**calibration_loader[0])
if args.few_shot_compile:
remove_hooks(model)
model.cuda()
model = torch.compile(model)

wrapped_model = HFLM(pretrained=model) # need to wrap for LLM eval
results = evaluator.simple_evaluate(
model=wrapped_model,
model_args=None,
tasks=list(args.few_shot_tasks),
device='cuda:0',
limit=args.few_shot_limit,
num_fewshot=0 if args.few_shot_zeroshot else None,
log_samples=False,
batch_size=None,
verbosity="ERROR")
results = filter_results(results, args.few_shot_tasks)
print("Few shot eval results")
print(results)
remove_hooks(model)

if args.checkpoint_name is not None and not args.load_checkpoint:
print(f"Saving checkpoint to {args.checkpoint_name}")
torch.save(model.state_dict(), args.checkpoint_name)

if args.export_target:
print(f"Export to {args.export_target}")
# Currently we always export on CPU with a float32 container to avoid float16 CPU errors
model = model.to(dtype=torch.float32)
model_export(model, calibration_loader[0], args)
if args.export_target:
print(f"Export to {args.export_target}")
# Currently we always export on CPU with a float32 container to avoid float16 CPU errors
model = model.to(dtype=torch.float32)
model_export(model, calibration_loader[0], args)

return float_ppl, quant_ppl, model

Expand Down Expand Up @@ -765,6 +788,12 @@ def parse_args(args, override_defaults={}):
'--quant-sdpa',
action='store_true',
help='Quantize `F.scaled_dot_product_attention` (default: %(default)s)')
parser.add_argument(
'--functional-sdpa-quant',
action='store_true',
help=
'Quantize `F.scaled_dot_product_attention` with stateless module and torch_function (default: %(default)s)'
)
parser.add_argument(
'--replace-mha',
action='store_true',
Expand Down

0 comments on commit c44566d

Please sign in to comment.