Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

looking for a test to verify cache correctness in flash_attn_with_kvcache #1414

Open
chakpongchung opened this issue Dec 26, 2024 · 2 comments

Comments

@chakpongchung
Copy link

chakpongchung commented Dec 26, 2024

correct me if I am wrong: It seems there is no such an existing test. FWIW, I looked at the test
test_flash_attn_kvcache from tests/test_flash_attn_ck.py and tests/test_flash_attn.py

Here is my failed attempt. Could you provide a test to demonstrate how to use the KV cache updated in place inside flash_attn_with_kvcache and verify the correctness of the cache by comparing the output with and without using the cache?

from flash_attn import flash_attn_with_kvcache
from flash_attn import flash_attn_func

import torch

n_layers = 2
dim = 3
num_kv_heads = 1
batch_size = 1
head_size = dim // num_kv_heads
block_size = 256

device = 'cuda'
dtype = torch.float16

kv_cache = [(torch.zeros((batch_size, block_size, num_kv_heads, head_size), dtype=dtype, device=device),
             torch.zeros((batch_size, block_size, num_kv_heads, head_size), dtype=dtype, device=device)) for _ in
            range(n_layers)]

layer = 0

k_cache, v_cache = kv_cache[layer]

k = torch.ones((batch_size, block_size, num_kv_heads, head_size), dtype=dtype, device=device)
v = torch.ones((batch_size, block_size, num_kv_heads, head_size), dtype=dtype, device=device)
q = torch.ones((batch_size, block_size, num_kv_heads, head_size), dtype=dtype, device=device)

max_num_blocks_per_seq = 2
block_tables = torch.randint(0,
                             batch_size,
                             (batch_size, max_num_blocks_per_seq),
                             dtype=torch.int32
                             , device=device)

cache_seqlens = torch.tensor([i for i in range(batch_size)], dtype=torch.int32, device=device)

y = flash_attn_with_kvcache(q, k_cache=k_cache, v_cache=v_cache, k=k, v=v, block_table=block_tables,
                            cache_seqlens=cache_seqlens,
                            causal=True)

to_print_v = v.view(batch_size * block_size, num_kv_heads * head_size).float().cpu().numpy()
to_print_v_cache = v_cache.view(batch_size * block_size, num_kv_heads * head_size).float().cpu().numpy()

y_with_cached_KV = flash_attn_with_kvcache(q, k_cache=k_cache, v_cache=v_cache, k=None, v=None,
                             block_table=block_tables,
                             cache_seqlens=cache_seqlens,
                             causal=True)



yy=flash_attn_func(q,k,v)


assert y_with_cached_KV == yy #RuntimeError: Boolean value of Tensor with more than one value is ambiguous


@chakpongchung chakpongchung changed the title looking for a test to compare the result with and without the KV cache looking for a test to compare the result with the KV cache updated in place and without the KV cache Dec 27, 2024
@chakpongchung chakpongchung changed the title looking for a test to compare the result with the KV cache updated in place and without the KV cache looking for a test to verify cache correctness in flash_attn_with_kvcache Jan 8, 2025
@chakpongchung
Copy link
Author

chakpongchung commented Jan 8, 2025

Hi @tridao , could you shed some light on this? I am trying to prototype a memory manager on this. So I need to have the correctness check.

@tridao
Copy link
Contributor

tridao commented Jan 9, 2025

This line checks that the cache update is correct

assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants