Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] MiniMax-01 Lightning Attention #739

Open
yzh119 opened this issue Jan 16, 2025 · 4 comments
Open

[Feature] MiniMax-01 Lightning Attention #739

yzh119 opened this issue Jan 16, 2025 · 4 comments

Comments

@yzh119
Copy link
Collaborator

yzh119 commented Jan 16, 2025

The MiniMax-01 scales linear attention to large-scale model (456B) and FlashInfer should support it.

The prefill computation of the lightning attention (forward) can be summarized as:

Image

Image

The computation of O_intra of each tile is completely independent and we can just reuse our existing attention kernel by setting use_softmax=False in our attention variant class.

The computation of O_inter is basically a scan operation, we can either perform the entire loop per request within a CTA, or using split-K. In the second case, we split the N into chunks, we first compute the KV matrix of each chunk, compute the cumsum of KV, then compute the O_inter of all tiles independently. The split-k chunk size can be selected adaptively to strike a balance between the O_inter overhead (determined by number of chunks) and the O_intra computation overhead (determined by chunk size). KV should be kept in f32 precision considering the accumulation precision for long context.

For decode, there is no need to maintain KV-Cache in Page Table, we just need to keep one KV (dxd) matrix per request, and accumulating KV by Ki^T Vi for step i. It's still possible to maintain a unified page for softmax attention layers' KV-Cache and linear attention layers' KV, in that case, we can add gather gemm operators to flashinfer for O_inter computation.

@leifeng666
Copy link

Hi @yzh119, I am interested in working on this issue if no one is currently working on it. Could you assign this issue to me?

@yzh119
Copy link
Collaborator Author

yzh119 commented Feb 2, 2025

Hi @leifeng666 sure! I think a good starting point is to run benchmark on state-of-the-art triton implementation (such as flash-linear-attention) and see how far it's from speed-of-light.

@leifeng666
Copy link

@yzh119 at the first glance, seems like lightning attention haven't been supported in fla yet. I just created an issue to ask the fla community if they are going to implement that in the near future: fla-org/flash-linear-attention#164. At the same time, I will also try to find another triton implementation and do a benchmark.

@yzhangcs
Copy link

yzhangcs commented Feb 5, 2025

Checkout fla-org/flash-linear-attention@f14178c
Lightning Attention has been integrated into fla.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants