Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New rotation #1159

Merged
merged 6 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from brevitas.graph.hadamard import random_hadamard_matrix
from brevitas.graph.utils import get_module
from brevitas.graph.utils import get_node
from brevitas.nn import ScaledDotProductAttention
from brevitas.nn.equalized_layer import EqualizedModule
from brevitas.nn.equalized_layer import functional_rotate_input
from brevitas.nn.equalized_layer import INPUT_NAMES
Expand Down Expand Up @@ -1584,6 +1585,10 @@ def find_sink(node):
name_to_module={
'src0': src_module, 'sink0': sink_module})
regions.append(region)
for m in graph_module.modules():
if isinstance(m, ScaledDotProductAttention):
m.pre_process_q = functional_rotate_input
m.pre_process_k = functional_rotate_input
return regions

def apply(self,
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/graph/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, model: torch.nn.Module, quant_map: Dict, enabled: bool = True
self.enabled = enabled
for stateless_function, stateless_module in quant_map.items():
if not hasattr(model, str(stateless_function)):
setattr(model, str(stateless_function), stateless_module())
model.add_module(str(stateless_function), stateless_module())

def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/nn/equalized_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ def forward(self, inp, **kwargs):
def functional_rotate_input(inp, transpose=False):
is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None
if transpose:
inp = inp.t()
inp = inp.transpose(-2, -1)
if is_cuda and fast_hadamard_transform is not None:
had_K, K = get_hadK(inp.shape[-1])
inp = matmul_hadU_cuda(inp, had_K, K)
else:
inp = matmul_hadU(inp)

if transpose:
inp = inp.t()
inp = inp.transpose(-2, -1)
return inp
23 changes: 20 additions & 3 deletions src/brevitas/nn/quant_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
from torch.nn import Parameter
import torch.nn.functional as F

from brevitas.core.function_wrapper.misc import Identity
from brevitas.function import identity
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant.scaled_int import Uint8ActPerTensorFloat

Expand All @@ -57,6 +59,12 @@

class ScaledDotProductAttention(Module):

def __init__(self, pre_process_q=identity, pre_process_k=identity, pre_process_v=identity):
super().__init__()
self.pre_process_q = pre_process_q
self.pre_process_k = pre_process_k
self.pre_process_v = pre_process_v

def forward(
self,
query: Tensor,
Expand Down Expand Up @@ -103,9 +111,9 @@ def forward(
if enable_gqa:
kwargs["enable_gqa"] = enable_gqa
return F.scaled_dot_product_attention(
query=query,
key=key,
value=value,
query=self.pre_process_q(query),
key=self.pre_process_k(key),
value=self.pre_process_v(value),
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
Expand All @@ -116,6 +124,9 @@ class QuantScaledDotProductAttention(Module):

def __init__(
self,
pre_process_q=identity,
pre_process_k=identity,
pre_process_v=identity,
softmax_input_quant=None,
attn_output_weights_quant=Uint8ActPerTensorFloat,
q_scaled_quant=Int8ActPerTensorFloat,
Expand All @@ -125,6 +136,11 @@ def __init__(
**kwargs) -> None:
super(QuantScaledDotProductAttention, self).__init__()

self.pre_process_q = pre_process_q
self.pre_process_k = pre_process_k
self.pre_process_v = pre_process_v
print(self.pre_process_q)

def filter_kwargs(prefix):
return {k[len(prefix):]: v for k, v in kwargs.items() if k.startswith(prefix)}

Expand Down Expand Up @@ -196,6 +212,7 @@ def forward(
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
query, key, value = self.pre_process_q(query), self.pre_process_k(key), self.pre_process_v(value)
q_scaled = self.q_scaled_quant(query * scale_factor)
k_transpose = self.k_transposed_quant(key.transpose(-2, -1))
attn_weight = q_scaled @ k_transpose
Expand Down
8 changes: 6 additions & 2 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@
"""
import re

from dependencies import this
import torch
from torch import nn

from brevitas import nn as qnn
from brevitas.core.function_wrapper import CeilSte
from brevitas.core.function_wrapper import FloorSte
from brevitas.core.restrict_val import RoundSte
from brevitas.core.stats import NegativeMinOrZero
from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint
from brevitas.graph.quantize import layerwise_quantize
from brevitas.quant.base import ParameterFromRuntimeZeroPoint
from brevitas.quant.experimental.float import Fp8e4m3Act
from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat
from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat
Expand Down Expand Up @@ -68,6 +71,7 @@
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerRowFloat
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat
from brevitas_examples.common.generative.quantizers import IntWeightSymmetricGroupQuant
from brevitas_examples.common.generative.quantizers import RuntimeDynamicStatsZeroPoint
from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerRowFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat

Expand Down Expand Up @@ -388,10 +392,10 @@ def generate_quantizers(
elif input_quant_granularity == 'per_group':
q_scaled_quant = sym_input_quant.let(
**{
'group_dim': 2, 'group_size': input_group_size})
'group_dim': -1, 'group_size': input_group_size})
k_transposed_quant = sym_input_quant.let(
**{
'group_dim': 1, 'group_size': input_group_size})
'group_dim': -1, 'group_size': input_group_size})
v_quant = q_scaled_quant
attn_output_weights_quant = q_scaled_quant
else:
Expand Down
33 changes: 19 additions & 14 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def set_seed(seed):
def fused_rotation_no_fx(model, calibration_loader, args):
with torch.no_grad():
new_model, guards = torch._dynamo.export(model)(**calibration_loader[0])
if hasattr(model, str(torch.nn.functional.scaled_dot_product_attention)):
m_to_add = getattr(model, str(torch.nn.functional.scaled_dot_product_attention))
new_model.add_module(str(torch.nn.functional.scaled_dot_product_attention), m_to_add)

apply_layernorm_affine_merge(new_model)
new_model, rewriters = apply_layernorm_to_rmsnorm(new_model, return_rewriters=True)
rewriters = fix_rewriter(rewriters, model, 'weight')
Expand Down Expand Up @@ -303,20 +307,6 @@ def quantize_llm(args):
apply_layernorm_to_rmsnorm(model)
print("Layernorm To RMSNorm applied.")

if args.rotation == 'fx':
model = offload_model(model)
eq = GraphRotationEqualization(
orphan_sink=args.rotation_orphan_sink,
full_rotation_method=args.rotation_mode,
sdpa_regions=args.rotation_sdpa_regions)
model = eq.apply(model)
remove_hooks(model)
elif args.rotation == 'layerwise':
eq = LayerwiseActivationRotation()
model = eq.apply(model)
elif args.rotation == 'fused_no_fx':
fused_rotation_no_fx(model, calibration_loader, args)

# Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing
# with all the variability in HF implementations
if args.replace_mha:
Expand All @@ -333,6 +323,21 @@ def quantize_llm(args):
with torch.no_grad(), functional_quantization_mode(model, {torch.nn.functional.scaled_dot_product_attention: ScaledDotProductAttention}):
model(**calibration_loader[0])
remove_hooks(model)

if args.rotation == 'fx':
model = offload_model(model)
eq = GraphRotationEqualization(
orphan_sink=args.rotation_orphan_sink,
full_rotation_method=args.rotation_mode,
sdpa_regions=args.rotation_sdpa_regions)
model = eq.apply(model)
remove_hooks(model)
elif args.rotation == 'layerwise':
eq = LayerwiseActivationRotation()
model = eq.apply(model)
elif args.rotation == 'fused_no_fx':
fused_rotation_no_fx(model, calibration_loader, args)

if args.weight_equalization:
print("Apply weight equalization...")
# In case of float16 model, we need to offload to account for missing ops
Expand Down
Loading