You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We need a paged attention capable of handling multiple query tokens in a sequence.
Motivation
The [existing paged attention(https://github.com/jax-ml/jax/blob/3aa55992fe374987ff3701b69d6814c007c37bb3/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py#L374) has limited support in that it requires the input to have a single token per sequence. This limitation prevents vLLM from having other powerful features such as speculative decoding, prefix caching, and chunked prefill. In speculative decoding, for example, we need to decode multiple input query tokens in parallel at the same time for a given sequence. Hence, this new kernel is a hard blocker for vLLM to shine on TPU. So a new paged attention is needed.
q=jnp.permute_dim(q, (0,2,1,3)) # in order to put the num_head dim before length dim
for b_idx in range(batch_size):
for kv_head_idx in range(num_kv_heads):
for q_blk_idx in range(num_queries_len_blocks):
for kv_blk_idx in range(num_kv_len_blocks):
# Within the kernel
# q.shape=[num_q_heads_per_kv_head, query_len_per_q_len_block, head_size]
# Load the kv pages corresponding to the current batch from HBM to VMEM
for q_head_idx in range(num_q_heads_per_kv_head):
# Within the flash attention kernel
# q.shape=[query_len_per_q_len_block, head_size]
# k.shape=[kv_len_per_kv_len_block, head_size]
# attn=[query_len_per_q_len_block, kv_len_per_kv_len_block]
# v.shape=[kv_len_per_kv_len_block, head_size]\
# out.shape=[query_len_per_q_len_block, head_size]
# save out to q_head_idx of final_out.
# final_out.shape=[num_q_heads_per_kv_head, query_len_per_q_len_block, head_size]
Alternatives
Additional context
The text was updated successfully, but these errors were encountered:
🚀 Feature
We need a paged attention capable of handling multiple query tokens in a sequence.
Motivation
The [existing paged attention(https://github.com/jax-ml/jax/blob/3aa55992fe374987ff3701b69d6814c007c37bb3/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py#L374) has limited support in that it requires the input to have a single token per sequence. This limitation prevents vLLM from having other powerful features such as speculative decoding, prefix caching, and chunked prefill. In speculative decoding, for example, we need to decode multiple input query tokens in parallel at the same time for a given sequence. Hence, this new kernel is a hard blocker for vLLM to shine on TPU. So a new paged attention is needed.
Pitch
We need a new Pallas kernel:
The rough logic maps to
Alternatives
Additional context
The text was updated successfully, but these errors were encountered: