Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use low bit KV Cache #721

Open
sitabulaixizawaluduo opened this issue Jan 7, 2025 · 14 comments
Open

How to use low bit KV Cache #721

sitabulaixizawaluduo opened this issue Jan 7, 2025 · 14 comments

Comments

@sitabulaixizawaluduo
Copy link

Does flasher currently support per-head quant kv cache, including fp8_e4m3 and int8?

@yzh119
Copy link
Collaborator

yzh119 commented Jan 7, 2025

Per channel quantization is not supported yet, could be a good feature to have, added to #675

@sitabulaixizawaluduo
Copy link
Author

Per channel quantization is not supported yet, could be a good feature to have, added to #138 .

I'm looking forward to this realization. With better performance engines like LMDeploy, int8 kv cache can meet both high performance and high precision, better than current fp8_e5m2 and fp8_e4m3

@zhyncs
Copy link
Member

zhyncs commented Jan 7, 2025

Per channel quantization is not supported yet, could be a good feature to have, added to #138 .

I'm looking forward to this realization. With better performance engines like LMDeploy, int8 kv cache can meet both high performance and high precision, better than current fp8_e5m2 and fp8_e4m3

You are right. It’s SOTA implementation

@sitabulaixizawaluduo
Copy link
Author

Per channel quantization is not supported yet, could be a good feature to have, added to #138 .

I'm looking forward to this realization. With better performance engines like LMDeploy, int8 kv cache can meet both high performance and high precision, better than current fp8_e5m2 and fp8_e4m3

You are right. It’s SOTA implementation

At present, the accuracy of fp8_e5m2 is seriously reduced, and the fp8_e4m3 speed is even slower than the speed of BF16

@yzh119
Copy link
Collaborator

yzh119 commented Jan 7, 2025

The current flashinfer fp8 kernel should only be used for decode/append, otherwise it's worse than first converting data to f16 then use f16 kernels. Because it internally uses f16 tensor cores.

@sitabulaixizawaluduo can you tell us what's your GPU architecture? Currently my bandwidth is limited and I'm prioritizing hopper/blackwell.

@sitabulaixizawaluduo
Copy link
Author

The current flashinfer fp8 kernel should only be used for decode/append, otherwise it's worse than first converting data to f16 then use f16 kernels. Because it internally uses f16 tensor cores.

@sitabulaixizawaluduo can you tell us what's your GPU architecture? Currently my bandwidth is limited and I'm prioritizing hopper/blackwell.

GPU L40 Ada architecture

@sitabulaixizawaluduo
Copy link
Author

The current flashinfer fp8 kernel should only be used for decode/append, otherwise it's worse than first converting data to f16 then use f16 kernels. Because it internally uses f16 tensor cores.

@sitabulaixizawaluduo can you tell us what's your GPU architecture? Currently my bandwidth is limited and I'm prioritizing hopper/blackwell.

In addition, LMDeploy's KV Cache should also be calculated using the fp16 tensor core, and there will be a dequantization step in the middle, which is why it can maintain high accuracy.

@yzh119
Copy link
Collaborator

yzh119 commented Jan 8, 2025

GPU L40 Ada architecture

Got it, I think the hyperparameters for Ada is not tuned well, I'll fix that.

@yzh119
Copy link
Collaborator

yzh119 commented Jan 8, 2025

At present, the accuracy of fp8_e5m2 is seriously reduced, and the fp8_e4m3 speed is even slower than the speed of BF16

btw, did you observe this slowdown for prefill or decode? If you are talking about decode, you need to enable use_tensor_cores=True.

@sitabulaixizawaluduo
Copy link
Author

At present, the accuracy of fp8_e5m2 is seriously reduced, and the fp8_e4m3 speed is even slower than the speed of BF16

btw, did you observe this slowdown for prefill or decode? If you are talking about decode, you need to enable use_tensor_cores=True.

Thanks! I will try it

@sitabulaixizawaluduo
Copy link
Author

At present, the accuracy of fp8_e5m2 is seriously reduced, and the fp8_e4m3 speed is even slower than the speed of BF16

btw, did you observe this slowdown for prefill or decode? If you are talking about decode, you need to enable use_tensor_cores=True.

When computing with fp8 kv cache, after tensor_core is turned on, does it use fp8 tensor core to calculate first, and then use sm_scale dequantize to fp16/bf16, or use sm_scale dequantize to fp16/bf16, and then use fp16 tensor core to calculate?

@yzh119
Copy link
Collaborator

yzh119 commented Jan 12, 2025

Hi @sitabulaixizawaluduo , it uses f16 tensor cores for both QK and PV.

I'm working on the head-wise scale of fp8/int8 KV-Caches, I wonder does the following API work for you:

  • K/V are stored in int8/fp8, two additional float16/float32 tensors are provided: qk_scale: Tensor[num_qo_heads], v_scale: Tensor[num_qo_heads] (for GQA, broadcast it).

@sitabulaixizawaluduo
Copy link
Author

Hi @sitabulaixizawaluduo , it uses f16 tensor cores for both QK and PV.

I'm working on the head-wise scale of fp8/int8 KV-Caches, I wonder does the following API work for you:

  • K/V are stored in int8/fp8, two additional float16/float32 tensors are provided: qk_scale: Tensor[num_qo_heads], v_scale: Tensor[num_qo_heads] (for GQA, broadcast it).

I think this is useful. BTW, if online quantization is used, will the overhead generated have a significant impact on performance?

@yzh119
Copy link
Collaborator

yzh119 commented Jan 13, 2025

You mean online quantization of KV-Cache?

I think it will be good to provide an API like https://docs.flashinfer.ai/generated/flashinfer.page.append_paged_kv_cache.html, but fuses with online quantization.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants