Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix +/-inf in LSE returned by forward #978

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

sgrigory
Copy link
Contributor

@sgrigory sgrigory commented Jun 3, 2024

Forward op was returning +inf in LSE for queries which have no keys to attend to, e.g. when K/V length happens to be 0. This diverges from the definition of LSE = log(exp(L1) + ... exp(L2)) which would give log(0) = -inf.
This PR fixes it, which allows feeding the output LSE directly into ops like merge_attentions without postprocessing.

pytest tests/test_flash_attn.py
...
======================================================================================== 268004 passed, 152064 skipped in 4404.00s (1:13:23) =========================================================================================

@tridao
Copy link
Contributor

tridao commented Jun 27, 2024

One issue I can see is that in the backward pass, if lse = +inf then exp(qk - lse) returns 0, which is what we want. If lse = -inf then exp would blow up.

@GD06
Copy link
Contributor

GD06 commented Jan 3, 2025

QQ: do we plan to merge this PR as it has been pending for months.

@sgrigory
Copy link
Contributor Author

QQ: do we plan to merge this PR as it has been pending for months.

Sorry, I didn't follow-up on @tridao's comment above. Basically I think there should be no NaNs after this change because the code actually checks for -inf before computing exp(score - lse) in the backward pass

const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));

Also, in the Hopper kernel we write -inf for out-of-bounds positions

if (row < seqlen_o) { mLSE(row) = -INFINITY; }
} else {
if (row < seqlen_o * qhead_per_khead) {
int m_idx, h_idx;
m_idx = params.qhead_per_khead_divmod.divmod(h_idx, row);
// mLSE shape shape ((qhead_per_khead, seqlen_q)) and it's unhappy with just 1 "make_coord"
mLSE(make_coord(make_coord(h_idx, m_idx))) = -INFINITY;

// Write 0 to gO and -inf to gLSE.

If that makes sense and FA2 code is still relevant, I add a test which cover backward behaviour in this situation to make the PR mergeable

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants