-
Notifications
You must be signed in to change notification settings - Fork 280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enabled running Pallas Flash Attention on CPU. #922
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,7 @@ | |
from axlearn.common.flash_attention.utils import _repeat_kv_heads, mha_reference | ||
from axlearn.common.test_utils import TestCase | ||
|
||
if jax.default_backend() != "gpu": | ||
if jax.default_backend() not in ("gpu", "cpu"): | ||
pytest.skip(reason="Incompatible hardware", allow_module_level=True) | ||
|
||
|
||
|
@@ -69,6 +69,8 @@ def test_triton_fwd_only_against_ref( | |
kv_seq_len = seq_len | ||
if kv_seq_len != seq_len and use_segment_ids: | ||
pytest.skip() | ||
if jax.default_backend() == "cpu" and kv_seq_len > 128: | ||
pytest.skip(reason="CI got OOM.") | ||
k1, k2, k3, k4, k5 = jax.random.split(jax.random.PRNGKey(0), 5) | ||
q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype) | ||
k = jax.random.normal(k2, (batch_size, kv_seq_len, num_heads, per_head_dim), dtype=input_dtype) | ||
|
@@ -101,6 +103,7 @@ def test_triton_fwd_only_against_ref( | |
causal=causal, | ||
softmax_scale=softmax_scale, | ||
dropout_rate=dropout_rate, | ||
interpret=(jax.default_backend() == "cpu"), | ||
) | ||
o_ref = mha_reference( | ||
q, | ||
|
@@ -152,6 +155,8 @@ def test_decode_against_ref( | |
kv_head_factor: int, | ||
window_len: int, | ||
): | ||
if jax.default_backend() == "cpu" and seq_len > 1024: | ||
pytest.skip(reason="Too slow on CPU.") | ||
self.assertEqual(num_heads % kv_head_factor, 0) | ||
assert num_heads % kv_head_factor == 0 | ||
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(42), 4) | ||
|
@@ -180,7 +185,14 @@ def test_decode_against_ref( | |
if window_len > 0: | ||
mask_fn = sliding_window_causal_mask(window_len) | ||
o = flash_decoding( | ||
q, k, v, bias=bias, softmax_scale=softmax_scale, kv_seq_len=seq_len, mask_fn=mask_fn | ||
q, | ||
k, | ||
v, | ||
bias=bias, | ||
softmax_scale=softmax_scale, | ||
kv_seq_len=seq_len, | ||
mask_fn=mask_fn, | ||
interpret=(jax.default_backend() == "cpu"), | ||
) | ||
if bias is not None: | ||
bias = bias[:, :, :, :seq_len] | ||
|
@@ -269,6 +281,7 @@ def test_triton_against_xla_ref( | |
block_q=block_size, | ||
block_k=block_size, | ||
dropout_rate=dropout_rate, | ||
interpret=(jax.default_backend() == "cpu"), | ||
) | ||
jax_out = call_flash( | ||
q, | ||
|
@@ -346,6 +359,9 @@ def test_cudnn_against_triton_ref( | |
causal: bool, | ||
dtype: jnp.dtype, | ||
): | ||
if jax.default_backend() == "cpu": | ||
pytest.skip(reason="cudnn function needs GPU.") | ||
|
||
k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3) | ||
q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=dtype) | ||
k = jax.random.normal(k2, (batch_size, seq_len, num_heads, per_head_dim), dtype=dtype) | ||
|
@@ -357,7 +373,15 @@ def test_cudnn_against_triton_ref( | |
jax_out = cudnn_dot_product_attention( | ||
q, k, v, bias=None, causal=causal, softmax_scale=softmax_scale | ||
) | ||
jax_ref_out = flash_attention(q, k, v, bias=None, causal=causal, softmax_scale=softmax_scale) | ||
jax_ref_out = flash_attention( | ||
q, | ||
k, | ||
v, | ||
bias=None, | ||
causal=causal, | ||
softmax_scale=softmax_scale, | ||
interpret=(jax.default_backend() == "cpu"), | ||
) | ||
if dtype == jnp.bfloat16: | ||
# We relax the atol to support bf16 in the unit test. | ||
chex.assert_trees_all_close(jax_out, jax_ref_out, atol=0.02, rtol=1e-5) | ||
|
@@ -372,7 +396,15 @@ def fn(q, k, v): | |
).sum() | ||
|
||
def ref_fn(q, k, v): | ||
return flash_attention(q, k, v, bias=None, causal=causal, softmax_scale=softmax_scale).sum() | ||
return flash_attention( | ||
q, | ||
k, | ||
v, | ||
bias=None, | ||
causal=causal, | ||
softmax_scale=softmax_scale, | ||
interpret=(jax.default_backend() == "cpu"), | ||
).sum() | ||
|
||
# Compare gradients. | ||
jax_grads = jax.grad(fn, argnums=(0, 1, 2))(q, k, v) | ||
|
@@ -414,6 +446,8 @@ def test_cudnn_dropout_against_xla_dropout( | |
by setting V to the identity matrix. However, this only works when seq_len == per_head_dim, | ||
i.e. when the shape of output is the same as the shape of the dropout mask. | ||
""" | ||
if jax.default_backend() == "cpu": | ||
pytest.skip(reason="cudnn function needs GPU.") | ||
Comment on lines
+449
to
+450
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And here and elsewhere. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As mentioned above, keep using |
||
qkv_shape = (batch_size, seq_len, num_heads, per_head_dim) | ||
softmax_scale = 1.0 | ||
cudnn_attn = functools.partial( | ||
|
@@ -481,6 +515,8 @@ def ref_fn(q, k, v): | |
|
||
def test_cudnn_dropout_determinism(): | ||
"""Tests that cuDNN dropout produces identical outputs across runs.""" | ||
if jax.default_backend() == "cpu": | ||
pytest.skip(reason="cudnn function needs GPU.") | ||
k1, k2, k3 = jax.random.split(jax.random.PRNGKey(3), 3) | ||
q = jax.random.normal(k1, (1, 128, 2, 64), dtype=jnp.float16) | ||
k = jax.random.normal(k2, (1, 128, 2, 64), dtype=jnp.float16) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,6 @@ | |
MaskFnAttentionBias, | ||
SegmentIdAttentionBias, | ||
TensorAttentionBias, | ||
ZeroAttentionBias, | ||
split, | ||
) | ||
from axlearn.common.flash_attention.gpu_attention import cudnn_dot_product_attention | ||
|
@@ -203,6 +202,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: | |
mask_fn=mask_fn, | ||
kv_seq_len=kv_seq_len, | ||
softmax_scale=softmax_scale, | ||
interpret=(backend == "cpu"), | ||
) | ||
|
||
key = _repeat_kv_heads(query.shape[2], key) | ||
|
@@ -237,6 +237,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: | |
softmax_scale=softmax_scale, | ||
causal=causal.has_value(), | ||
dropout_rate=dropout_rate, | ||
interpret=(backend == "cpu"), | ||
) | ||
else: | ||
explicit_bias += segment_ids | ||
|
@@ -268,20 +269,18 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]: | |
value, | ||
bias=explicit_bias.value(), | ||
segment_ids=get_segment_ids(segment_ids), | ||
# The `from_sequence()` function guarantees that if there is only one | ||
# mask, it is returned without modification. | ||
# This allows the `causal` path in `_legacy_tpu_flash_attention()` to work. | ||
mask=mask if not isinstance(mask, ZeroAttentionBias) else None, | ||
mask=mask, | ||
softmax_scale=softmax_scale, | ||
block_size=block_size, | ||
interpret=(backend == "cpu"), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given how often we do this across locations, I wonder if we can do the following:
WDYT? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for your suggestion. Additionally, this line was added so contributors can easily debug the Pallas kernel on the CPU. For instance, changing the
would allow debugging in |
||
) | ||
|
||
elif backend in ("cpu", "xla"): | ||
key = _repeat_kv_heads(query.shape[2], key) | ||
value = _repeat_kv_heads(query.shape[2], value) | ||
if backend == "cpu": | ||
logging.warning("Flash attention CPU backend is for testing only.") | ||
logging.warning("Flash attention falling back using plain MHA implementation") | ||
logging.info("Flash attention CPU backend is for testing only.") | ||
logging.info("Flash attention falling back using plain MHA implementation") | ||
|
||
# `causal` is supported. | ||
# `segment_ids` is supported. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likewise, let's avoid assuming that the backend is either gpu or cpu in multiple places.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll leave this code as-is as you asked
In addition, at the begin of file, it allows only "gpu" and "cpu". So
== "cpu"
is!= "gpu"
in this code.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know you are making this assumption, but such dependency is fragile---what if we extend the supported backends in the future?
In this case, requiring the backend to be "gpu" is both more robust and readable. What's the downside?