Skip to content

Commit

Permalink
temp
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 15, 2025
1 parent a7efcba commit 398ce35
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 27 deletions.
30 changes: 30 additions & 0 deletions src/brevitas/core/zero_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,3 +344,33 @@ def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor:
# pre-zero centering before rounding and clipping
z = self.get_zero_center(x) / scale # need to scale the norm by s
return z


class RuntimeDynamicGroupZeroScaling(brevitas.jit.ScriptModule):

def __init__(
self,
group_size: int,
group_dim: int,
input_view_impl: Module,
zero_point_stats_impl: Module,
int_quant,
quantize_zero_point) -> None:
super(RuntimeDynamicGroupZeroScaling, self).__init__()

self.group_size = group_size
self.group_dim = group_dim
self.zero_point_stats_impl = zero_point_stats_impl
self.input_view_impl = input_view_impl
self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point)

@brevitas.jit.script_method
def forward(
self,
stats_input: torch.Tensor,
scale,
bit_width) -> torch.Tensor:

stats_input_reshaped = self.input_view_impl(stats_input)
out = self.zero_point_stats_impl(stats_input_reshaped)
return self.scale_shift_zero_point(-out, scale, bit_width)
6 changes: 6 additions & 0 deletions src/brevitas/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from inspect import getcallargs
from typing import Any, Callable, Dict, Optional, Type, Union

from brevitas.nn import ScaledDotProductAttention
import torch
from torch import Tensor
from torch.nn import Module
Expand Down Expand Up @@ -121,6 +122,9 @@ def _map_origin_vars(self, vars: dict):

def _module_attributes(self, module):
attrs = vars(module)
if isinstance(module, ScaledDotProductAttention):
print(attrs)

# workaround since bias doesn't show up on vars of Linear
if hasattr(module, 'bias'):
attrs['bias'] = module.bias
Expand All @@ -147,6 +151,8 @@ def _init_new_module(self, old_module: Module, name=None):
new_kwargs = self._module_attributes(old_module)
# transforms attribute of original module, e.g. bias Parameter -> bool
new_kwargs = self._map_origin_vars(new_kwargs)
if isinstance(old_module, ScaledDotProductAttention):
print(new_kwargs)
# restrict to only values that are in the init of the new module
new_module_signature_keys = signature_keys(self.new_module_class)
new_kwargs = {k: v for k, v in new_kwargs.items() if k in new_module_signature_keys}
Expand Down
6 changes: 6 additions & 0 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
import warnings

from brevitas.nn import ScaledDotProductAttention
import packaging
import packaging.version
import torch
Expand Down Expand Up @@ -1584,6 +1585,11 @@ def find_sink(node):
name_to_module={
'src0': src_module, 'sink0': sink_module})
regions.append(region)
for m in graph_module.modules():
if isinstance(m, ScaledDotProductAttention):
m.pre_process_q = functional_rotate_input
m.pre_process_k = functional_rotate_input
# m.pre_process_v = partial(functional_rotate_input, transpose=True)
return regions

def apply(self,
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/graph/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, model: torch.nn.Module, quant_map: Dict, enabled: bool = True
self.enabled = enabled
for stateless_function, stateless_module in quant_map.items():
if not hasattr(model, str(stateless_function)):
setattr(model, str(stateless_function), stateless_module())
model.add_module(str(stateless_function), stateless_module())

def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/nn/equalized_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ def forward(self, inp, **kwargs):
def functional_rotate_input(inp, transpose=False):
is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None
if transpose:
inp = inp.t()
inp = inp.transpose(-2, -1)
if is_cuda and fast_hadamard_transform is not None:
had_K, K = get_hadK(inp.shape[-1])
inp = matmul_hadU_cuda(inp, had_K, K)
else:
inp = matmul_hadU(inp)

if transpose:
inp = inp.t()
inp = inp.transpose(-2, -1)
return inp
26 changes: 21 additions & 5 deletions src/brevitas/nn/quant_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
import math
from typing import Optional, Tuple, Union

from brevitas.core.function_wrapper.misc import Identity
from brevitas.function import identity
import torch
from torch import Tensor
from torch.nn import Module
Expand All @@ -57,6 +59,12 @@

class ScaledDotProductAttention(Module):

def __init__(self, pre_process_q = identity, pre_process_k = identity, pre_process_v = identity):
super().__init__()
self.pre_process_q = pre_process_q
self.pre_process_k = pre_process_k
self.pre_process_v = pre_process_v

def forward(
self,
query: Tensor,
Expand Down Expand Up @@ -103,9 +111,9 @@ def forward(
if enable_gqa:
kwargs["enable_gqa"] = enable_gqa
return F.scaled_dot_product_attention(
query=query,
key=key,
value=value,
query=self.pre_process_q(query),
key=self.pre_process_k(key),
value=value,#self.pre_process_v(value),
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
Expand All @@ -116,6 +124,7 @@ class QuantScaledDotProductAttention(Module):

def __init__(
self,
pre_process_q = Identity(), pre_process_k = Identity(), pre_process_v = Identity(),
softmax_input_quant=None,
attn_output_weights_quant=Uint8ActPerTensorFloat,
q_scaled_quant=Int8ActPerTensorFloat,
Expand All @@ -125,6 +134,11 @@ def __init__(
**kwargs) -> None:
super(QuantScaledDotProductAttention, self).__init__()

self.pre_process_q = pre_process_q
self.pre_process_k = pre_process_k
self.pre_process_v = pre_process_v
print(self.pre_process_q)

def filter_kwargs(prefix):
return {k[len(prefix):]: v for k, v in kwargs.items() if k.startswith(prefix)}

Expand Down Expand Up @@ -196,14 +210,16 @@ def forward(
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
q_scaled = self.q_scaled_quant(query * scale_factor)
query, key, value = self.pre_process_q(query), self.pre_process_k(key), self.pre_process_v(value)
q_scaled = query * scale_factor#self.q_scaled_quant(query * scale_factor)
k_transpose = self.k_transposed_quant(key.transpose(-2, -1))
attn_weight = q_scaled @ k_transpose
attn_weight += attn_bias
attn_weight = self.softmax_input_quant(attn_weight)
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
attn_weight = self.attn_output_weights_quant(attn_weight)
# attn_weight = self.pre_process_q(attn_weight)
# attn_weight = self.attn_output_weights_quant(attn_weight)
attn_output = attn_weight @ self.v_quant(value)
attn_output = self.sdpa_output_quant(attn_output)
return attn_output
23 changes: 18 additions & 5 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
"""
import re

from brevitas.core.stats import NegativeMinOrZero
from brevitas.quant.base import ParameterFromRuntimeZeroPoint
from dependencies import this
import torch
from torch import nn

from brevitas import nn as qnn
from brevitas.core.function_wrapper import CeilSte
from brevitas.core.function_wrapper import FloorSte
from brevitas.core.restrict_val import RoundSte
from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint
from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint, RuntimeDynamicGroupZeroScaling
from brevitas.graph.quantize import layerwise_quantize
from brevitas.quant.experimental.float import Fp8e4m3Act
from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat
Expand Down Expand Up @@ -57,7 +60,7 @@
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE
from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d
from brevitas_examples.common.generative.nn import LoRACompatibleQuantLinear
from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicActPerGroupFloat
from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicActPerGroupFloat, RuntimeDynamicStatsZeroPoint
from brevitas_examples.common.generative.quantizers import FP8e4m3OCPDynamicActPerRowFixedPoint
from brevitas_examples.common.generative.quantizers import FP8e4m3OCPDynamicActPerRowFloat
from brevitas_examples.common.generative.quantizers import Fp8e4m3OCPWeightPerChannelFixedPointMSE
Expand Down Expand Up @@ -149,6 +152,15 @@
'per_channel': {
'sym': Fp8e4m3FNUZWeightPerChannelFloat}}}}}

class Test(Int8DynamicActPerGroupFloat):
# zero_point_impl = RuntimeDynamicStatsZeroPoint
zero_point_impl = RuntimeDynamicGroupZeroScaling
zero_point_stats_impl = NegativeMinOrZero
scaling_stats_op = 'min_max'
signed = False
# zero_point_shape = this.scaling_shape
# zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl

INPUT_QUANT_MAP = {
'int': {
'static': {
Expand Down Expand Up @@ -177,7 +189,8 @@
'sym': Int8DynamicActPerRowFloat,
'asym': ShiftedUint8DynamicActPerRowFloat},
'per_group': {
'sym': Int8DynamicActPerGroupFloat}}},
'sym': Int8DynamicActPerGroupFloat,
'asym': Test}}},
'po2_scale': {
'stats': {
'per_row': {
Expand Down Expand Up @@ -388,10 +401,10 @@ def generate_quantizers(
elif input_quant_granularity == 'per_group':
q_scaled_quant = sym_input_quant.let(
**{
'group_dim': 2, 'group_size': input_group_size})
'group_dim': -1, 'group_size': input_group_size})
k_transposed_quant = sym_input_quant.let(
**{
'group_dim': 1, 'group_size': input_group_size})
'group_dim': -1, 'group_size': input_group_size})
v_quant = q_scaled_quant
attn_output_weights_quant = q_scaled_quant
else:
Expand Down
40 changes: 26 additions & 14 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,17 @@ def set_seed(seed):


def fused_rotation_no_fx(model, calibration_loader, args):
print("Here")
with torch.no_grad():
new_model, guards = torch._dynamo.export(model)(**calibration_loader[0])
print(getattr(model, str(torch.nn.functional.scaled_dot_product_attention)))
if hasattr(model, str(torch.nn.functional.scaled_dot_product_attention)):
m_to_add = getattr(model, str(torch.nn.functional.scaled_dot_product_attention))
new_model.add_module(str(torch.nn.functional.scaled_dot_product_attention), m_to_add)
# for m in new_model.modules():
# print(type(m))
# if hasattr(m, 'pre_process_q'):
# raise
apply_layernorm_affine_merge(new_model)
new_model, rewriters = apply_layernorm_to_rmsnorm(new_model, return_rewriters=True)
rewriters = fix_rewriter(rewriters, model, 'weight')
Expand Down Expand Up @@ -303,19 +312,7 @@ def quantize_llm(args):
apply_layernorm_to_rmsnorm(model)
print("Layernorm To RMSNorm applied.")

if args.rotation == 'fx':
model = offload_model(model)
eq = GraphRotationEqualization(
orphan_sink=args.rotation_orphan_sink,
full_rotation_method=args.rotation_mode,
sdpa_regions=args.rotation_sdpa_regions)
model = eq.apply(model)
remove_hooks(model)
elif args.rotation == 'layerwise':
eq = LayerwiseActivationRotation()
model = eq.apply(model)
elif args.rotation == 'fused_no_fx':
fused_rotation_no_fx(model, calibration_loader, args)


# Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing
# with all the variability in HF implementations
Expand All @@ -333,6 +330,21 @@ def quantize_llm(args):
with torch.no_grad(), functional_quantization_mode(model, {torch.nn.functional.scaled_dot_product_attention: ScaledDotProductAttention}):
model(**calibration_loader[0])
remove_hooks(model)

if args.rotation == 'fx':
model = offload_model(model)
eq = GraphRotationEqualization(
orphan_sink=args.rotation_orphan_sink,
full_rotation_method=args.rotation_mode,
sdpa_regions=args.rotation_sdpa_regions)
model = eq.apply(model)
remove_hooks(model)
elif args.rotation == 'layerwise':
eq = LayerwiseActivationRotation()
model = eq.apply(model)
elif args.rotation == 'fused_no_fx':
fused_rotation_no_fx(model, calibration_loader, args)

if args.weight_equalization:
print("Apply weight equalization...")
# In case of float16 model, we need to offload to account for missing ops
Expand Down Expand Up @@ -521,7 +533,7 @@ def quantize_llm(args):
print(f"Saving checkpoint to {args.checkpoint_name}")
torch.save(model.state_dict(), args.checkpoint_name)

if args.eval and not args.no_quantize:
if args.eval:# and not args.no_quantize:
print("Model eval...")
with torch.no_grad(), quant_inference_mode(model):
model(**calibration_loader[0])
Expand Down

0 comments on commit 398ce35

Please sign in to comment.