Skip to content

Commit

Permalink
Merge pull request #38 from iMountTai/main
Browse files Browse the repository at this point in the history
add FlashAttention-2 support
  • Loading branch information
ymcui authored Aug 1, 2023
2 parents a5777e8 + 1649777 commit 8ef2788
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 9 deletions.
116 changes: 116 additions & 0 deletions scripts/training/flash_attn_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Below code is based on https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py.
from typing import Optional, Tuple
import torch

import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb

from einops import rearrange
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input
except ImportError:
raise ImportError(
"FlashAttention-2 is not installed correctly. Please check the usage in https://github.com/Dao-AILab/flash-attention for more details."
)

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel
attention_mask: [bsz, q_len]
"""
bsz, q_len, _ = hidden_states.size()

query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]

kv_seq_len = key_states.shape[-2]
assert past_key_value is None, "past_key_value is not supported"

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
assert not output_attentions, "output_attentions is not supported"
assert not use_cache, "use_cache is not supported"

# Flash attention codes from
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py

# transform the data into the format required by flash attention
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask = attention_mask

if key_padding_mask is None:
qkv = rearrange(qkv, "b s ... -> (b s) ...")
max_s = q_len
cu_q_lens = torch.arange(
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
)
output = flash_attn_varlen_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
else:
nheads = qkv.shape[-2]
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
)
output_unpad = flash_attn_varlen_qkvpacked_func(
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
),
"b s (h d) -> b s h d",
h=nheads,
)
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None


# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# [bsz, seq_len]
return attention_mask


def replace_llama_attn_with_flash_attn():
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
12 changes: 6 additions & 6 deletions scripts/training/run_clm_pt_with_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ class MyTrainingArguments(TrainingArguments):
modules_to_save : Optional[str] = field(default=None)
debug_mode : Optional[bool] = field(default=False)
peft_path : Optional[str] = field(default=None)
flash_attn : Optional[bool] = field(default=False)


logger = logging.getLogger(__name__)
Expand All @@ -326,6 +327,9 @@ def main():
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if training_args.flash_attn:
from flash_attn_patch import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()

# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
Expand All @@ -336,7 +340,6 @@ def main():
level=logging.INFO, # if training_args.local_rank in [-1, 0] else logging.WARN,
handlers=[logging.StreamHandler(sys.stdout)],)


if training_args.should_log:
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
transformers.utils.logging.set_verbosity_info()
Expand Down Expand Up @@ -508,19 +511,16 @@ def group_texts(examples):
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
logger.info(f"Num train_samples {len(train_dataset)}")
logger.info("training example:")
logger.info("Training example:")
logger.info(tokenizer.decode(train_dataset[0]['input_ids']))
if training_args.do_eval:
eval_dataset = lm_datasets["test"]
if data_args.max_eval_samples is not None:
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
logger.info(f"Num eval_samples {len(eval_dataset)}")
logger.info("training example:")
logger.info("Evaluation example:")
logger.info(tokenizer.decode(eval_dataset[0]['input_ids']))



if model_args.model_name_or_path:
torch_dtype = (
model_args.torch_dtype
Expand Down
9 changes: 6 additions & 3 deletions scripts/training/run_clm_sft_with_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ class MyTrainingArguments(TrainingArguments):
lora_alpha : Optional[float] = field(default=32.)
modules_to_save : Optional[str] = field(default=None)
peft_path : Optional[str] = field(default=None)
flash_attn : Optional[bool] = field(default=False)


logger = logging.getLogger(__name__)
Expand All @@ -213,6 +214,9 @@ def main():
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if training_args.flash_attn:
from flash_attn_patch import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()

send_example_telemetry("run_clm", model_args, data_args)

Expand Down Expand Up @@ -311,7 +315,7 @@ def main():
data_cache_dir = None,
preprocessing_num_workers = data_args.preprocessing_num_workers)
logger.info(f"Num train_samples {len(train_dataset)}")
logger.info("training example:")
logger.info("Training example:")
logger.info(tokenizer.decode(train_dataset[0]['input_ids']))
if training_args.do_eval:
with training_args.main_process_first(desc="loading and tokenization"):
Expand All @@ -324,7 +328,7 @@ def main():
data_cache_dir = None,
preprocessing_num_workers = data_args.preprocessing_num_workers)
logger.info(f"Num eval_samples {len(eval_dataset)}")
logger.info("eval example:")
logger.info("Evaluation example:")
logger.info(tokenizer.decode(eval_dataset[0]['input_ids']))

if model_args.model_name_or_path:
Expand Down Expand Up @@ -353,7 +357,6 @@ def main():
logger.info(f"len(tokenizer):{len(tokenizer)}")
if model_vocab_size != len(tokenizer):
logger.info(f"Resize model vocab size to {len(tokenizer)}")
logger.info("resize the embedding size by the size of the tokenizer")
model.resize_token_embeddings(len(tokenizer))

if training_args.peft_path is not None:
Expand Down

0 comments on commit 8ef2788

Please sign in to comment.