From 398ce35b4c1ee1a61d051c4737331b08880c622a Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 15 Jan 2025 13:42:31 +0000 Subject: [PATCH 1/6] temp --- src/brevitas/core/zero_point.py | 30 ++++++++++++++ src/brevitas/graph/base.py | 6 +++ src/brevitas/graph/equalize.py | 6 +++ src/brevitas/graph/quantize.py | 2 +- src/brevitas/nn/equalized_layer.py | 4 +- src/brevitas/nn/quant_sdpa.py | 26 +++++++++--- .../common/generative/quantize.py | 23 ++++++++--- src/brevitas_examples/llm/main.py | 40 ++++++++++++------- 8 files changed, 110 insertions(+), 27 deletions(-) diff --git a/src/brevitas/core/zero_point.py b/src/brevitas/core/zero_point.py index f74fffae8..7038fe1a9 100644 --- a/src/brevitas/core/zero_point.py +++ b/src/brevitas/core/zero_point.py @@ -344,3 +344,33 @@ def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor: # pre-zero centering before rounding and clipping z = self.get_zero_center(x) / scale # need to scale the norm by s return z + + +class RuntimeDynamicGroupZeroScaling(brevitas.jit.ScriptModule): + + def __init__( + self, + group_size: int, + group_dim: int, + input_view_impl: Module, + zero_point_stats_impl: Module, + int_quant, + quantize_zero_point) -> None: + super(RuntimeDynamicGroupZeroScaling, self).__init__() + + self.group_size = group_size + self.group_dim = group_dim + self.zero_point_stats_impl = zero_point_stats_impl + self.input_view_impl = input_view_impl + self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point) + + @brevitas.jit.script_method + def forward( + self, + stats_input: torch.Tensor, + scale, + bit_width) -> torch.Tensor: + + stats_input_reshaped = self.input_view_impl(stats_input) + out = self.zero_point_stats_impl(stats_input_reshaped) + return self.scale_shift_zero_point(-out, scale, bit_width) diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index d1631f34e..f9307e46e 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -7,6 +7,7 @@ from inspect import getcallargs from typing import Any, Callable, Dict, Optional, Type, Union +from brevitas.nn import ScaledDotProductAttention import torch from torch import Tensor from torch.nn import Module @@ -121,6 +122,9 @@ def _map_origin_vars(self, vars: dict): def _module_attributes(self, module): attrs = vars(module) + if isinstance(module, ScaledDotProductAttention): + print(attrs) + # workaround since bias doesn't show up on vars of Linear if hasattr(module, 'bias'): attrs['bias'] = module.bias @@ -147,6 +151,8 @@ def _init_new_module(self, old_module: Module, name=None): new_kwargs = self._module_attributes(old_module) # transforms attribute of original module, e.g. bias Parameter -> bool new_kwargs = self._map_origin_vars(new_kwargs) + if isinstance(old_module, ScaledDotProductAttention): + print(new_kwargs) # restrict to only values that are in the init of the new module new_module_signature_keys = signature_keys(self.new_module_class) new_kwargs = {k: v for k, v in new_kwargs.items() if k in new_module_signature_keys} diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 744130664..f8bd0f5a7 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -10,6 +10,7 @@ from typing import Callable, Dict, List, Optional, Set, Tuple, Union import warnings +from brevitas.nn import ScaledDotProductAttention import packaging import packaging.version import torch @@ -1584,6 +1585,11 @@ def find_sink(node): name_to_module={ 'src0': src_module, 'sink0': sink_module}) regions.append(region) + for m in graph_module.modules(): + if isinstance(m, ScaledDotProductAttention): + m.pre_process_q = functional_rotate_input + m.pre_process_k = functional_rotate_input + # m.pre_process_v = partial(functional_rotate_input, transpose=True) return regions def apply(self, diff --git a/src/brevitas/graph/quantize.py b/src/brevitas/graph/quantize.py index 7724e8f9d..6dbdae4cb 100644 --- a/src/brevitas/graph/quantize.py +++ b/src/brevitas/graph/quantize.py @@ -50,7 +50,7 @@ def __init__(self, model: torch.nn.Module, quant_map: Dict, enabled: bool = True self.enabled = enabled for stateless_function, stateless_module in quant_map.items(): if not hasattr(model, str(stateless_function)): - setattr(model, str(stateless_function), stateless_module()) + model.add_module(str(stateless_function), stateless_module()) def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: diff --git a/src/brevitas/nn/equalized_layer.py b/src/brevitas/nn/equalized_layer.py index 8413a8208..e3d930a50 100644 --- a/src/brevitas/nn/equalized_layer.py +++ b/src/brevitas/nn/equalized_layer.py @@ -81,7 +81,7 @@ def forward(self, inp, **kwargs): def functional_rotate_input(inp, transpose=False): is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None if transpose: - inp = inp.t() + inp = inp.transpose(-2, -1) if is_cuda and fast_hadamard_transform is not None: had_K, K = get_hadK(inp.shape[-1]) inp = matmul_hadU_cuda(inp, had_K, K) @@ -89,5 +89,5 @@ def functional_rotate_input(inp, transpose=False): inp = matmul_hadU(inp) if transpose: - inp = inp.t() + inp = inp.transpose(-2, -1) return inp diff --git a/src/brevitas/nn/quant_sdpa.py b/src/brevitas/nn/quant_sdpa.py index 43f99e827..728d924e9 100644 --- a/src/brevitas/nn/quant_sdpa.py +++ b/src/brevitas/nn/quant_sdpa.py @@ -43,6 +43,8 @@ import math from typing import Optional, Tuple, Union +from brevitas.core.function_wrapper.misc import Identity +from brevitas.function import identity import torch from torch import Tensor from torch.nn import Module @@ -57,6 +59,12 @@ class ScaledDotProductAttention(Module): + def __init__(self, pre_process_q = identity, pre_process_k = identity, pre_process_v = identity): + super().__init__() + self.pre_process_q = pre_process_q + self.pre_process_k = pre_process_k + self.pre_process_v = pre_process_v + def forward( self, query: Tensor, @@ -103,9 +111,9 @@ def forward( if enable_gqa: kwargs["enable_gqa"] = enable_gqa return F.scaled_dot_product_attention( - query=query, - key=key, - value=value, + query=self.pre_process_q(query), + key=self.pre_process_k(key), + value=value,#self.pre_process_v(value), attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, @@ -116,6 +124,7 @@ class QuantScaledDotProductAttention(Module): def __init__( self, + pre_process_q = Identity(), pre_process_k = Identity(), pre_process_v = Identity(), softmax_input_quant=None, attn_output_weights_quant=Uint8ActPerTensorFloat, q_scaled_quant=Int8ActPerTensorFloat, @@ -125,6 +134,11 @@ def __init__( **kwargs) -> None: super(QuantScaledDotProductAttention, self).__init__() + self.pre_process_q = pre_process_q + self.pre_process_k = pre_process_k + self.pre_process_v = pre_process_v + print(self.pre_process_q) + def filter_kwargs(prefix): return {k[len(prefix):]: v for k, v in kwargs.items() if k.startswith(prefix)} @@ -196,14 +210,16 @@ def forward( attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) else: attn_bias += attn_mask - q_scaled = self.q_scaled_quant(query * scale_factor) + query, key, value = self.pre_process_q(query), self.pre_process_k(key), self.pre_process_v(value) + q_scaled = query * scale_factor#self.q_scaled_quant(query * scale_factor) k_transpose = self.k_transposed_quant(key.transpose(-2, -1)) attn_weight = q_scaled @ k_transpose attn_weight += attn_bias attn_weight = self.softmax_input_quant(attn_weight) attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) - attn_weight = self.attn_output_weights_quant(attn_weight) + # attn_weight = self.pre_process_q(attn_weight) + # attn_weight = self.attn_output_weights_quant(attn_weight) attn_output = attn_weight @ self.v_quant(value) attn_output = self.sdpa_output_quant(attn_output) return attn_output diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 778955285..40e6063f4 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -4,6 +4,9 @@ """ import re +from brevitas.core.stats import NegativeMinOrZero +from brevitas.quant.base import ParameterFromRuntimeZeroPoint +from dependencies import this import torch from torch import nn @@ -11,7 +14,7 @@ from brevitas.core.function_wrapper import CeilSte from brevitas.core.function_wrapper import FloorSte from brevitas.core.restrict_val import RoundSte -from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint +from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint, RuntimeDynamicGroupZeroScaling from brevitas.graph.quantize import layerwise_quantize from brevitas.quant.experimental.float import Fp8e4m3Act from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat @@ -57,7 +60,7 @@ from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d from brevitas_examples.common.generative.nn import LoRACompatibleQuantLinear -from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicActPerGroupFloat +from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicActPerGroupFloat, RuntimeDynamicStatsZeroPoint from brevitas_examples.common.generative.quantizers import FP8e4m3OCPDynamicActPerRowFixedPoint from brevitas_examples.common.generative.quantizers import FP8e4m3OCPDynamicActPerRowFloat from brevitas_examples.common.generative.quantizers import Fp8e4m3OCPWeightPerChannelFixedPointMSE @@ -149,6 +152,15 @@ 'per_channel': { 'sym': Fp8e4m3FNUZWeightPerChannelFloat}}}}} +class Test(Int8DynamicActPerGroupFloat): + # zero_point_impl = RuntimeDynamicStatsZeroPoint + zero_point_impl = RuntimeDynamicGroupZeroScaling + zero_point_stats_impl = NegativeMinOrZero + scaling_stats_op = 'min_max' + signed = False + # zero_point_shape = this.scaling_shape + # zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl + INPUT_QUANT_MAP = { 'int': { 'static': { @@ -177,7 +189,8 @@ 'sym': Int8DynamicActPerRowFloat, 'asym': ShiftedUint8DynamicActPerRowFloat}, 'per_group': { - 'sym': Int8DynamicActPerGroupFloat}}}, + 'sym': Int8DynamicActPerGroupFloat, + 'asym': Test}}}, 'po2_scale': { 'stats': { 'per_row': { @@ -388,10 +401,10 @@ def generate_quantizers( elif input_quant_granularity == 'per_group': q_scaled_quant = sym_input_quant.let( **{ - 'group_dim': 2, 'group_size': input_group_size}) + 'group_dim': -1, 'group_size': input_group_size}) k_transposed_quant = sym_input_quant.let( **{ - 'group_dim': 1, 'group_size': input_group_size}) + 'group_dim': -1, 'group_size': input_group_size}) v_quant = q_scaled_quant attn_output_weights_quant = q_scaled_quant else: diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 21641a819..1ebcc9180 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -82,8 +82,17 @@ def set_seed(seed): def fused_rotation_no_fx(model, calibration_loader, args): + print("Here") with torch.no_grad(): new_model, guards = torch._dynamo.export(model)(**calibration_loader[0]) + print(getattr(model, str(torch.nn.functional.scaled_dot_product_attention))) + if hasattr(model, str(torch.nn.functional.scaled_dot_product_attention)): + m_to_add = getattr(model, str(torch.nn.functional.scaled_dot_product_attention)) + new_model.add_module(str(torch.nn.functional.scaled_dot_product_attention), m_to_add) + # for m in new_model.modules(): + # print(type(m)) + # if hasattr(m, 'pre_process_q'): + # raise apply_layernorm_affine_merge(new_model) new_model, rewriters = apply_layernorm_to_rmsnorm(new_model, return_rewriters=True) rewriters = fix_rewriter(rewriters, model, 'weight') @@ -303,19 +312,7 @@ def quantize_llm(args): apply_layernorm_to_rmsnorm(model) print("Layernorm To RMSNorm applied.") - if args.rotation == 'fx': - model = offload_model(model) - eq = GraphRotationEqualization( - orphan_sink=args.rotation_orphan_sink, - full_rotation_method=args.rotation_mode, - sdpa_regions=args.rotation_sdpa_regions) - model = eq.apply(model) - remove_hooks(model) - elif args.rotation == 'layerwise': - eq = LayerwiseActivationRotation() - model = eq.apply(model) - elif args.rotation == 'fused_no_fx': - fused_rotation_no_fx(model, calibration_loader, args) + # Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing # with all the variability in HF implementations @@ -333,6 +330,21 @@ def quantize_llm(args): with torch.no_grad(), functional_quantization_mode(model, {torch.nn.functional.scaled_dot_product_attention: ScaledDotProductAttention}): model(**calibration_loader[0]) remove_hooks(model) + + if args.rotation == 'fx': + model = offload_model(model) + eq = GraphRotationEqualization( + orphan_sink=args.rotation_orphan_sink, + full_rotation_method=args.rotation_mode, + sdpa_regions=args.rotation_sdpa_regions) + model = eq.apply(model) + remove_hooks(model) + elif args.rotation == 'layerwise': + eq = LayerwiseActivationRotation() + model = eq.apply(model) + elif args.rotation == 'fused_no_fx': + fused_rotation_no_fx(model, calibration_loader, args) + if args.weight_equalization: print("Apply weight equalization...") # In case of float16 model, we need to offload to account for missing ops @@ -521,7 +533,7 @@ def quantize_llm(args): print(f"Saving checkpoint to {args.checkpoint_name}") torch.save(model.state_dict(), args.checkpoint_name) - if args.eval and not args.no_quantize: + if args.eval:# and not args.no_quantize: print("Model eval...") with torch.no_grad(), quant_inference_mode(model): model(**calibration_loader[0]) From a21b7714eebfd4784a0f53a273785914d0ec510e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 15 Jan 2025 17:30:08 +0000 Subject: [PATCH 2/6] temp2 --- src/brevitas/graph/equalize.py | 3 +-- src/brevitas/nn/quant_sdpa.py | 17 +++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index f8bd0f5a7..31b2d4f72 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -10,7 +10,6 @@ from typing import Callable, Dict, List, Optional, Set, Tuple, Union import warnings -from brevitas.nn import ScaledDotProductAttention import packaging import packaging.version import torch @@ -35,6 +34,7 @@ from brevitas.graph.hadamard import random_hadamard_matrix from brevitas.graph.utils import get_module from brevitas.graph.utils import get_node +from brevitas.nn import ScaledDotProductAttention from brevitas.nn.equalized_layer import EqualizedModule from brevitas.nn.equalized_layer import functional_rotate_input from brevitas.nn.equalized_layer import INPUT_NAMES @@ -1589,7 +1589,6 @@ def find_sink(node): if isinstance(m, ScaledDotProductAttention): m.pre_process_q = functional_rotate_input m.pre_process_k = functional_rotate_input - # m.pre_process_v = partial(functional_rotate_input, transpose=True) return regions def apply(self, diff --git a/src/brevitas/nn/quant_sdpa.py b/src/brevitas/nn/quant_sdpa.py index 728d924e9..9f82ccc0c 100644 --- a/src/brevitas/nn/quant_sdpa.py +++ b/src/brevitas/nn/quant_sdpa.py @@ -43,14 +43,14 @@ import math from typing import Optional, Tuple, Union -from brevitas.core.function_wrapper.misc import Identity -from brevitas.function import identity import torch from torch import Tensor from torch.nn import Module from torch.nn import Parameter import torch.nn.functional as F +from brevitas.core.function_wrapper.misc import Identity +from brevitas.function import identity from brevitas.quant.scaled_int import Int8ActPerTensorFloat from brevitas.quant.scaled_int import Uint8ActPerTensorFloat @@ -59,7 +59,7 @@ class ScaledDotProductAttention(Module): - def __init__(self, pre_process_q = identity, pre_process_k = identity, pre_process_v = identity): + def __init__(self, pre_process_q=identity, pre_process_k=identity, pre_process_v=identity): super().__init__() self.pre_process_q = pre_process_q self.pre_process_k = pre_process_k @@ -113,7 +113,7 @@ def forward( return F.scaled_dot_product_attention( query=self.pre_process_q(query), key=self.pre_process_k(key), - value=value,#self.pre_process_v(value), + value=self.pre_process_v(value), attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, @@ -124,7 +124,9 @@ class QuantScaledDotProductAttention(Module): def __init__( self, - pre_process_q = Identity(), pre_process_k = Identity(), pre_process_v = Identity(), + pre_process_q=identity, + pre_process_k=identity, + pre_process_v=identity, softmax_input_quant=None, attn_output_weights_quant=Uint8ActPerTensorFloat, q_scaled_quant=Int8ActPerTensorFloat, @@ -211,15 +213,14 @@ def forward( else: attn_bias += attn_mask query, key, value = self.pre_process_q(query), self.pre_process_k(key), self.pre_process_v(value) - q_scaled = query * scale_factor#self.q_scaled_quant(query * scale_factor) + q_scaled = self.q_scaled_quant(query * scale_factor) k_transpose = self.k_transposed_quant(key.transpose(-2, -1)) attn_weight = q_scaled @ k_transpose attn_weight += attn_bias attn_weight = self.softmax_input_quant(attn_weight) attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) - # attn_weight = self.pre_process_q(attn_weight) - # attn_weight = self.attn_output_weights_quant(attn_weight) + attn_weight = self.attn_output_weights_quant(attn_weight) attn_output = attn_weight @ self.v_quant(value) attn_output = self.sdpa_output_quant(attn_output) return attn_output From 925c3a553c2df6d4504174eeb79da4b87fe461cf Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 15 Jan 2025 18:02:57 +0000 Subject: [PATCH 3/6] fix --- src/brevitas/graph/base.py | 6 +----- src/brevitas_examples/llm/main.py | 11 ++--------- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index f9307e46e..13cd19d2e 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -7,7 +7,6 @@ from inspect import getcallargs from typing import Any, Callable, Dict, Optional, Type, Union -from brevitas.nn import ScaledDotProductAttention import torch from torch import Tensor from torch.nn import Module @@ -19,6 +18,7 @@ from brevitas.fx import immutable_dict from brevitas.fx import Node from brevitas.graph.utils import * +from brevitas.nn import ScaledDotProductAttention from brevitas.utils.python_utils import islambda from brevitas.utils.rotation_utils import RotationWeightParametrization @@ -122,8 +122,6 @@ def _map_origin_vars(self, vars: dict): def _module_attributes(self, module): attrs = vars(module) - if isinstance(module, ScaledDotProductAttention): - print(attrs) # workaround since bias doesn't show up on vars of Linear if hasattr(module, 'bias'): @@ -151,8 +149,6 @@ def _init_new_module(self, old_module: Module, name=None): new_kwargs = self._module_attributes(old_module) # transforms attribute of original module, e.g. bias Parameter -> bool new_kwargs = self._map_origin_vars(new_kwargs) - if isinstance(old_module, ScaledDotProductAttention): - print(new_kwargs) # restrict to only values that are in the init of the new module new_module_signature_keys = signature_keys(self.new_module_class) new_kwargs = {k: v for k, v in new_kwargs.items() if k in new_module_signature_keys} diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 1ebcc9180..5e7e3db93 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -82,17 +82,12 @@ def set_seed(seed): def fused_rotation_no_fx(model, calibration_loader, args): - print("Here") with torch.no_grad(): new_model, guards = torch._dynamo.export(model)(**calibration_loader[0]) - print(getattr(model, str(torch.nn.functional.scaled_dot_product_attention))) if hasattr(model, str(torch.nn.functional.scaled_dot_product_attention)): m_to_add = getattr(model, str(torch.nn.functional.scaled_dot_product_attention)) new_model.add_module(str(torch.nn.functional.scaled_dot_product_attention), m_to_add) - # for m in new_model.modules(): - # print(type(m)) - # if hasattr(m, 'pre_process_q'): - # raise + apply_layernorm_affine_merge(new_model) new_model, rewriters = apply_layernorm_to_rmsnorm(new_model, return_rewriters=True) rewriters = fix_rewriter(rewriters, model, 'weight') @@ -312,8 +307,6 @@ def quantize_llm(args): apply_layernorm_to_rmsnorm(model) print("Layernorm To RMSNorm applied.") - - # Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing # with all the variability in HF implementations if args.replace_mha: @@ -533,7 +526,7 @@ def quantize_llm(args): print(f"Saving checkpoint to {args.checkpoint_name}") torch.save(model.state_dict(), args.checkpoint_name) - if args.eval:# and not args.no_quantize: + if args.eval and not args.no_quantize: print("Model eval...") with torch.no_grad(), quant_inference_mode(model): model(**calibration_loader[0]) From 3ea4b0387810fbc885f2f76142db79fc3c2f0c5e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 15 Jan 2025 18:08:49 +0000 Subject: [PATCH 4/6] cleanup --- src/brevitas/core/zero_point.py | 30 ------------------- .../common/generative/quantize.py | 22 +++++--------- 2 files changed, 7 insertions(+), 45 deletions(-) diff --git a/src/brevitas/core/zero_point.py b/src/brevitas/core/zero_point.py index 7038fe1a9..f74fffae8 100644 --- a/src/brevitas/core/zero_point.py +++ b/src/brevitas/core/zero_point.py @@ -344,33 +344,3 @@ def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor: # pre-zero centering before rounding and clipping z = self.get_zero_center(x) / scale # need to scale the norm by s return z - - -class RuntimeDynamicGroupZeroScaling(brevitas.jit.ScriptModule): - - def __init__( - self, - group_size: int, - group_dim: int, - input_view_impl: Module, - zero_point_stats_impl: Module, - int_quant, - quantize_zero_point) -> None: - super(RuntimeDynamicGroupZeroScaling, self).__init__() - - self.group_size = group_size - self.group_dim = group_dim - self.zero_point_stats_impl = zero_point_stats_impl - self.input_view_impl = input_view_impl - self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point) - - @brevitas.jit.script_method - def forward( - self, - stats_input: torch.Tensor, - scale, - bit_width) -> torch.Tensor: - - stats_input_reshaped = self.input_view_impl(stats_input) - out = self.zero_point_stats_impl(stats_input_reshaped) - return self.scale_shift_zero_point(-out, scale, bit_width) diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 40e6063f4..0e85917b2 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -4,8 +4,6 @@ """ import re -from brevitas.core.stats import NegativeMinOrZero -from brevitas.quant.base import ParameterFromRuntimeZeroPoint from dependencies import this import torch from torch import nn @@ -14,8 +12,11 @@ from brevitas.core.function_wrapper import CeilSte from brevitas.core.function_wrapper import FloorSte from brevitas.core.restrict_val import RoundSte -from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint, RuntimeDynamicGroupZeroScaling +from brevitas.core.stats import NegativeMinOrZero +from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint +from brevitas.core.zero_point import RuntimeDynamicGroupZeroScaling from brevitas.graph.quantize import layerwise_quantize +from brevitas.quant.base import ParameterFromRuntimeZeroPoint from brevitas.quant.experimental.float import Fp8e4m3Act from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat @@ -60,7 +61,7 @@ from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d from brevitas_examples.common.generative.nn import LoRACompatibleQuantLinear -from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicActPerGroupFloat, RuntimeDynamicStatsZeroPoint +from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicActPerGroupFloat from brevitas_examples.common.generative.quantizers import FP8e4m3OCPDynamicActPerRowFixedPoint from brevitas_examples.common.generative.quantizers import FP8e4m3OCPDynamicActPerRowFloat from brevitas_examples.common.generative.quantizers import Fp8e4m3OCPWeightPerChannelFixedPointMSE @@ -71,6 +72,7 @@ from brevitas_examples.common.generative.quantizers import Int8DynamicActPerRowFloat from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat from brevitas_examples.common.generative.quantizers import IntWeightSymmetricGroupQuant +from brevitas_examples.common.generative.quantizers import RuntimeDynamicStatsZeroPoint from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerRowFloat from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat @@ -152,15 +154,6 @@ 'per_channel': { 'sym': Fp8e4m3FNUZWeightPerChannelFloat}}}}} -class Test(Int8DynamicActPerGroupFloat): - # zero_point_impl = RuntimeDynamicStatsZeroPoint - zero_point_impl = RuntimeDynamicGroupZeroScaling - zero_point_stats_impl = NegativeMinOrZero - scaling_stats_op = 'min_max' - signed = False - # zero_point_shape = this.scaling_shape - # zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl - INPUT_QUANT_MAP = { 'int': { 'static': { @@ -189,8 +182,7 @@ class Test(Int8DynamicActPerGroupFloat): 'sym': Int8DynamicActPerRowFloat, 'asym': ShiftedUint8DynamicActPerRowFloat}, 'per_group': { - 'sym': Int8DynamicActPerGroupFloat, - 'asym': Test}}}, + 'sym': Int8DynamicActPerGroupFloat}}}, 'po2_scale': { 'stats': { 'per_row': { From d26ca4229d4dfed4cb6fa3c5adf3737bb226dee1 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 15 Jan 2025 18:10:09 +0000 Subject: [PATCH 5/6] cleanup --- src/brevitas/graph/base.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index 13cd19d2e..d1631f34e 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -18,7 +18,6 @@ from brevitas.fx import immutable_dict from brevitas.fx import Node from brevitas.graph.utils import * -from brevitas.nn import ScaledDotProductAttention from brevitas.utils.python_utils import islambda from brevitas.utils.rotation_utils import RotationWeightParametrization @@ -122,7 +121,6 @@ def _map_origin_vars(self, vars: dict): def _module_attributes(self, module): attrs = vars(module) - # workaround since bias doesn't show up on vars of Linear if hasattr(module, 'bias'): attrs['bias'] = module.bias From 75d3d4567333053d75d80f22f35378460fca64cf Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 15 Jan 2025 18:17:12 +0000 Subject: [PATCH 6/6] Fix --- src/brevitas_examples/common/generative/quantize.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 0e85917b2..d845f58a6 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -14,7 +14,6 @@ from brevitas.core.restrict_val import RoundSte from brevitas.core.stats import NegativeMinOrZero from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint -from brevitas.core.zero_point import RuntimeDynamicGroupZeroScaling from brevitas.graph.quantize import layerwise_quantize from brevitas.quant.base import ParameterFromRuntimeZeroPoint from brevitas.quant.experimental.float import Fp8e4m3Act