You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
correct me if I am wrong: It seems there is no such an existing test. FWIW, I looked at the test test_flash_attn_kvcache from tests/test_flash_attn_ck.py and tests/test_flash_attn.py
Here is my failed attempt. Could you provide a test to demonstrate how to use the KV cache updated in place inside flash_attn_with_kvcache and verify the correctness of the cache by comparing the output with and without using the cache?
from flash_attn import flash_attn_with_kvcache
from flash_attn import flash_attn_func
import torch
n_layers = 2
dim = 3
num_kv_heads = 1
batch_size = 1
head_size = dim // num_kv_heads
block_size = 256
device = 'cuda'
dtype = torch.float16
kv_cache = [(torch.zeros((batch_size, block_size, num_kv_heads, head_size), dtype=dtype, device=device),
torch.zeros((batch_size, block_size, num_kv_heads, head_size), dtype=dtype, device=device)) for _ in
range(n_layers)]
layer = 0
k_cache, v_cache = kv_cache[layer]
k = torch.ones((batch_size, block_size, num_kv_heads, head_size), dtype=dtype, device=device)
v = torch.ones((batch_size, block_size, num_kv_heads, head_size), dtype=dtype, device=device)
q = torch.ones((batch_size, block_size, num_kv_heads, head_size), dtype=dtype, device=device)
max_num_blocks_per_seq = 2
block_tables = torch.randint(0,
batch_size,
(batch_size, max_num_blocks_per_seq),
dtype=torch.int32
, device=device)
cache_seqlens = torch.tensor([i for i in range(batch_size)], dtype=torch.int32, device=device)
y = flash_attn_with_kvcache(q, k_cache=k_cache, v_cache=v_cache, k=k, v=v, block_table=block_tables,
cache_seqlens=cache_seqlens,
causal=True)
to_print_v = v.view(batch_size * block_size, num_kv_heads * head_size).float().cpu().numpy()
to_print_v_cache = v_cache.view(batch_size * block_size, num_kv_heads * head_size).float().cpu().numpy()
y_with_cached_KV = flash_attn_with_kvcache(q, k_cache=k_cache, v_cache=v_cache, k=None, v=None,
block_table=block_tables,
cache_seqlens=cache_seqlens,
causal=True)
yy=flash_attn_func(q,k,v)
assert y_with_cached_KV == yy #RuntimeError: Boolean value of Tensor with more than one value is ambiguous
The text was updated successfully, but these errors were encountered:
chakpongchung
changed the title
looking for a test to compare the result with and without the KV cache
looking for a test to compare the result with the KV cache updated in place and without the KV cache
Dec 27, 2024
chakpongchung
changed the title
looking for a test to compare the result with the KV cache updated in place and without the KV cache
looking for a test to verify cache correctness in flash_attn_with_kvcacheJan 8, 2025
correct me if I am wrong: It seems there is no such an existing test. FWIW, I looked at the test
test_flash_attn_kvcache
fromtests/test_flash_attn_ck.py
andtests/test_flash_attn.py
Here is my failed attempt. Could you provide a test to demonstrate how to use the KV cache updated in place inside
flash_attn_with_kvcache
and verify the correctness of the cache by comparing the output with and without using the cache?The text was updated successfully, but these errors were encountered: