-
Notifications
You must be signed in to change notification settings - Fork 580
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #38 from iMountTai/main
add FlashAttention-2 support
- Loading branch information
Showing
3 changed files
with
128 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# Below code is based on https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py. | ||
from typing import Optional, Tuple | ||
import torch | ||
|
||
import transformers | ||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb | ||
|
||
from einops import rearrange | ||
try: | ||
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func | ||
from flash_attn.bert_padding import unpad_input, pad_input | ||
except ImportError: | ||
raise ImportError( | ||
"FlashAttention-2 is not installed correctly. Please check the usage in https://github.com/Dao-AILab/flash-attention for more details." | ||
) | ||
|
||
def forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.Tensor] = None, | ||
past_key_value: Optional[Tuple[torch.Tensor]] = None, | ||
output_attentions: bool = False, | ||
use_cache: bool = False, | ||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||
"""Input shape: Batch x Time x Channel | ||
attention_mask: [bsz, q_len] | ||
""" | ||
bsz, q_len, _ = hidden_states.size() | ||
|
||
query_states = ( | ||
self.q_proj(hidden_states) | ||
.view(bsz, q_len, self.num_heads, self.head_dim) | ||
.transpose(1, 2) | ||
) | ||
key_states = ( | ||
self.k_proj(hidden_states) | ||
.view(bsz, q_len, self.num_heads, self.head_dim) | ||
.transpose(1, 2) | ||
) | ||
value_states = ( | ||
self.v_proj(hidden_states) | ||
.view(bsz, q_len, self.num_heads, self.head_dim) | ||
.transpose(1, 2) | ||
) | ||
# [bsz, q_len, nh, hd] | ||
# [bsz, nh, q_len, hd] | ||
|
||
kv_seq_len = key_states.shape[-2] | ||
assert past_key_value is None, "past_key_value is not supported" | ||
|
||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | ||
query_states, key_states = apply_rotary_pos_emb( | ||
query_states, key_states, cos, sin, position_ids | ||
) | ||
# [bsz, nh, t, hd] | ||
assert not output_attentions, "output_attentions is not supported" | ||
assert not use_cache, "use_cache is not supported" | ||
|
||
# Flash attention codes from | ||
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py | ||
|
||
# transform the data into the format required by flash attention | ||
qkv = torch.stack( | ||
[query_states, key_states, value_states], dim=2 | ||
) # [bsz, nh, 3, q_len, hd] | ||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] | ||
# We have disabled _prepare_decoder_attention_mask in LlamaModel | ||
# the attention_mask should be the same as the key_padding_mask | ||
key_padding_mask = attention_mask | ||
|
||
if key_padding_mask is None: | ||
qkv = rearrange(qkv, "b s ... -> (b s) ...") | ||
max_s = q_len | ||
cu_q_lens = torch.arange( | ||
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device | ||
) | ||
output = flash_attn_varlen_qkvpacked_func( | ||
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True | ||
) | ||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz) | ||
else: | ||
nheads = qkv.shape[-2] | ||
x = rearrange(qkv, "b s three h d -> b s (three h d)") | ||
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) | ||
x_unpad = rearrange( | ||
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads | ||
) | ||
output_unpad = flash_attn_varlen_qkvpacked_func( | ||
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True | ||
) | ||
output = rearrange( | ||
pad_input( | ||
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len | ||
), | ||
"b s (h d) -> b s h d", | ||
h=nheads, | ||
) | ||
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None | ||
|
||
|
||
# Disable the transformation of the attention mask in LlamaModel as the flash attention | ||
# requires the attention mask to be the same as the key_padding_mask | ||
def _prepare_decoder_attention_mask( | ||
self, attention_mask, input_shape, inputs_embeds, past_key_values_length | ||
): | ||
# [bsz, seq_len] | ||
return attention_mask | ||
|
||
|
||
def replace_llama_attn_with_flash_attn(): | ||
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( | ||
_prepare_decoder_attention_mask | ||
) | ||
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters