Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ragged paged attention #8659

Open
wants to merge 8 commits into
base: master
Choose a base branch
from

Conversation

vanbasten23
Copy link
Collaborator

Test plan:

LIBTPU_INIT_ARGS=--xla_tpu_scoped_vmem_limit_kib=65536  python /workspaces/persist/pytorch/xla/test/test_ragged_paged_attention_kernel.py 2>&1 | tee out.txt

test/test_ragged_paged_attention_kernel.py Outdated Show resolved Hide resolved
last_time_seeing_cur_physical_q_blk = jnp.logical_or(is_last_logical_q_blk, physical_q_blk_will_change)
should_store_to_hbm = jnp.logical_and(is_last_kv_blk_idx, last_time_seeing_cur_physical_q_blk)
@pl.when(should_store_to_hbm)
def store_to_hbm(): # pylint: disable=unused-variable

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function actually only stores to VMEM and all the ref used are actually VMEM ref. We rely on pipeline emitter to send vmem block back to HBM. And we can't store vregs directly to HBM in kernel. So I think original store_to_output makes more sense here.

Comment on lines 894 to 884
pages_per_sequence=pages_per_sequence,
num_tokens=num_tokens,
num_seqs=num_seqs, # it they changes, need to recompile.
num_kv_pages_per_compute_block=num_kv_pages_per_compute_block,
mask_value=mask_value,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are all calculated from either static shape or static_argnames.

The overhead could be very large if some are too dynamic (num_seqs, num_tokens) because we need to recompile many times. Maybe to add a note or TODO here.

@bythew3i
Copy link

Test plan:

LIBTPU_INIT_ARGS=--xla_tpu_scoped_vmem_limit_kib=65536  python /workspaces/persist/pytorch/xla/test/test_ragged_paged_attention_kernel.py 2>&1 | tee out.txt

How is 65536 calculated?

@vanbasten23
Copy link
Collaborator Author

Test plan:

LIBTPU_INIT_ARGS=--xla_tpu_scoped_vmem_limit_kib=65536  python /workspaces/persist/pytorch/xla/test/test_ragged_paged_attention_kernel.py 2>&1 | tee out.txt

How is 65536 calculated?

I found a ticket and someone uses it. I remember the number is the vmem limit on a TPU generation.

@vanbasten23 vanbasten23 force-pushed the xiowei/add_ragged_paged_attention branch from ad2f87c to 9e4b227 Compare February 1, 2025 00:32
@vanbasten23 vanbasten23 force-pushed the xiowei/add_ragged_paged_attention branch from 9e4b227 to 7fe5071 Compare February 3, 2025 05:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants