From 0d4cba5d4279f5bee8229af59c1ff2ca9e8f255b Mon Sep 17 00:00:00 2001 From: Rishikesh Magar <43094762+RishikeshMagar@users.noreply.github.com> Date: Wed, 10 Jan 2024 23:06:16 -0800 Subject: [PATCH] GQA Attention (#59) --- .../modeling/models/apt/model_pytorch.py | 100 +++++++++++++++++- protein_lm/tests/test_attention.py | 89 ++++++++++++++++ 2 files changed, 185 insertions(+), 4 deletions(-) create mode 100644 protein_lm/tests/test_attention.py diff --git a/protein_lm/modeling/models/apt/model_pytorch.py b/protein_lm/modeling/models/apt/model_pytorch.py index f814519..f450d4a 100644 --- a/protein_lm/modeling/models/apt/model_pytorch.py +++ b/protein_lm/modeling/models/apt/model_pytorch.py @@ -8,6 +8,7 @@ from transformers.pytorch_utils import Conv1D from transformers.activations import ACT2FN from transformers.utils import logging + from protein_lm.modeling.utils.rotary_embedding import RotaryEmbedding from protein_lm.modeling.utils.rerope_embedding import RectifiedRotaryEmbedding from protein_lm.modeling.utils.alibi_embedding import create_alibi_tensor @@ -34,6 +35,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.max_sequence_length = config.max_sequence_length self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads + self.attn_type = config.attn_type self.head_dim = self.embed_dim // self.num_heads self.split_size = self.embed_dim if self.head_dim * self.num_heads != self.embed_dim: @@ -48,7 +50,15 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): # Layer-wise attention scaling, reordering, and upcasting self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx self.layer_idx = layer_idx - self.reorder_and_upcast_attn = config.reorder_and_upcast_attn + + if self.attn_type == "gqa": + self.gqa_attn = True + elif self.attn_type == "reorder_and_upcast_attn": + self.reorder_and_upcast_attn = True + elif self.attn_type == "standard": + self.standard_attn = True + + #self.reorder_and_upcast_attn = config.reorder_and_upcast_attn #comment out because config now states attn type if self.is_cross_attention: self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) @@ -116,6 +126,87 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bia return attn_output, attn_weights + def _gqa_attn(self, query, key, value, attention_mask=None, + alibi_bias =None, dropout=0.0): + """Group Query Attention implementation.""" + + # Check for potential issues before moving on + if not query.ndim == key.ndim == value.ndim == 4: + raise ValueError(f"Expected query, key, and value to be 4-dimensional, but got shapes " + f"{query.shape}, {key.shape}, and {value.shape}.") + + """ + Expected shapes: (batch_size, num_heads, query_len, query_dim) similar to _upcast_and_reordered_attn + """ + batch_size, num_heads, query_len, query_dim = query.shape + + + scale_factor = 1.0 + if self.scale_attn_weights: + scale_factor /= float(value.size(-1)) ** 0.5 + query = query / scale_factor + + ''' + Determine the number of groups + For example lets say we have 4 queries heads and 2 keys heads, then we have 2 groups + Lets say the number of group are 2 and head are 2, + then reshape the query tensor to (batch_size, (2, 2), query_len, query_dim) + query shape (batch_size, num_groups, num_heads, query_len, query_dim) + attention_weights_grouped shape (batch_size, num_groups, num_heads, query_len, key_len). + attention weights shape: (batch_size, num_heads, query_len, key_len) + ''' + + n_groups = query.size(1) // key.size(1) + + if n_groups > 1: + query_shape = query.shape + grouped_shape = (query_shape[0], n_groups, query_shape[1]//n_groups, query_shape[2], query_shape[3]) + query_grouped = query.reshape(grouped_shape) + attn_weights_grouped = torch.matmul(query_grouped, key.transpose(-2, -1)) + attn_weights = attn_weights_grouped.sum(dim=1) + #print("attn_weights:", attn_weights.shape) + + else: + ''' + If the number of groups is 1, then we can use the normal attention function + ''' + attn_weights = torch.matmul(query, key.transpose(-2, -1)) + + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + if attention_mask is not None: + # Apply the attention mask + ''' + Input attention_mask shape: (batch_size, query_len, key_len) + ''' + attn_weights += attention_mask.unsqueeze(1) # Unsqueeze to Add head dimension + + # Causal masking ensures that the attention mechanism doesn't attend to "future" tokens in sequences. + ## Adapted to work with groups and ensure similarity with vanilla attention + if not self.is_cross_attention: + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) + + # print("attn_weights:", attn_weights) + # Softmax normalization to get the attention scores + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if alibi_bias is not None: + attn_weights = attn_weights + alibi_bias[:,:,:attn_weights.size(-1)] + + # Apply dropout if specified + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Compute the output by multiplying the attention scores with the value tensor. + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bias=None): # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) bsz, num_heads, q_seq_len, dk = query.size() @@ -233,9 +324,10 @@ def forward( if self.reorder_and_upcast_attn: attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask,alibi_bias=alibi_bias) - else: + elif self.standard_attn: attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask,alibi_bias=alibi_bias) - + elif self.gqa_attn: + attn_output, attn_weights = self._gqa_attn(query, key, value, attention_mask,alibi_bias=alibi_bias) attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) @@ -244,7 +336,7 @@ def forward( if output_attentions: outputs += (attn_weights,) - return outputs # a, present, (attentions) + return outputs # a, present, (attentions) class APTMLP(nn.Module): diff --git a/protein_lm/tests/test_attention.py b/protein_lm/tests/test_attention.py new file mode 100644 index 0000000..1e85ffe --- /dev/null +++ b/protein_lm/tests/test_attention.py @@ -0,0 +1,89 @@ +import pytest +import torch +from torch.nn import functional as F + +from model_pytorch import APTAttention + +class ParameterConfig: + def __init__(self): + self.max_position_embeddings = 512 + self.position_embedding = 'rope' + self.max_sequence_length = 512 + self.hidden_size = 768 + self.num_attention_heads = 12 + self.scale_attn_weights = True + self.scale_attn_by_inverse_layer_idx = True + self.reorder_and_upcast_attn = True + self.attn_pdrop = 0.1 + self.resid_pdrop = 0.1 + self.rope_scaling_factor = 1 + self.rope_theta = 1 + self.attn_type = 'gqa' + + +def test_vanilla_attn(): + # Initialize with mock config + config = ParameterConfig() + attention = APTAttention(config, is_cross_attention=False, layer_idx=0) + + # generate random input tensors + batch_size = 4 + seq_length = 100 + num_heads = config.num_attention_heads + query_dim = config.hidden_size // num_heads + query = torch.randn(batch_size, num_heads, seq_length, query_dim) + key = torch.randn(batch_size, num_heads, seq_length, query_dim) + value = torch.randn(batch_size, num_heads, seq_length, query_dim) + + # Create a random attention mask for testing + attention_mask = torch.ones(batch_size,seq_length, seq_length) + padding_positions = 10 + attention_mask[:, -padding_positions:, :] = float('-inf') + attention_mask[:, :, -padding_positions:] = float('-inf') + attention_mask = attention_mask.unsqueeze(1) + # Pass them through the _attn method + attn_output, attn_weights = attention._attn(query, key, value, attention_mask=attention_mask) + + # Check the shapes and types of the output + assert isinstance(attn_output, torch.Tensor) + assert attn_output.shape == (batch_size, num_heads, seq_length, query_dim) + assert isinstance(attn_weights, torch.Tensor) + assert attn_weights.shape == (batch_size, num_heads, seq_length, seq_length) + print("Test passed!") + +def test_gqa_attn(): + # Initialize with mock config + config = ParameterConfig() + attention = APTAttention(config, is_cross_attention=False, layer_idx=0) + + # generate random input tensors + batch_size = 4 + seq_length = 100 + num_heads = config.num_attention_heads + query_dim = config.hidden_size // num_heads + query = torch.randn(batch_size, num_heads, seq_length, query_dim) + key = torch.randn(batch_size, num_heads, seq_length, query_dim) + value = torch.randn(batch_size, num_heads, seq_length, query_dim) + + # Create a random attention mask for testing + attention_mask = torch.ones(batch_size,seq_length, seq_length) + padding_positions = 10 + attention_mask[:, -padding_positions:, :] = float('-inf') + attention_mask[:, :, -padding_positions:] = float('-inf') + + # Pass them through the _gqa_attn method + attn_output, attn_weights = attention._gqa_attn(query, key, value, attention_mask=attention_mask) + + # Check the shapes and types of the output + assert isinstance(attn_output, torch.Tensor) + assert attn_output.shape == (batch_size, num_heads, seq_length, query_dim) + assert isinstance(attn_weights, torch.Tensor) + assert attn_weights.shape == (batch_size, num_heads, seq_length, seq_length) + print("Test passed!") + + +test_gqa_attn() +test_vanilla_attn() + + +