-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Faster MHA backwards pass #22820
Faster MHA backwards pass #22820
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!!
|
||
del dv, dk | ||
|
||
# Scan #2: dQ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, is there an advantage to doing this in one kernel vs two kernels?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it comes down to the fact that there’s more work to do in a kernel, which leads to better occupancy of warps.
Other factors that will influence the gpu utilization:
- There’s some data locality between the 2 loops, but it's more significant for smaller sequence lengths.
- Overhead of launching 2 kernels.
- Making sure the kernels are actually executing in parallel. They need to be launched on separate cuda streams, and even I don't think that it's guaranteed.
Here are the numbers compared the previous kernel. Trying different values of num_wraps, num_wraps and block sizes can be helpful to further increase performance for your hardware. Note that the improvement is the relative speedup, and not just the percentage increase.
|
This PR implements a faster backwards pass of the Multi-Headed Attention pallas kernel.
The biggest improvements on the speedup are:
This builds on work from @tonywu95 in jax-ml/jax-triton#177 and is inspired by the triton tutorial https://github.com/triton-lang/triton/blob/main/python/tutorials/06-fused-attention.py.
Comparison against the XLA bwd pass across different configurations: