Skip to content

Commit

Permalink
Optimization for phi series models: support fp8 kv cache and reuse kv…
Browse files Browse the repository at this point in the history
… cache (huggingface#902)

Signed-off-by: yuwenzho <[email protected]>
  • Loading branch information
yuwenzho authored May 8, 2024
1 parent 9f6eba3 commit aa175ea
Show file tree
Hide file tree
Showing 9 changed files with 508 additions and 268 deletions.
32 changes: 31 additions & 1 deletion examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ You will also need to add `--torch_compile` in your command.

### Running with FP8

Llama2-70b, Llama2-7b, Llama3-70b, Llama3-8b, Mixtral-8x7B, Falcon-7B, Falcon-40B, and Falcon-180B in FP8 are enabled using the Quantization Toolkit (HQT), which provides model measurement and quantization capabilities in PyTorch.
Llama2-70b, Llama2-7b, Llama3-70b, Llama3-8b, Mixtral-8x7B, Falcon-7B, Falcon-40B, Falcon-180B and phi-2 in FP8 are enabled using the Quantization Toolkit (HQT), which provides model measurement and quantization capabilities in PyTorch.

More information on enabling fp8 in SynapseAI is available here:
https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html
Expand Down Expand Up @@ -363,6 +363,36 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python ../gaudi_spawn.py \
--trim_logits \
--fp8
```

Here is an example to measure the tensor quantization statistics on phi-2 with 1 card:

```bash
QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_lm_eval.py \
-o acc_phi-2_bs1_measure.txt \
--model_name_or_path microsoft/phi-2 \
--use_hpu_graphs \
--use_kv_cache \
--max_new_tokens 100 \
--batch_size 1 \
--trim_logits \
--reuse_cache \
--bf16
```

Here is an example to quantize the model based on previous measurements for phi-2 with 1 card:
```bash
QUANT_CONFIG=./quantization_config/maxabs_quant_phi.json python run_generation.py \
--model_name_or_path microsoft/phi-2 \
--use_hpu_graphs \
--use_kv_cache \
--max_new_tokens 100 \
--batch_size 1 \
--bf16 \
--trim_logits \
--reuse_cache \
--fp8
```

`--fp8` is required to enable quantization in fp8.


Expand Down
14 changes: 14 additions & 0 deletions examples/text-generation/quantization_config/maxabs_quant_phi.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"method": "HOOKS",
"mode": "QUANTIZE",
"observer": "maxabs",
"scale_method": "maxabs_hw",
"allowlist": {"types": [], "names": []},
"blocklist": {"types": [], "names": [
"matmul_qk",
"matmul_av",
"lm_head"
]},
"dump_stats_path": "./hqt_output/measure",
"dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx"
}
2 changes: 1 addition & 1 deletion examples/text-generation/run_lm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self, tokenizer, model, args, options):
self.options = options
self._device = args.device
self.model_inputs = {"use_cache": self.options.use_cache}
if self.model.config.model_type in ["llama", "mistral", "falcon"]:
if self.model.config.model_type in ["llama", "mistral", "falcon", "phi"]:
self.model_inputs.update(
{
"reuse_cache": self.options.reuse_cache,
Expand Down
3 changes: 2 additions & 1 deletion optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,8 @@ def generate(
"mistral",
"falcon",
"mixtral",
], "reuse_cache only supported by llama, mistral, falcon and mixtral at the moment"
"phi",
], "reuse_cache only supported by llama, mistral, falcon, mixtral and phi at the moment"
if not generation_config.bucket_internal:
assert (
generation_config.bucket_size <= 0
Expand Down
12 changes: 6 additions & 6 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@
GaudiOPTForCausalLM,
GaudiOPTLearnedPositionalEmbedding,
GaudiPersimmonForCausalLM,
GaudiPhiAttention,
GaudiPhiDecoderLayer,
GaudiPhiForCausalLM,
GaudiPhiModel,
GaudiQwen2DecoderLayer,
GaudiQwen2ForCausalLM,
GaudiStableLmForCausalLM,
Expand Down Expand Up @@ -132,9 +135,6 @@
gaudi_persimmon_attention_forward,
gaudi_persimmon_decoder_layer_forward,
gaudi_persimmon_model_forward,
gaudi_phi_attention_forward,
gaudi_phi_decoder_layer_forward,
gaudi_phi_model_forward,
gaudi_qwen2_attention_forward,
gaudi_qwen2_model_forward,
gaudi_rot_matmul,
Expand Down Expand Up @@ -366,9 +366,9 @@ def adapt_transformers_to_gaudi():

# Optimization for phi on Gaudi
transformers.models.phi.modeling_phi.PhiForCausalLM = GaudiPhiForCausalLM
transformers.models.phi.modeling_phi.PhiAttention.forward = gaudi_phi_attention_forward
transformers.models.phi.modeling_phi.PhiDecoderLayer.forward = gaudi_phi_decoder_layer_forward
transformers.models.phi.modeling_phi.PhiModel.forward = gaudi_phi_model_forward
transformers.models.phi.modeling_phi.PhiAttention = GaudiPhiAttention
transformers.models.phi.modeling_phi.PhiDecoderLayer = GaudiPhiDecoderLayer
transformers.models.phi.modeling_phi.PhiModel = GaudiPhiModel

# Optimization for gemma on Gaudi
transformers.models.gemma.modeling_gemma.GemmaForCausalLM = GaudiGemmaForCausalLM
Expand Down
6 changes: 3 additions & 3 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@
gaudi_persimmon_model_forward,
)
from .phi import (
GaudiPhiAttention,
GaudiPhiDecoderLayer,
GaudiPhiForCausalLM,
gaudi_phi_attention_forward,
gaudi_phi_decoder_layer_forward,
gaudi_phi_model_forward,
GaudiPhiModel,
)
from .qwen2 import (
GaudiQwen2DecoderLayer,
Expand Down
6 changes: 3 additions & 3 deletions optimum/habana/transformers/models/phi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .modeling_phi import (
GaudiPhiAttention,
GaudiPhiDecoderLayer,
GaudiPhiForCausalLM,
gaudi_phi_attention_forward,
gaudi_phi_decoder_layer_forward,
gaudi_phi_model_forward,
GaudiPhiModel,
)
Loading

0 comments on commit aa175ea

Please sign in to comment.