diff --git a/scripts/training/flash_attn_patch.py b/scripts/training/flash_attn_patch.py new file mode 100644 index 0000000..68dc0ec --- /dev/null +++ b/scripts/training/flash_attn_patch.py @@ -0,0 +1,116 @@ +# Below code is based on https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py. +from typing import Optional, Tuple +import torch + +import transformers +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb + +from einops import rearrange +try: + from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func + from flash_attn.bert_padding import unpad_input, pad_input +except ImportError: + raise ImportError( + "FlashAttention-2 is not installed correctly. Please check the usage in https://github.com/Dao-AILab/flash-attention for more details." + ) + +def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel + + attention_mask: [bsz, q_len] + """ + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + # [bsz, q_len, nh, hd] + # [bsz, nh, q_len, hd] + + kv_seq_len = key_states.shape[-2] + assert past_key_value is None, "past_key_value is not supported" + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + # [bsz, nh, t, hd] + assert not output_attentions, "output_attentions is not supported" + assert not use_cache, "use_cache is not supported" + + # Flash attention codes from + # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py + + # transform the data into the format required by flash attention + qkv = torch.stack( + [query_states, key_states, value_states], dim=2 + ) # [bsz, nh, 3, q_len, hd] + qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] + # We have disabled _prepare_decoder_attention_mask in LlamaModel + # the attention_mask should be the same as the key_padding_mask + key_padding_mask = attention_mask + + if key_padding_mask is None: + qkv = rearrange(qkv, "b s ... -> (b s) ...") + max_s = q_len + cu_q_lens = torch.arange( + 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device + ) + output = flash_attn_varlen_qkvpacked_func( + qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True + ) + output = rearrange(output, "(b s) ... -> b s ...", b=bsz) + else: + nheads = qkv.shape[-2] + x = rearrange(qkv, "b s three h d -> b s (three h d)") + x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) + x_unpad = rearrange( + x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads + ) + output_unpad = flash_attn_varlen_qkvpacked_func( + x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True + ) + output = rearrange( + pad_input( + rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len + ), + "b s (h d) -> b s h d", + h=nheads, + ) + return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None + + +# Disable the transformation of the attention mask in LlamaModel as the flash attention +# requires the attention mask to be the same as the key_padding_mask +def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length +): + # [bsz, seq_len] + return attention_mask + + +def replace_llama_attn_with_flash_attn(): + transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( + _prepare_decoder_attention_mask + ) + transformers.models.llama.modeling_llama.LlamaAttention.forward = forward diff --git a/scripts/training/run_clm_pt_with_peft.py b/scripts/training/run_clm_pt_with_peft.py index 11f6213..7d6d701 100644 --- a/scripts/training/run_clm_pt_with_peft.py +++ b/scripts/training/run_clm_pt_with_peft.py @@ -312,6 +312,7 @@ class MyTrainingArguments(TrainingArguments): modules_to_save : Optional[str] = field(default=None) debug_mode : Optional[bool] = field(default=False) peft_path : Optional[str] = field(default=None) + flash_attn : Optional[bool] = field(default=False) logger = logging.getLogger(__name__) @@ -326,6 +327,9 @@ def main(): model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() + if training_args.flash_attn: + from flash_attn_patch import replace_llama_attn_with_flash_attn + replace_llama_attn_with_flash_attn() # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # information sent is the one passed as arguments along with your Python/PyTorch versions. @@ -336,7 +340,6 @@ def main(): level=logging.INFO, # if training_args.local_rank in [-1, 0] else logging.WARN, handlers=[logging.StreamHandler(sys.stdout)],) - if training_args.should_log: # The default of training_args.log_level is passive, so we set log level at info here to have that default. transformers.utils.logging.set_verbosity_info() @@ -508,7 +511,7 @@ def group_texts(examples): max_train_samples = min(len(train_dataset), data_args.max_train_samples) train_dataset = train_dataset.select(range(max_train_samples)) logger.info(f"Num train_samples {len(train_dataset)}") - logger.info("training example:") + logger.info("Training example:") logger.info(tokenizer.decode(train_dataset[0]['input_ids'])) if training_args.do_eval: eval_dataset = lm_datasets["test"] @@ -516,11 +519,8 @@ def group_texts(examples): max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) eval_dataset = eval_dataset.select(range(max_eval_samples)) logger.info(f"Num eval_samples {len(eval_dataset)}") - logger.info("training example:") + logger.info("Evaluation example:") logger.info(tokenizer.decode(eval_dataset[0]['input_ids'])) - - - if model_args.model_name_or_path: torch_dtype = ( model_args.torch_dtype diff --git a/scripts/training/run_clm_sft_with_peft.py b/scripts/training/run_clm_sft_with_peft.py index cc4de85..4daf208 100644 --- a/scripts/training/run_clm_sft_with_peft.py +++ b/scripts/training/run_clm_sft_with_peft.py @@ -199,6 +199,7 @@ class MyTrainingArguments(TrainingArguments): lora_alpha : Optional[float] = field(default=32.) modules_to_save : Optional[str] = field(default=None) peft_path : Optional[str] = field(default=None) + flash_attn : Optional[bool] = field(default=False) logger = logging.getLogger(__name__) @@ -213,6 +214,9 @@ def main(): model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() + if training_args.flash_attn: + from flash_attn_patch import replace_llama_attn_with_flash_attn + replace_llama_attn_with_flash_attn() send_example_telemetry("run_clm", model_args, data_args) @@ -311,7 +315,7 @@ def main(): data_cache_dir = None, preprocessing_num_workers = data_args.preprocessing_num_workers) logger.info(f"Num train_samples {len(train_dataset)}") - logger.info("training example:") + logger.info("Training example:") logger.info(tokenizer.decode(train_dataset[0]['input_ids'])) if training_args.do_eval: with training_args.main_process_first(desc="loading and tokenization"): @@ -324,7 +328,7 @@ def main(): data_cache_dir = None, preprocessing_num_workers = data_args.preprocessing_num_workers) logger.info(f"Num eval_samples {len(eval_dataset)}") - logger.info("eval example:") + logger.info("Evaluation example:") logger.info(tokenizer.decode(eval_dataset[0]['input_ids'])) if model_args.model_name_or_path: @@ -353,7 +357,6 @@ def main(): logger.info(f"len(tokenizer):{len(tokenizer)}") if model_vocab_size != len(tokenizer): logger.info(f"Resize model vocab size to {len(tokenizer)}") - logger.info("resize the embedding size by the size of the tokenizer") model.resize_token_embeddings(len(tokenizer)) if training_args.peft_path is not None: