Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate Habana flash attention to Llama2-70B finetune #596

Merged
merged 2 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions examples/language-modeling/run_lora_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,23 @@ class ModelArguments:
)
},
)
use_flash_attention: bool = field(
default=False,
metadata={
"help": (
"Whether to use Habana flash attention for fine-tuning. The current support is limited to Llama only.",
)
},
)
flash_attention_recompute: bool = field(
default=False,
metadata={
"help": (
"Whether to enable recompute in Habana flash attention for fine-tuning."
" It is applicable only when use_flash_attention is True.",
)
},
)
load_meta_device: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -519,6 +536,9 @@ def main():
model.generation_config.eos_token_id = 2
if model_args.attn_softmax_bf16:
model.generation_config.attn_softmax_bf16 = True
if model_args.use_flash_attention:
model.generation_config.use_flash_attention = True
model.generation_config.flash_attention_recompute = model_args.flash_attention_recompute

if hasattr(model.generation_config, "pad_token_id") and model.generation_config.pad_token_id is not None:
tokenizer.pad_token_id = model.generation_config.pad_token_id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class GaudiGenerationConfig(GenerationConfig):
Store kv-cache in float8 when kv-cache is used
use_flash_attention (`bool`, *optional*):
Whether to use flash attention optimization.
flash_attention_recompute (`bool`, *optional*):
Whether to enable recompute if use Habana flash attention.
"""

def __init__(self, **kwargs):
Expand All @@ -44,3 +46,4 @@ def __init__(self, **kwargs):
self.bucket_size = kwargs.get("bucket_size", -1)
self.kv_cache_fp8 = kwargs.get("kv_cache_fp8", None)
self.use_flash_attention = kwargs.get("use_flash_attention", None)
self.flash_attention_recompute = kwargs.get("flash_attention_recompute", None)
1 change: 1 addition & 0 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,7 @@ def generate(

# determine whether flash attention needs to be used
model_kwargs["use_flash_attention"] = generation_config.use_flash_attention
model_kwargs["flash_attention_recompute"] = True if generation_config.flash_attention_recompute else False

if not self.config.is_encoder_decoder:
calculated_max_length = input_ids.shape[-1]
Expand Down
16 changes: 15 additions & 1 deletion optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def pre_attn_forward(
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Copied from LlamaAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
Expand All @@ -166,6 +167,7 @@ def pre_attn_forward(
- add new args attn_softmax_bf16
- add new args reuse_cache
- add new args use_flash_attention
- add new arg flash_attention_recompute
"""
bsz, q_len, _ = hidden_states.size()

Expand Down Expand Up @@ -241,7 +243,7 @@ def pre_attn_forward(
)
else:
# first token
with ht.sdp_kernel(enable_recompute=False):
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None)

else:
Expand Down Expand Up @@ -354,6 +356,7 @@ def forward(
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Copied from LlamaDecoderLayer.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
Expand All @@ -362,6 +365,7 @@ def forward(
- add new args attn_softmax_bf16
- add new args reuse_cache
- add new args use_flash_attention
- add new arg flash_attention_recompute
"""
residual = hidden_states
output_pre_attn, self_attn_weights, present_key_value = self.pre_attn(
Expand All @@ -375,6 +379,7 @@ def forward(
attn_softmax_bf16,
reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)
self.self_attn.attention_all_reduce(output_pre_attn)
output_post_attn_pre_mlp, residual_mlp = self.post_attn_pre_mlp(output_pre_attn, residual)
Expand Down Expand Up @@ -402,6 +407,7 @@ def pre_attn(
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
hidden_states = self.input_layernorm(hidden_states)
output_attn, attn_weights, present_key_value = self.self_attn.pre_attn_forward(
Expand All @@ -415,6 +421,7 @@ def pre_attn(
attn_softmax_bf16,
reuse_cache,
use_flash_attention,
flash_attention_recompute,
)
return output_attn, attn_weights, present_key_value

Expand Down Expand Up @@ -462,6 +469,7 @@ def forward(
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Copied from LlamaModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
Expand All @@ -470,6 +478,7 @@ def forward(
- add new args attn_softmax_bf16
- add new args reuse_cache
- add new args use_flash_attention
- add new arg flash_attention_recompute
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -550,6 +559,7 @@ def custom_forward(*inputs):
output_attentions,
attn_softmax_bf16=attn_softmax_bf16,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)

return custom_forward
Expand All @@ -569,6 +579,7 @@ def custom_forward(*inputs):
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -634,6 +645,7 @@ def forward(
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand All @@ -655,6 +667,7 @@ def forward(
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)
hidden_states = outputs[0]
_, seq_len, _ = hidden_states.shape
Expand Down Expand Up @@ -739,6 +752,7 @@ def prepare_inputs_for_generation(
"attn_softmax_bf16": kwargs.get("attn_softmax_bf16"),
"reuse_cache": reuse_cache,
"use_flash_attention": kwargs.get("use_flash_attention"),
"flash_attention_recompute": kwargs.get("flash_attention_recompute"),
}
)
return model_inputs
Expand Down
4 changes: 4 additions & 0 deletions optimum/habana/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,8 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args):
inputs["attn_softmax_bf16"] = True
if self.model.generation_config.use_flash_attention:
inputs["use_flash_attention"] = True
if self.model.generation_config.flash_attention_recompute:
inputs["flash_attention_recompute"] = True

# TODO: keep syncs for fast DDP?
with self.accelerator.accumulate(model):
Expand Down Expand Up @@ -1540,6 +1542,8 @@ def evaluation_loop(
inputs["attn_softmax_bf16"] = True
if self.model.generation_config.use_flash_attention:
inputs["use_flash_attention"] = True
if self.model.generation_config.flash_attention_recompute:
inputs["flash_attention_recompute"] = True

# Prediction step
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
Expand Down
Loading