diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 82f45e44c..ef9ea93d1 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -10,7 +10,6 @@ from typing import Callable, Dict, List, Optional, Set, Tuple, Union import warnings -from brevitas.nn import ScaledDotProductAttention import packaging import packaging.version import torch @@ -31,6 +30,7 @@ from brevitas.graph.hadamard import matmul_hadU_cuda from brevitas.graph.utils import get_module from brevitas.graph.utils import get_node +from brevitas.nn import ScaledDotProductAttention from brevitas.nn.equalized_layer import EqualizedModule from brevitas.nn.equalized_layer import functional_rotate_input from brevitas.nn.equalized_layer import INPUT_NAMES @@ -1509,7 +1509,6 @@ def find_sink(node): if isinstance(m, ScaledDotProductAttention): m.pre_process_q = functional_rotate_input m.pre_process_k = functional_rotate_input - # m.pre_process_v = partial(functional_rotate_input, transpose=True) return regions def apply(self, diff --git a/src/brevitas/nn/quant_sdpa.py b/src/brevitas/nn/quant_sdpa.py index 728d924e9..9f82ccc0c 100644 --- a/src/brevitas/nn/quant_sdpa.py +++ b/src/brevitas/nn/quant_sdpa.py @@ -43,14 +43,14 @@ import math from typing import Optional, Tuple, Union -from brevitas.core.function_wrapper.misc import Identity -from brevitas.function import identity import torch from torch import Tensor from torch.nn import Module from torch.nn import Parameter import torch.nn.functional as F +from brevitas.core.function_wrapper.misc import Identity +from brevitas.function import identity from brevitas.quant.scaled_int import Int8ActPerTensorFloat from brevitas.quant.scaled_int import Uint8ActPerTensorFloat @@ -59,7 +59,7 @@ class ScaledDotProductAttention(Module): - def __init__(self, pre_process_q = identity, pre_process_k = identity, pre_process_v = identity): + def __init__(self, pre_process_q=identity, pre_process_k=identity, pre_process_v=identity): super().__init__() self.pre_process_q = pre_process_q self.pre_process_k = pre_process_k @@ -113,7 +113,7 @@ def forward( return F.scaled_dot_product_attention( query=self.pre_process_q(query), key=self.pre_process_k(key), - value=value,#self.pre_process_v(value), + value=self.pre_process_v(value), attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, @@ -124,7 +124,9 @@ class QuantScaledDotProductAttention(Module): def __init__( self, - pre_process_q = Identity(), pre_process_k = Identity(), pre_process_v = Identity(), + pre_process_q=identity, + pre_process_k=identity, + pre_process_v=identity, softmax_input_quant=None, attn_output_weights_quant=Uint8ActPerTensorFloat, q_scaled_quant=Int8ActPerTensorFloat, @@ -211,15 +213,14 @@ def forward( else: attn_bias += attn_mask query, key, value = self.pre_process_q(query), self.pre_process_k(key), self.pre_process_v(value) - q_scaled = query * scale_factor#self.q_scaled_quant(query * scale_factor) + q_scaled = self.q_scaled_quant(query * scale_factor) k_transpose = self.k_transposed_quant(key.transpose(-2, -1)) attn_weight = q_scaled @ k_transpose attn_weight += attn_bias attn_weight = self.softmax_input_quant(attn_weight) attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) - # attn_weight = self.pre_process_q(attn_weight) - # attn_weight = self.attn_output_weights_quant(attn_weight) + attn_weight = self.attn_output_weights_quant(attn_weight) attn_output = attn_weight @ self.v_quant(value) attn_output = self.sdpa_output_quant(attn_output) return attn_output