Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Multi-queries paged attention Pallas kernel #8597

Open
vanbasten23 opened this issue Jan 22, 2025 · 2 comments
Open

[RFC] Multi-queries paged attention Pallas kernel #8597

vanbasten23 opened this issue Jan 22, 2025 · 2 comments

Comments

@vanbasten23
Copy link
Collaborator

🚀 Feature

We need a paged attention capable of handling multiple query tokens in a sequence.

Motivation

The [existing paged attention(https://github.com/jax-ml/jax/blob/3aa55992fe374987ff3701b69d6814c007c37bb3/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py#L374) has limited support in that it requires the input to have a single token per sequence. This limitation prevents vLLM from having other powerful features such as speculative decoding, prefix caching, and chunked prefill. In speculative decoding, for example, we need to decode multiple input query tokens in parallel at the same time for a given sequence. Hence, this new kernel is a hard blocker for vLLM to shine on TPU. So a new paged attention is needed.

Pitch

We need a new Pallas kernel:

def paged_attention(
    q: jax.Array,			# [batch_size, query_len, num_heads, head_dim] 
    k_pages: jax.Array,		# [num_kv_heads, total_num_pages, page_size, head_dim]
    v_pages: jax.Array,		# [num_kv_heads, total_num_pages, page_size, head_dim]
    lengths: jax.Array,		# i32[batch_size]
    page_indices: jax.Array,	# i32[batch_size, pages_per_sequence]
    effective_q_lens: jax.Array, # i32[batch_size]
) -> jax.Array:			# [batch_size, query_len, num_heads, head_dim]

The rough logic maps to

q=jnp.permute_dim(q, (0,2,1,3))  # in order to put the num_head dim before length dim
for b_idx in range(batch_size):
  for kv_head_idx in range(num_kv_heads):
    for q_blk_idx in range(num_queries_len_blocks):
      for kv_blk_idx in range(num_kv_len_blocks):
        # Within the kernel
        # q.shape=[num_q_heads_per_kv_head, query_len_per_q_len_block, head_size]
        # Load the kv pages corresponding to the current batch from HBM to VMEM
        for q_head_idx in range(num_q_heads_per_kv_head):
          # Within the flash attention kernel
          # q.shape=[query_len_per_q_len_block, head_size]
          # k.shape=[kv_len_per_kv_len_block, head_size]
          # attn=[query_len_per_q_len_block, kv_len_per_kv_len_block]
          # v.shape=[kv_len_per_kv_len_block, head_size]\
          # out.shape=[query_len_per_q_len_block, head_size]
          # save out to q_head_idx of final_out.
        # final_out.shape=[num_q_heads_per_kv_head, query_len_per_q_len_block, head_size]

Alternatives

Additional context

@vanbasten23
Copy link
Collaborator Author

cc @miladm

@vanbasten23
Copy link
Collaborator Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant