From 479b4d91d9b9116b43edfd3d89a1f3a7ac2c8a83 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 20 Aug 2024 12:28:21 +0100 Subject: [PATCH 01/25] feat nn: add `ScaledDotProductAttention` and `QuantScaledDotProductAttention` classes --- src/brevitas/nn/quant_sdpa.py | 166 ++++++++++++++++++++++++++++++++++ 1 file changed, 166 insertions(+) create mode 100644 src/brevitas/nn/quant_sdpa.py diff --git a/src/brevitas/nn/quant_sdpa.py b/src/brevitas/nn/quant_sdpa.py new file mode 100644 index 000000000..368eaa565 --- /dev/null +++ b/src/brevitas/nn/quant_sdpa.py @@ -0,0 +1,166 @@ +""" +Copyright (C) 2024, Advanced Micro Devices, Inc. +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of AMD, Facebook, Deepmind Technologies, NYU, + NEC Laboratories America and IDIAP Research Institute nor the names + of its contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. +""" + +from typing import Optional, Tuple, Union + +import torch +from torch import Tensor +from torch.nn import Module +from torch.nn import Parameter +import torch.nn.functional as F + +from brevitas.quant.scaled_int import Int8ActPerTensorFloat +from brevitas.quant.scaled_int import Uint8ActPerTensorFloat + + +class ScaledDotProductAttention(Module): + def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False): + r""" + Args: + query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`. + key (Tensor): Key tensor; shape :math:`(N, ..., H, S, E)`. + value (Tensor): Value tensor; shape :math:`(N, ..., H, S, Ev)`. + attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights, + which is :math:`(N,..., L, S)`. Two types of masks are supported. + A boolean mask where a value of True indicates that the element *should* take part in attention. + A float mask of the same type as query, key, value that is added to the attention score. + dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied + is_causal (bool): If set to true, the attention masking is a lower triangular matrix when the mask is a + square matrix. The attention masking has the form of the upper left causal bias due to the alignment + (see :class:`torch.nn.attention.bias.CausalBias`) when the mask is a non-square matrix. + An error is thrown if both attn_mask and is_causal are set. + scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set + to :math:`\frac{1}{\sqrt{E}}`. + enable_gqa (bool): Ignored to make calling interface compatible with PyTorch >v2.5. Always set to False. + + Returns: + output (Tensor): Attention output; shape :math:`(N, ..., Hq, L, Ev)`. + + Shape legend: + - :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}` + - :math:`S: \text{Source sequence length}` + - :math:`L: \text{Target sequence length}` + - :math:`E: \text{Embedding dimension of the query and key}` + - :math:`Ev: \text{Embedding dimension of the value}` + - :math:`Hq: \text{Number of heads of query}` + - :math:`H: \text{Number of heads of key and value}` + """ + return F.scaled_dot_product_attention(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale) + + +class QuantScaledDotProductAttention(Module): + def __init__(self, query_quant=Int8ActPerTensorFloat, key_quant=Int8ActPerTensorFloat, value_quant=Int8ActPerTensorFloat, softmax_input_quant=Int8ActPerTensorFloat, softmax_output_quant=Uint8ActPerTensorFloat, attn_output_quant=None, **kwargs) -> None: + super(QuantScaledDotProductAttention, self).__init__() + + def filter_kwargs(prefix): + return {k[len(prefix):]: v for k, v in kwargs.items() if k.startswith(prefix)} + + self.query_quant = QuantIdentity( + act_quant=query_quant, **filter_kwargs('query_')) + self.key_quant = QuantIdentity( + act_quant=key_quant, **filter_kwargs('key_')) + self.value_quant = QuantIdentity( + act_quant=value_quant, **filter_kwargs('value_')) + self.softmax_input_quant = QuantIdentity( + act_quant=softmax_input_quant, **filter_kwargs('softmax_input_')) + self.softmax_output_quant = QuantIdentity( + act_quant=softmax_output_quant, **filter_kwargs('softmax_output_')) + self.attn_output_quant = QuantIdentity( + act_quant=attn_output_quant, **filter_kwargs('attn_output_')) + + def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False): + r""" + Args: + query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`. + key (Tensor): Key tensor; shape :math:`(N, ..., H, S, E)`. + value (Tensor): Value tensor; shape :math:`(N, ..., H, S, Ev)`. + attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights, + which is :math:`(N,..., L, S)`. Two types of masks are supported. + A boolean mask where a value of True indicates that the element *should* take part in attention. + A float mask of the same type as query, key, value that is added to the attention score. + dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied + is_causal (bool): If set to true, the attention masking is a lower triangular matrix when the mask is a + square matrix. The attention masking has the form of the upper left causal bias due to the alignment + (see :class:`torch.nn.attention.bias.CausalBias`) when the mask is a non-square matrix. + An error is thrown if both attn_mask and is_causal are set. + scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set + to :math:`\frac{1}{\sqrt{E}}`. + enable_gqa (bool): Ignored to make calling interface compatible with PyTorch >v2.5. Always set to False. + + Returns: + output (Tensor): Attention output; shape :math:`(N, ..., Hq, L, Ev)`. + + Shape legend: + - :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}` + - :math:`S: \text{Source sequence length}` + - :math:`L: \text{Target sequence length}` + - :math:`E: \text{Embedding dimension of the query and key}` + - :math:`Ev: \text{Embedding dimension of the value}` + - :math:`Hq: \text{Number of heads of query}` + - :math:`H: \text{Number of heads of key and value}` + """ + query = self.query_quant(query) + key = self.key_quant(key) + value = self.value_quant(value) + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_bias = torch.zeros(L, S, dtype=query.dtype) + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = self.softmax_input_quant(attn_weight) + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + attn_weight = self.softmax_output_quant(attn_weight) + attn_output = attn_weight @ value + attn_output = self.attn_output_quant(attn_output) + return attn_output From 8f69b691e92fa54c72a87120299b9b56e303c53c Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 20 Aug 2024 12:31:06 +0100 Subject: [PATCH 02/25] Feat (graph/standardize): Add SDPA conversion to modular version --- src/brevitas/graph/standardize.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/brevitas/graph/standardize.py b/src/brevitas/graph/standardize.py index 93e99eac9..12692eddf 100644 --- a/src/brevitas/graph/standardize.py +++ b/src/brevitas/graph/standardize.py @@ -10,6 +10,7 @@ from brevitas.fx import GraphModule from brevitas.fx import immutable_dict from brevitas.fx import Node +from brevitas.nn.quant_sdpa import ScaledDotProductAttention from .base import FnToModule from .base import GraphTransform @@ -109,7 +110,8 @@ class TorchFunctionalToModule(GraphTransform): (F.avg_pool3d, nn.AvgPool3d), (F.adaptive_avg_pool1d, nn.AdaptiveAvgPool1d), (F.adaptive_avg_pool2d, nn.AdaptiveAvgPool2d), (F.adaptive_avg_pool3d, - nn.AdaptiveAvgPool3d), (F.dropout, nn.Dropout)) + nn.AdaptiveAvgPool3d), (F.dropout, nn.Dropout), + (F.scaled_dot_product_attention, ScaledDotProductAttention)) def __init__(self, fn_to_module_map=FN_TO_MODULE_MAP): super().__init__() From 4a93631220d2cf76332fdd746ae152f862e05ac8 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 20 Aug 2024 12:32:49 +0100 Subject: [PATCH 03/25] Feat (example/llm): Specify LLMs to use SDPA for their attn implementation --- src/brevitas_examples/llm/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index c8c76d4a1..2e0d302c8 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -186,7 +186,7 @@ def main(args): kwargs['torchscript'] = True print("Model loading...") - model = AutoModelForCausalLM.from_pretrained(args.model, **kwargs) + model = AutoModelForCausalLM.from_pretrained(args.model, attn_implementation="sdpa", **kwargs) print("Model loaded.") model.eval() tokenizer = AutoTokenizer.from_pretrained(args.model) From 0ac0db8ee0cc1098469997d4a02d070d47b3a5d0 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 8 Nov 2024 11:13:10 +0000 Subject: [PATCH 04/25] Fix (nn/sdpa): formatting --- src/brevitas/graph/standardize.py | 2 +- src/brevitas/nn/quant_sdpa.py | 60 +++++++++++++++++++++++-------- 2 files changed, 47 insertions(+), 15 deletions(-) diff --git a/src/brevitas/graph/standardize.py b/src/brevitas/graph/standardize.py index 12692eddf..4c9233e08 100644 --- a/src/brevitas/graph/standardize.py +++ b/src/brevitas/graph/standardize.py @@ -10,7 +10,7 @@ from brevitas.fx import GraphModule from brevitas.fx import immutable_dict from brevitas.fx import Node -from brevitas.nn.quant_sdpa import ScaledDotProductAttention +from brevitas.nn.quant_sdpa import ScaledDotProductAttention from .base import FnToModule from .base import GraphTransform diff --git a/src/brevitas/nn/quant_sdpa.py b/src/brevitas/nn/quant_sdpa.py index 368eaa565..9927ca41a 100644 --- a/src/brevitas/nn/quant_sdpa.py +++ b/src/brevitas/nn/quant_sdpa.py @@ -53,7 +53,17 @@ class ScaledDotProductAttention(Module): - def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False): + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Optional[Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False): r""" Args: query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`. @@ -71,10 +81,10 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set to :math:`\frac{1}{\sqrt{E}}`. enable_gqa (bool): Ignored to make calling interface compatible with PyTorch >v2.5. Always set to False. - + Returns: output (Tensor): Attention output; shape :math:`(N, ..., Hq, L, Ev)`. - + Shape legend: - :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}` - :math:`S: \text{Source sequence length}` @@ -84,22 +94,35 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional - :math:`Hq: \text{Number of heads of query}` - :math:`H: \text{Number of heads of key and value}` """ - return F.scaled_dot_product_attention(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale) + return F.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale) class QuantScaledDotProductAttention(Module): - def __init__(self, query_quant=Int8ActPerTensorFloat, key_quant=Int8ActPerTensorFloat, value_quant=Int8ActPerTensorFloat, softmax_input_quant=Int8ActPerTensorFloat, softmax_output_quant=Uint8ActPerTensorFloat, attn_output_quant=None, **kwargs) -> None: + + def __init__( + self, + query_quant=Int8ActPerTensorFloat, + key_quant=Int8ActPerTensorFloat, + value_quant=Int8ActPerTensorFloat, + softmax_input_quant=Int8ActPerTensorFloat, + softmax_output_quant=Uint8ActPerTensorFloat, + attn_output_quant=None, + **kwargs) -> None: super(QuantScaledDotProductAttention, self).__init__() def filter_kwargs(prefix): return {k[len(prefix):]: v for k, v in kwargs.items() if k.startswith(prefix)} - self.query_quant = QuantIdentity( - act_quant=query_quant, **filter_kwargs('query_')) - self.key_quant = QuantIdentity( - act_quant=key_quant, **filter_kwargs('key_')) - self.value_quant = QuantIdentity( - act_quant=value_quant, **filter_kwargs('value_')) + self.query_quant = QuantIdentity(act_quant=query_quant, **filter_kwargs('query_')) + self.key_quant = QuantIdentity(act_quant=key_quant, **filter_kwargs('key_')) + self.value_quant = QuantIdentity(act_quant=value_quant, **filter_kwargs('value_')) self.softmax_input_quant = QuantIdentity( act_quant=softmax_input_quant, **filter_kwargs('softmax_input_')) self.softmax_output_quant = QuantIdentity( @@ -107,7 +130,16 @@ def filter_kwargs(prefix): self.attn_output_quant = QuantIdentity( act_quant=attn_output_quant, **filter_kwargs('attn_output_')) - def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False): + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Optional[Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False): r""" Args: query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`. @@ -125,10 +157,10 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional scale (optional float, keyword-only): Scaling factor applied prior to softmax. If None, the default value is set to :math:`\frac{1}{\sqrt{E}}`. enable_gqa (bool): Ignored to make calling interface compatible with PyTorch >v2.5. Always set to False. - + Returns: output (Tensor): Attention output; shape :math:`(N, ..., Hq, L, Ev)`. - + Shape legend: - :math:`N: \text{Batch size} ... : \text{Any number of other batch dimensions (optional)}` - :math:`S: \text{Source sequence length}` From d0b9f501782f228d86a81553912d79d576ddb324 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Mon, 18 Nov 2024 15:58:44 +0000 Subject: [PATCH 05/25] Fix (nn/sdpa): Updated argument to match qsdpa --- src/brevitas/nn/quant_sdpa.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/src/brevitas/nn/quant_sdpa.py b/src/brevitas/nn/quant_sdpa.py index 9927ca41a..a7a61d56d 100644 --- a/src/brevitas/nn/quant_sdpa.py +++ b/src/brevitas/nn/quant_sdpa.py @@ -108,11 +108,11 @@ class QuantScaledDotProductAttention(Module): def __init__( self, - query_quant=Int8ActPerTensorFloat, - key_quant=Int8ActPerTensorFloat, - value_quant=Int8ActPerTensorFloat, - softmax_input_quant=Int8ActPerTensorFloat, - softmax_output_quant=Uint8ActPerTensorFloat, + softmax_input_quant=None, + attn_output_weights_quant=Uint8ActPerTensorFloat, + q_scaled_quant=Int8ActPerTensorFloat, + k_transposed_quant=Int8ActPerTensorFloat, + v_quant=Int8ActPerTensorFloat, attn_output_quant=None, **kwargs) -> None: super(QuantScaledDotProductAttention, self).__init__() @@ -120,13 +120,14 @@ def __init__( def filter_kwargs(prefix): return {k[len(prefix):]: v for k, v in kwargs.items() if k.startswith(prefix)} - self.query_quant = QuantIdentity(act_quant=query_quant, **filter_kwargs('query_')) - self.key_quant = QuantIdentity(act_quant=key_quant, **filter_kwargs('key_')) - self.value_quant = QuantIdentity(act_quant=value_quant, **filter_kwargs('value_')) + self.q_scaled_quant = QuantIdentity(act_quant=q_scaled_quant, **filter_kwargs('q_scaled_')) + self.k_transposed_quant = QuantIdentity( + act_quant=k_transposed_quant, **filter_kwargs('k_transposed_')) + self.v_quant = QuantIdentity(act_quant=v_quant, **filter_kwargs('v_')) self.softmax_input_quant = QuantIdentity( act_quant=softmax_input_quant, **filter_kwargs('softmax_input_')) - self.softmax_output_quant = QuantIdentity( - act_quant=softmax_output_quant, **filter_kwargs('softmax_output_')) + self.attn_output_weights_quant = QuantIdentity( + act_quant=attn_output_weights_quant, **filter_kwargs('attn_output_weights_')) self.attn_output_quant = QuantIdentity( act_quant=attn_output_quant, **filter_kwargs('attn_output_')) @@ -187,12 +188,14 @@ def forward( attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) else: attn_bias += attn_mask - attn_weight = query @ key.transpose(-2, -1) * scale_factor + q_scaled = self.q_scaled_quant(query * scale_factor) + k_transpose = self.k_transpose_quant(key.transpose(-2, -1)) + attn_weight = q_scaled @ k_transpose attn_weight += attn_bias attn_weight = self.softmax_input_quant(attn_weight) attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) - attn_weight = self.softmax_output_quant(attn_weight) - attn_output = attn_weight @ value + attn_weight = self.attn_output_weights_quant(attn_weight) + attn_output = attn_weight @ self.v_quant(value) attn_output = self.attn_output_quant(attn_output) return attn_output From 28f108fc13f820cb839a4cb37e2a60887546d97d Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 20 Nov 2024 17:13:47 +0000 Subject: [PATCH 06/25] feat (nn): bugfixes in QuantSDPA --- src/brevitas/nn/__init__.py | 2 ++ src/brevitas/nn/quant_sdpa.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/brevitas/nn/__init__.py b/src/brevitas/nn/__init__.py index 4176e30cf..e58a5fa0c 100644 --- a/src/brevitas/nn/__init__.py +++ b/src/brevitas/nn/__init__.py @@ -28,6 +28,8 @@ from .quant_rnn import QuantRNN from .quant_scale_bias import QuantScaleBias from .quant_scale_bias import ScaleBias +from .quant_sdpa import QuantScaledDotProductAttention +from .quant_sdpa import ScaledDotProductAttention from .quant_upsample import QuantUpsample from .quant_upsample import QuantUpsamplingBilinear2d from .quant_upsample import QuantUpsamplingNearest2d diff --git a/src/brevitas/nn/quant_sdpa.py b/src/brevitas/nn/quant_sdpa.py index a7a61d56d..053928bb1 100644 --- a/src/brevitas/nn/quant_sdpa.py +++ b/src/brevitas/nn/quant_sdpa.py @@ -40,6 +40,7 @@ POSSIBILITY OF SUCH DAMAGE. """ +import math from typing import Optional, Tuple, Union import torch @@ -51,6 +52,8 @@ from brevitas.quant.scaled_int import Int8ActPerTensorFloat from brevitas.quant.scaled_int import Uint8ActPerTensorFloat +from .quant_activation import QuantIdentity + class ScaledDotProductAttention(Module): @@ -171,25 +174,22 @@ def forward( - :math:`Hq: \text{Number of heads of query}` - :math:`H: \text{Number of heads of key and value}` """ - query = self.query_quant(query) - key = self.key_quant(key) - value = self.value_quant(value) L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale - attn_bias = torch.zeros(L, S, dtype=query.dtype) + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) if is_causal: assert attn_mask is None - temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) - attn_bias.to(query.dtype) + attn_bias.to(dtype=query.dtype, device=query.device) if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) else: - attn_bias += attn_mask + attn_bias = attn_bias + attn_mask q_scaled = self.q_scaled_quant(query * scale_factor) - k_transpose = self.k_transpose_quant(key.transpose(-2, -1)) + k_transpose = self.k_transposed_quant(key.transpose(-2, -1)) attn_weight = q_scaled @ k_transpose attn_weight += attn_bias attn_weight = self.softmax_input_quant(attn_weight) From a28bf1e4f45111ed2b62c0127ab111563d19c310 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 20 Nov 2024 17:16:02 +0000 Subject: [PATCH 07/25] Feat (example/llm): Adding functions to replace SDPA --- .../llm/llm_quant/prepare_for_quantize.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py index ee2f0b3e6..10fe8325a 100644 --- a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py +++ b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py @@ -5,10 +5,14 @@ from packaging import version import torch +import torch.nn.functional as F import transformers from transformers.models.opt.modeling_opt import OPTAttention from brevitas.graph import ModuleToModuleByClass +from brevitas.graph import TorchFunctionalToModule +from brevitas.nn import QuantScaledDotProductAttention +from brevitas.nn import ScaledDotProductAttention from brevitas_examples.llm.llm_quant.mha_layers import QuantizableOPTAttention QUANTIZABLE_MHA_MAP = { @@ -35,6 +39,12 @@ def replace_mha_with_quantizable_layers(model, dtype): return model +def replace_sdpa_with_quantizable_layers(graph_model): + fn_to_module_map = ((F.scaled_dot_product_attention, ScaledDotProductAttention),) + graph_model = TorchFunctionalToModule(fn_to_module_map=fn_to_module_map).apply(graph_model) + return graph_model + + @torch.no_grad() def add_zero_bias_to_linear(model: torch.nn.Module) -> torch.nn.Module: for name, module in model.named_modules(): From 308fcbaa30a92ad550d0c5bcc5a50f3e0cfde362 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 20 Nov 2024 17:22:04 +0000 Subject: [PATCH 08/25] feat (example/generative): Replace SDPA with QuantSDPA --- .../common/generative/quantize.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 7e9b9c897..6c156ec1a 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -463,13 +463,25 @@ def generate_quant_maps( 'dtype': dtype, 'device': device} + quant_sdpa_kwargs = { + 'softmax_input_quant': None, + 'attn_output_weights_quant': attn_output_weights_quant, + 'attn_output_weights_signed': 'float' in input_quant_format, + 'q_scaled_quant': q_scaled_quant, + 'k_transposed_quant': k_transposed_quant, + 'v_quant': v_quant, + 'attn_output_quant': None, + 'dtype': dtype, + 'device': device} + layer_map = { nn.Linear: (qnn.QuantLinear, quant_linear_kwargs), nn.Conv2d: (qnn.QuantConv2d, quant_conv_kwargs), 'diffusers.models.lora.LoRACompatibleLinear': (LoRACompatibleQuantLinear, quant_linear_kwargs), 'diffusers.models.lora.LoRACompatibleConv': (LoRACompatibleQuantConv2d, quant_conv_kwargs), - nn.MultiheadAttention: (qnn.QuantMultiheadAttention, quant_mha_kwargs)} + nn.MultiheadAttention: (qnn.QuantMultiheadAttention, quant_mha_kwargs), + qnn.ScaledDotProductAttention: (qnn.QuantScaledDotProductAttention, quant_sdpa_kwargs)} if quantize_embedding: quant_embedding_kwargs = {'weight_quant': weight_quant, 'dtype': dtype, 'device': device} From fd49b28f7119c742c11d30bd5282b9a40390241d Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 20 Nov 2024 17:24:29 +0000 Subject: [PATCH 09/25] feat (example/llm): Add argument to quantize SDPA --- src/brevitas_examples/llm/main.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 2e0d302c8..3a678bdf8 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -40,6 +40,8 @@ from brevitas_examples.llm.llm_quant.ln_affine_merge import replace_rmsnorm_with_torch from brevitas_examples.llm.llm_quant.prepare_for_quantize import add_zero_bias_to_linear from brevitas_examples.llm.llm_quant.prepare_for_quantize import replace_mha_with_quantizable_layers +from brevitas_examples.llm.llm_quant.prepare_for_quantize import \ + replace_sdpa_with_quantizable_layers from brevitas_examples.llm.llm_quant.run_utils import CastFloat16ToFloat32 from brevitas_examples.llm.llm_quant.run_utils import fix_rewriter from brevitas_examples.llm.llm_quant.run_utils import get_fx @@ -180,13 +182,18 @@ def main(args): else: dtype = torch.float16 + # Whether to quantize SDPA with FX + quant_sdpa_fx = args.quant_sdpa and not args.replace_mha + kwargs = {"torch_dtype": dtype} + if quant_sdpa_fx: + kwargs["attn_implementation"] = "sdpa" if args.export_target == 'torch_qcdq': kwargs['torchscript'] = True print("Model loading...") - model = AutoModelForCausalLM.from_pretrained(args.model, attn_implementation="sdpa", **kwargs) + model = AutoModelForCausalLM.from_pretrained(args.model, **kwargs) print("Model loaded.") model.eval() tokenizer = AutoTokenizer.from_pretrained(args.model) @@ -199,7 +206,7 @@ def main(args): with CastFloat16ToFloat32(): apply_awq(model, awq_results) - require_fx = True if args.weight_equalization or args.act_equalization == 'fx' or args.ln_affine_merge or args.convert_layernorm_to_rmsnorm else False + require_fx = True if args.weight_equalization or args.act_equalization == 'fx' or args.ln_affine_merge or args.convert_layernorm_to_rmsnorm or quant_sdpa_fx else False # Load the data for calibration and evaluation. calibration_loader = get_dataset_for_model( @@ -280,6 +287,10 @@ def main(args): print("Replace HF MHA with quantizable variants...") model = replace_mha_with_quantizable_layers(model, dtype) print("Replacing done.") + elif quant_sdpa_fx: + print("Replace `F.scaled_dot_product_attention` with QuantSDPA...") + model = replace_sdpa_with_quantizable_layers(model) + print("Replacing done.") if args.weight_equalization: print("Apply weight equalization...") @@ -636,6 +647,10 @@ def parse_args(args): type=float, default=1e-4, help='Minimum value to clamp scale to when using bf16 or fp16 quantization.') + parser.add_argument( + '--quant-sdpa', + action='store_true', + help='Quantize `F.scaled_dot_product_attention` (default: %(default)s)') parser.add_argument( '--replace-mha', action='store_true', From 02d64ff5009a506112334cf6c293f733d0874a18 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 20 Nov 2024 17:50:32 +0000 Subject: [PATCH 10/25] test (llm/sdpa): Added basic tests for SDPA --- tests/brevitas_examples/test_llm.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 60dd33ac2..61cfae010 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -245,7 +245,8 @@ def test_small_models_acc(caplog, acc_args_and_acc): @pytest_cases.fixture( ids=[ - "opt-replace-mha",], + "opt-replace-mha", + "opt-quant-sdpa",], params=[ { "model": "hf-internal-testing/tiny-random-OPTForCausalLM", # Requires PT>=2.4 to run @@ -253,6 +254,13 @@ def test_small_models_acc(caplog, acc_args_and_acc): "ln_affine_merge": True, "replace_mha": True, "float_ppl": 50016.0, + "quant_ppl": 50016.0}, + { + "model": "hf-internal-testing/tiny-random-OPTForCausalLM", # Requires PT>=2.4 to run + "weight_equalization": True, + "ln_affine_merge": True, + "quant_sdpa": True, + "float_ppl": 50016.0, "quant_ppl": 50016.0},]) def acc_args_and_acc_pt_ge_2_4(default_run_args, request): args = default_run_args @@ -430,7 +438,8 @@ def test_small_models_quant_layer(caplog, layer_args): @pytest_cases.fixture( ids=[ - "opt-replace-mha",], + "opt-replace-mha", + "opt-quant-sdpa",], params=[ { "model": "hf-internal-testing/tiny-random-OPTForCausalLM", # Requires PT>=2.4 to run @@ -439,7 +448,13 @@ def test_small_models_quant_layer(caplog, layer_args): "model.decoder.layers.0.self_attn": "", "model.decoder.layers.0.self_attn.mha": - "",}},]) + "",}}, + { + "model": "hf-internal-testing/tiny-random-OPTForCausalLM", # Requires PT>=2.4 to run + "quant_sdpa": True, + "exp_layer_types": { + "scaled_dot_product_attention": + "",}},]) def layer_args_pt_ge_2_4(default_run_args, request): args = default_run_args layer_dict = request.param From 229bc31de400bd0d63308dd3114292e320f17e4f Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 20 Nov 2024 17:51:00 +0000 Subject: [PATCH 11/25] Fix (example/llm): workaround for new OPT default attention --- src/brevitas_examples/llm/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 3a678bdf8..9304977af 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -188,6 +188,8 @@ def main(args): kwargs = {"torch_dtype": dtype} if quant_sdpa_fx: kwargs["attn_implementation"] = "sdpa" + elif args.replace_mha: + kwargs["attn_implementation"] = "eager" if args.export_target == 'torch_qcdq': kwargs['torchscript'] = True From 434ddc16a28979cca462306c364df792328272db Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 20 Nov 2024 17:53:25 +0000 Subject: [PATCH 12/25] Fix (graph): Changed SDPA import --- src/brevitas/graph/standardize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/graph/standardize.py b/src/brevitas/graph/standardize.py index 4c9233e08..7d15b12ed 100644 --- a/src/brevitas/graph/standardize.py +++ b/src/brevitas/graph/standardize.py @@ -10,7 +10,7 @@ from brevitas.fx import GraphModule from brevitas.fx import immutable_dict from brevitas.fx import Node -from brevitas.nn.quant_sdpa import ScaledDotProductAttention +from brevitas.nn import ScaledDotProductAttention from .base import FnToModule from .base import GraphTransform From b1b168c223b1d822ec4cf4618f62886b3df990d0 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 20 Nov 2024 18:04:13 +0000 Subject: [PATCH 13/25] Fix (graph): Removed SDPA from default standardize script. --- src/brevitas/graph/standardize.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/brevitas/graph/standardize.py b/src/brevitas/graph/standardize.py index 7d15b12ed..93e99eac9 100644 --- a/src/brevitas/graph/standardize.py +++ b/src/brevitas/graph/standardize.py @@ -10,7 +10,6 @@ from brevitas.fx import GraphModule from brevitas.fx import immutable_dict from brevitas.fx import Node -from brevitas.nn import ScaledDotProductAttention from .base import FnToModule from .base import GraphTransform @@ -110,8 +109,7 @@ class TorchFunctionalToModule(GraphTransform): (F.avg_pool3d, nn.AvgPool3d), (F.adaptive_avg_pool1d, nn.AdaptiveAvgPool1d), (F.adaptive_avg_pool2d, nn.AdaptiveAvgPool2d), (F.adaptive_avg_pool3d, - nn.AdaptiveAvgPool3d), (F.dropout, nn.Dropout), - (F.scaled_dot_product_attention, ScaledDotProductAttention)) + nn.AdaptiveAvgPool3d), (F.dropout, nn.Dropout)) def __init__(self, fn_to_module_map=FN_TO_MODULE_MAP): super().__init__() From b28ee5ea6d46bb280140af42cd62b8808c040251 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 21 Nov 2024 11:05:15 +0000 Subject: [PATCH 14/25] test (nn/sdpa): Added basic sanity check test --- tests/brevitas/nn/test_sdpa.py | 58 ++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 tests/brevitas/nn/test_sdpa.py diff --git a/tests/brevitas/nn/test_sdpa.py b/tests/brevitas/nn/test_sdpa.py new file mode 100644 index 000000000..7e0de670d --- /dev/null +++ b/tests/brevitas/nn/test_sdpa.py @@ -0,0 +1,58 @@ +# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from packaging import version +import pytest +import torch +from torch.nn.functional import scaled_dot_product_attention + +from brevitas import torch_version +from brevitas.nn import QuantScaledDotProductAttention +from brevitas.nn import ScaledDotProductAttention +from tests.marker import requires_pt_ge + +ATOL = 1e-6 +EMBED_DIM = 9 +HEAD_DIM = 3 +BATCH_SIZE = 2 +SEQUENCE_LENGTH = 4 +PAST_SEQUENCE_LENGTH = 5 +DROPOUT_SEED = 42 + + +class TestScaledDotProductAttention: + + @requires_pt_ge('2.0') + @pytest.mark.parametrize("dropout_p", [0.0, 0.5]) + @pytest.mark.parametrize("is_causal", [True, False]) + @pytest.mark.parametrize("scale", [None, 0.3]) + @pytest.mark.parametrize("enable_gqa", [False, True]) + @pytest.mark.parametrize("rand_attn_mask", [False, True]) + # Sanity check, since `ScaledDotProductAttention` just called `F.scaled_dot_product_attention` in its forward function + def test_sdpa_fwd(self, dropout_p, is_causal, scale, enable_gqa, rand_attn_mask): + extra_kwargs = { + "dropout_p": dropout_p, + "is_causal": is_causal, + "scale": scale, + "enable_gqa": enable_gqa,} + if torch_version < version.parse('2.5.0'): + del extra_kwargs["enable_gqa"] + + kv_length = PAST_SEQUENCE_LENGTH + SEQUENCE_LENGTH + m = ScaledDotProductAttention() + q = torch.randn(BATCH_SIZE, HEAD_DIM, SEQUENCE_LENGTH, EMBED_DIM) + k = torch.randn(BATCH_SIZE, HEAD_DIM, kv_length, EMBED_DIM) + v = torch.randn(BATCH_SIZE, HEAD_DIM, kv_length, EMBED_DIM) + if rand_attn_mask and not is_causal: + attn_mask = torch.randint( + low=0, high=2, size=(BATCH_SIZE, 1, SEQUENCE_LENGTH, kv_length), dtype=torch.bool) + else: + attn_mask = None + if dropout_p > 0.0: + torch.manual_seed(DROPOUT_SEED) + ref_out = scaled_dot_product_attention(q, k, v, attn_mask, **extra_kwargs) + if dropout_p > 0.0: + torch.manual_seed(DROPOUT_SEED) + out = m(q, k, v, attn_mask, **extra_kwargs) + assert torch.isclose(out, ref_out, atol=ATOL).all() + assert torch.isclose(out, ref_out, atol=ATOL).all() From 57803ce38d128206f1bf0fefc6400aacc782cf64 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 21 Nov 2024 11:30:23 +0000 Subject: [PATCH 15/25] Fix (nn): Fix in QSPDA --- src/brevitas/nn/quant_sdpa.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/brevitas/nn/quant_sdpa.py b/src/brevitas/nn/quant_sdpa.py index 053928bb1..754c48bac 100644 --- a/src/brevitas/nn/quant_sdpa.py +++ b/src/brevitas/nn/quant_sdpa.py @@ -176,7 +176,10 @@ def forward( """ L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale - attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + if attn_mask is None: + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + else: + attn_bias = torch.zeros(size=attn_mask.shape, dtype=query.dtype, device=query.device) if is_causal: assert attn_mask is None temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0) @@ -187,7 +190,7 @@ def forward( if attn_mask.dtype == torch.bool: attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) else: - attn_bias = attn_bias + attn_mask + attn_bias += attn_mask q_scaled = self.q_scaled_quant(query * scale_factor) k_transpose = self.k_transposed_quant(key.transpose(-2, -1)) attn_weight = q_scaled @ k_transpose From 638a82be0322320b25a222b4817054ec5fc6b39d Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 21 Nov 2024 11:30:58 +0000 Subject: [PATCH 16/25] test (nn): Added quant_disabled QSPDA test --- tests/brevitas/nn/test_sdpa.py | 44 +++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/tests/brevitas/nn/test_sdpa.py b/tests/brevitas/nn/test_sdpa.py index 7e0de670d..9d735a997 100644 --- a/tests/brevitas/nn/test_sdpa.py +++ b/tests/brevitas/nn/test_sdpa.py @@ -28,7 +28,7 @@ class TestScaledDotProductAttention: @pytest.mark.parametrize("scale", [None, 0.3]) @pytest.mark.parametrize("enable_gqa", [False, True]) @pytest.mark.parametrize("rand_attn_mask", [False, True]) - # Sanity check, since `ScaledDotProductAttention` just called `F.scaled_dot_product_attention` in its forward function + # Sanity check, since `ScaledDotProductAttention` just calls `F.scaled_dot_product_attention` in its forward function def test_sdpa_fwd(self, dropout_p, is_causal, scale, enable_gqa, rand_attn_mask): extra_kwargs = { "dropout_p": dropout_p, @@ -56,3 +56,45 @@ def test_sdpa_fwd(self, dropout_p, is_causal, scale, enable_gqa, rand_attn_mask) out = m(q, k, v, attn_mask, **extra_kwargs) assert torch.isclose(out, ref_out, atol=ATOL).all() assert torch.isclose(out, ref_out, atol=ATOL).all() + + @requires_pt_ge('2.0') + @pytest.mark.parametrize("dropout_p", [0.0, 0.5]) + @pytest.mark.parametrize("is_causal", [True, False]) + @pytest.mark.parametrize("scale", [None, 0.3]) + @pytest.mark.parametrize("enable_gqa", [False, True]) + @pytest.mark.parametrize("rand_attn_mask", [False, True]) + def test_sdpa_quant_disabled_fwd(self, dropout_p, is_causal, scale, enable_gqa, rand_attn_mask): + extra_kwargs = { + "dropout_p": dropout_p, + "is_causal": is_causal, + "scale": scale, + "enable_gqa": enable_gqa,} + if torch_version < version.parse('2.5.0'): + del extra_kwargs["enable_gqa"] + + kv_length = PAST_SEQUENCE_LENGTH + SEQUENCE_LENGTH + m = ScaledDotProductAttention() + qm = QuantScaledDotProductAttention( + softmax_input_quant=None, + attn_output_weights_quant=None, + q_scaled_quant=None, + k_transposed_quant=None, + v_quant=None, + attn_output_quant=None, + ) + q = torch.randn(BATCH_SIZE, HEAD_DIM, SEQUENCE_LENGTH, EMBED_DIM) + k = torch.randn(BATCH_SIZE, HEAD_DIM, kv_length, EMBED_DIM) + v = torch.randn(BATCH_SIZE, HEAD_DIM, kv_length, EMBED_DIM) + if rand_attn_mask and not is_causal: + attn_mask = torch.randint( + low=0, high=2, size=(BATCH_SIZE, 1, SEQUENCE_LENGTH, kv_length), dtype=torch.bool) + else: + attn_mask = None + if dropout_p > 0.0: + torch.manual_seed(DROPOUT_SEED) + ref_out = m(q, k, v, attn_mask, **extra_kwargs) + if dropout_p > 0.0: + torch.manual_seed(DROPOUT_SEED) + out = qm(q, k, v, attn_mask, **extra_kwargs) + assert torch.isclose(out, ref_out, atol=ATOL).all() + assert torch.isclose(out, ref_out, atol=ATOL).all() From bf499e2d71eaa542b04eee76bbb2d87e40e43970 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 21 Nov 2024 14:30:11 +0000 Subject: [PATCH 17/25] test (fix): sdpa import --- tests/brevitas/nn/test_sdpa.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/brevitas/nn/test_sdpa.py b/tests/brevitas/nn/test_sdpa.py index 9d735a997..856f1ef01 100644 --- a/tests/brevitas/nn/test_sdpa.py +++ b/tests/brevitas/nn/test_sdpa.py @@ -4,7 +4,7 @@ from packaging import version import pytest import torch -from torch.nn.functional import scaled_dot_product_attention +import torch.nn.functional as F from brevitas import torch_version from brevitas.nn import QuantScaledDotProductAttention @@ -50,7 +50,7 @@ def test_sdpa_fwd(self, dropout_p, is_causal, scale, enable_gqa, rand_attn_mask) attn_mask = None if dropout_p > 0.0: torch.manual_seed(DROPOUT_SEED) - ref_out = scaled_dot_product_attention(q, k, v, attn_mask, **extra_kwargs) + ref_out = F.scaled_dot_product_attention(q, k, v, attn_mask, **extra_kwargs) if dropout_p > 0.0: torch.manual_seed(DROPOUT_SEED) out = m(q, k, v, attn_mask, **extra_kwargs) From 14b3ea40c9887d60ab611e920b77b2faff77e8ba Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 28 Nov 2024 16:46:58 +0000 Subject: [PATCH 18/25] test (nn/sdpa): Fix when PT<2.1 --- tests/brevitas/nn/test_sdpa.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/brevitas/nn/test_sdpa.py b/tests/brevitas/nn/test_sdpa.py index 856f1ef01..efb6e4aa5 100644 --- a/tests/brevitas/nn/test_sdpa.py +++ b/tests/brevitas/nn/test_sdpa.py @@ -37,6 +37,8 @@ def test_sdpa_fwd(self, dropout_p, is_causal, scale, enable_gqa, rand_attn_mask) "enable_gqa": enable_gqa,} if torch_version < version.parse('2.5.0'): del extra_kwargs["enable_gqa"] + if torch_version < version.parse('2.1.0'): + del extra_kwargs["scale"] kv_length = PAST_SEQUENCE_LENGTH + SEQUENCE_LENGTH m = ScaledDotProductAttention() From 2071511404048789c0b9a99941f3d092f8ff971a Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 28 Nov 2024 16:55:46 +0000 Subject: [PATCH 19/25] Fix (example/llm): Removed unnecessary intantiation kwarg --- src/brevitas_examples/llm/main.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 9304977af..3a678bdf8 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -188,8 +188,6 @@ def main(args): kwargs = {"torch_dtype": dtype} if quant_sdpa_fx: kwargs["attn_implementation"] = "sdpa" - elif args.replace_mha: - kwargs["attn_implementation"] = "eager" if args.export_target == 'torch_qcdq': kwargs['torchscript'] = True From 5368351741e08c0af0584b0851a19bb4a6d50367 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 28 Nov 2024 17:59:35 +0000 Subject: [PATCH 20/25] fix (nn/sdpa): Add backwards compatibility with older pytorch versions --- src/brevitas/nn/quant_sdpa.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/brevitas/nn/quant_sdpa.py b/src/brevitas/nn/quant_sdpa.py index 754c48bac..e9e6d89c9 100644 --- a/src/brevitas/nn/quant_sdpa.py +++ b/src/brevitas/nn/quant_sdpa.py @@ -97,6 +97,11 @@ def forward( - :math:`Hq: \text{Number of heads of query}` - :math:`H: \text{Number of heads of key and value}` """ + kwargs = {} + if scale is not None: + kwargs["scale"] = scale + if not enable_gqa: + kwargs["enable_gqa"] = enable_gqa return F.scaled_dot_product_attention( query=query, key=key, @@ -104,7 +109,7 @@ def forward( attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, - scale=scale) + **kwargs) class QuantScaledDotProductAttention(Module): From c50418724892fc6711dfec603d603a1c4f1a1df4 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 28 Nov 2024 18:08:13 +0000 Subject: [PATCH 21/25] Fix (test/nn/sdpa): Fix for older pytorch versions. --- tests/brevitas/nn/test_sdpa.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/brevitas/nn/test_sdpa.py b/tests/brevitas/nn/test_sdpa.py index efb6e4aa5..0ba6a837e 100644 --- a/tests/brevitas/nn/test_sdpa.py +++ b/tests/brevitas/nn/test_sdpa.py @@ -73,6 +73,8 @@ def test_sdpa_quant_disabled_fwd(self, dropout_p, is_causal, scale, enable_gqa, "enable_gqa": enable_gqa,} if torch_version < version.parse('2.5.0'): del extra_kwargs["enable_gqa"] + if torch_version < version.parse('2.1.0'): + del extra_kwargs["scale"] kv_length = PAST_SEQUENCE_LENGTH + SEQUENCE_LENGTH m = ScaledDotProductAttention() From ac533dd4d00225389731756c7515b8775d94ee1b Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 29 Nov 2024 10:09:25 +0000 Subject: [PATCH 22/25] test (nn/sdpa): bugfix --- src/brevitas/nn/quant_sdpa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/nn/quant_sdpa.py b/src/brevitas/nn/quant_sdpa.py index e9e6d89c9..96e85d489 100644 --- a/src/brevitas/nn/quant_sdpa.py +++ b/src/brevitas/nn/quant_sdpa.py @@ -100,7 +100,7 @@ def forward( kwargs = {} if scale is not None: kwargs["scale"] = scale - if not enable_gqa: + if enable_gqa: kwargs["enable_gqa"] = enable_gqa return F.scaled_dot_product_attention( query=query, From 41e4fd57106a220dcab2b23c4a273df0f58415fe Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Wed, 4 Dec 2024 16:43:38 +0000 Subject: [PATCH 23/25] docs (examples/llm): Update README --- src/brevitas_examples/llm/README.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index 64c82c80a..8471c84ee 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -10,7 +10,7 @@ ## Run -Set the env variable BREVITAS_JIT=1 to speed up the quantization process. Currently unsupported whenever export is also toggled or with MSE based scales/zero-points. +Set the env variable `BREVITAS_JIT=1` to speed up the quantization process. Currently unsupported whenever export is also toggled or with MSE based scales/zero-points. ```bash usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] @@ -46,8 +46,9 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES] [--act-calibration] [--bias-corr] [--ln-affine-merge] [--convert-layernorm-to-rmsnorm] [--replace-rmsnorm] [--no-quantize] [--no-float16] - [--scaling-min-val SCALING_MIN_VAL] [--replace-mha] - [--weight-equalization] [--rotation {fx,layerwise,fused_no_fx}] + [--scaling-min-val SCALING_MIN_VAL] [--quant-sdpa] + [--replace-mha] [--weight-equalization] + [--rotation {fx,layerwise,fused_no_fx}] [--rotation-mode {had,ort}] [--rotation-orphan-sink] [--act-equalization {None,layerwise,fx}] [--load-awq LOAD_AWQ] [--export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight}] @@ -160,6 +161,8 @@ options: --scaling-min-val SCALING_MIN_VAL Minimum value to clamp scale to when using bf16 or fp16 quantization. + --quant-sdpa Quantize `F.scaled_dot_product_attention` (default: + False) --replace-mha Replace HuggingFace Attention with a quantizable version --weight-equalization @@ -200,5 +203,4 @@ options: --learned-round-fast-update Whether to use fast update with learned round. Prototype (default: False) - ``` From 69fbafbbe28c9332d1bf610b384ec125f83a6665 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 5 Dec 2024 13:15:32 +0000 Subject: [PATCH 24/25] test (nn/sdpa): test filter_kwargs works correctly --- tests/brevitas/nn/test_sdpa.py | 47 ++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tests/brevitas/nn/test_sdpa.py b/tests/brevitas/nn/test_sdpa.py index 0ba6a837e..281429f0f 100644 --- a/tests/brevitas/nn/test_sdpa.py +++ b/tests/brevitas/nn/test_sdpa.py @@ -9,6 +9,8 @@ from brevitas import torch_version from brevitas.nn import QuantScaledDotProductAttention from brevitas.nn import ScaledDotProductAttention +from brevitas.quant import Int8ActPerTensorFloat +from brevitas.quant import Uint8ActPerTensorFloat from tests.marker import requires_pt_ge ATOL = 1e-6 @@ -22,6 +24,51 @@ class TestScaledDotProductAttention: + @requires_pt_ge('2.0') + # Check what kwargs are properly filtered and override defaults + def test_sdpa_init(self): + extra_kwargs = { + 'softmax_input_bit_width': 2, + 'attn_output_weights_bit_width': 3, + 'q_scaled_bit_width': 4, + 'k_transposed_bit_width': 5, + 'v_bit_width': 6, + 'attn_output_bit_width': 7,} + qm = QuantScaledDotProductAttention( + softmax_input_quant=Int8ActPerTensorFloat, + attn_output_weights_quant=Uint8ActPerTensorFloat, + q_scaled_quant=Int8ActPerTensorFloat, + k_transposed_quant=Int8ActPerTensorFloat, + v_quant=Int8ActPerTensorFloat, + attn_output_quant=Int8ActPerTensorFloat, + **extra_kwargs, + ) + + # Check that the `kwargs` have been applied correctly + prefixes = ["softmax_input", "attn_output", "q_scaled", "v", "attn_output"] + for k in extra_kwargs.keys(): + checked = False + if "softmax_input_" in k: + assert int(qm.softmax_input_quant.act_quant.bit_width().item()) == extra_kwargs[k] + checked = True + elif "attn_output_weights_" in k: + assert int( + qm.attn_output_weights_quant.act_quant.bit_width().item()) == extra_kwargs[k] + checked = True + elif "q_scaled_" in k: + assert int(qm.q_scaled_quant.act_quant.bit_width().item()) == extra_kwargs[k] + checked = True + elif "k_transposed_" in k: + assert int(qm.k_transposed_quant.act_quant.bit_width().item()) == extra_kwargs[k] + checked = True + elif "v_" in k: + assert int(qm.v_quant.act_quant.bit_width().item()) == extra_kwargs[k] + checked = True + elif "attn_output_" in k: + assert int(qm.attn_output_quant.act_quant.bit_width().item()) == extra_kwargs[k] + checked = True + assert checked, f"Unmatched kwarg: {k}" + @requires_pt_ge('2.0') @pytest.mark.parametrize("dropout_p", [0.0, 0.5]) @pytest.mark.parametrize("is_causal", [True, False]) From f0701bae61cb88e960d68d8a4cb3c7663df80f2c Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 5 Dec 2024 13:19:40 +0000 Subject: [PATCH 25/25] fix (nn/sdpa): Rename output quantizer to sdpa_output_quant to avoid name clashes --- src/brevitas/nn/quant_sdpa.py | 8 ++++---- tests/brevitas/nn/test_sdpa.py | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/brevitas/nn/quant_sdpa.py b/src/brevitas/nn/quant_sdpa.py index 96e85d489..43f99e827 100644 --- a/src/brevitas/nn/quant_sdpa.py +++ b/src/brevitas/nn/quant_sdpa.py @@ -121,7 +121,7 @@ def __init__( q_scaled_quant=Int8ActPerTensorFloat, k_transposed_quant=Int8ActPerTensorFloat, v_quant=Int8ActPerTensorFloat, - attn_output_quant=None, + sdpa_output_quant=None, **kwargs) -> None: super(QuantScaledDotProductAttention, self).__init__() @@ -136,8 +136,8 @@ def filter_kwargs(prefix): act_quant=softmax_input_quant, **filter_kwargs('softmax_input_')) self.attn_output_weights_quant = QuantIdentity( act_quant=attn_output_weights_quant, **filter_kwargs('attn_output_weights_')) - self.attn_output_quant = QuantIdentity( - act_quant=attn_output_quant, **filter_kwargs('attn_output_')) + self.sdpa_output_quant = QuantIdentity( + act_quant=sdpa_output_quant, **filter_kwargs('sdpa_output_')) def forward( self, @@ -205,5 +205,5 @@ def forward( attn_weight = torch.dropout(attn_weight, dropout_p, train=True) attn_weight = self.attn_output_weights_quant(attn_weight) attn_output = attn_weight @ self.v_quant(value) - attn_output = self.attn_output_quant(attn_output) + attn_output = self.sdpa_output_quant(attn_output) return attn_output diff --git a/tests/brevitas/nn/test_sdpa.py b/tests/brevitas/nn/test_sdpa.py index 281429f0f..b38415ea8 100644 --- a/tests/brevitas/nn/test_sdpa.py +++ b/tests/brevitas/nn/test_sdpa.py @@ -33,19 +33,19 @@ def test_sdpa_init(self): 'q_scaled_bit_width': 4, 'k_transposed_bit_width': 5, 'v_bit_width': 6, - 'attn_output_bit_width': 7,} + 'sdpa_output_bit_width': 7,} qm = QuantScaledDotProductAttention( softmax_input_quant=Int8ActPerTensorFloat, attn_output_weights_quant=Uint8ActPerTensorFloat, q_scaled_quant=Int8ActPerTensorFloat, k_transposed_quant=Int8ActPerTensorFloat, v_quant=Int8ActPerTensorFloat, - attn_output_quant=Int8ActPerTensorFloat, + sdpa_output_quant=Int8ActPerTensorFloat, **extra_kwargs, ) # Check that the `kwargs` have been applied correctly - prefixes = ["softmax_input", "attn_output", "q_scaled", "v", "attn_output"] + prefixes = ["softmax_input", "attn_output_weights", "q_scaled", "v", "sdpa_output"] for k in extra_kwargs.keys(): checked = False if "softmax_input_" in k: @@ -64,8 +64,8 @@ def test_sdpa_init(self): elif "v_" in k: assert int(qm.v_quant.act_quant.bit_width().item()) == extra_kwargs[k] checked = True - elif "attn_output_" in k: - assert int(qm.attn_output_quant.act_quant.bit_width().item()) == extra_kwargs[k] + elif "sdpa_output_" in k: + assert int(qm.sdpa_output_quant.act_quant.bit_width().item()) == extra_kwargs[k] checked = True assert checked, f"Unmatched kwarg: {k}" @@ -131,7 +131,7 @@ def test_sdpa_quant_disabled_fwd(self, dropout_p, is_causal, scale, enable_gqa, q_scaled_quant=None, k_transposed_quant=None, v_quant=None, - attn_output_quant=None, + sdpa_output_quant=None, ) q = torch.randn(BATCH_SIZE, HEAD_DIM, SEQUENCE_LENGTH, EMBED_DIM) k = torch.randn(BATCH_SIZE, HEAD_DIM, kv_length, EMBED_DIM)