Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabled running Pallas Flash Attention on CPU. #922

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions axlearn/common/flash_attention/gpu_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from axlearn.common.flash_attention.utils import _repeat_kv_heads, mha_reference
from axlearn.common.test_utils import TestCase

if jax.default_backend() != "gpu":
if jax.default_backend() not in ("gpu", "cpu"):
pytest.skip(reason="Incompatible hardware", allow_module_level=True)


Expand Down Expand Up @@ -69,6 +69,8 @@ def test_triton_fwd_only_against_ref(
kv_seq_len = seq_len
if kv_seq_len != seq_len and use_segment_ids:
pytest.skip()
if jax.default_backend() == "cpu" and kv_seq_len > 128:
pytest.skip(reason="CI got OOM.")
k1, k2, k3, k4, k5 = jax.random.split(jax.random.PRNGKey(0), 5)
q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype)
k = jax.random.normal(k2, (batch_size, kv_seq_len, num_heads, per_head_dim), dtype=input_dtype)
Expand Down Expand Up @@ -101,6 +103,7 @@ def test_triton_fwd_only_against_ref(
causal=causal,
softmax_scale=softmax_scale,
dropout_rate=dropout_rate,
interpret=(jax.default_backend() == "cpu"),
)
o_ref = mha_reference(
q,
Expand Down Expand Up @@ -152,6 +155,8 @@ def test_decode_against_ref(
kv_head_factor: int,
window_len: int,
):
if jax.default_backend() == "cpu" and seq_len > 1024:
pytest.skip(reason="Too slow on CPU.")
self.assertEqual(num_heads % kv_head_factor, 0)
assert num_heads % kv_head_factor == 0
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(42), 4)
Expand Down Expand Up @@ -180,7 +185,14 @@ def test_decode_against_ref(
if window_len > 0:
mask_fn = sliding_window_causal_mask(window_len)
o = flash_decoding(
q, k, v, bias=bias, softmax_scale=softmax_scale, kv_seq_len=seq_len, mask_fn=mask_fn
q,
k,
v,
bias=bias,
softmax_scale=softmax_scale,
kv_seq_len=seq_len,
mask_fn=mask_fn,
interpret=(jax.default_backend() == "cpu"),
)
if bias is not None:
bias = bias[:, :, :, :seq_len]
Expand Down Expand Up @@ -269,6 +281,7 @@ def test_triton_against_xla_ref(
block_q=block_size,
block_k=block_size,
dropout_rate=dropout_rate,
interpret=(jax.default_backend() == "cpu"),
)
jax_out = call_flash(
q,
Expand Down Expand Up @@ -346,6 +359,9 @@ def test_cudnn_against_triton_ref(
causal: bool,
dtype: jnp.dtype,
):
if jax.default_backend() == "cpu":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise, let's avoid assuming that the backend is either gpu or cpu in multiple places.

Suggested change
if jax.default_backend() == "cpu":
if jax.default_backend() != "gpu":

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave this code as-is as you asked

Nit: can we check it against "cpu" directly instead of != "gpu"?

In addition, at the begin of file, it allows only "gpu" and "cpu". So == "cpu" is != "gpu" in this code.

if jax.default_backend() not in ("gpu", "cpu"):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition, at the begin of file, it allows only "gpu" and "cpu". So == "cpu" is != "gpu" in this code.

I know you are making this assumption, but such dependency is fragile---what if we extend the supported backends in the future?

In this case, requiring the backend to be "gpu" is both more robust and readable. What's the downside?

pytest.skip(reason="cudnn function needs GPU.")

k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3)
q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=dtype)
k = jax.random.normal(k2, (batch_size, seq_len, num_heads, per_head_dim), dtype=dtype)
Expand All @@ -357,7 +373,15 @@ def test_cudnn_against_triton_ref(
jax_out = cudnn_dot_product_attention(
q, k, v, bias=None, causal=causal, softmax_scale=softmax_scale
)
jax_ref_out = flash_attention(q, k, v, bias=None, causal=causal, softmax_scale=softmax_scale)
jax_ref_out = flash_attention(
q,
k,
v,
bias=None,
causal=causal,
softmax_scale=softmax_scale,
interpret=(jax.default_backend() == "cpu"),
)
if dtype == jnp.bfloat16:
# We relax the atol to support bf16 in the unit test.
chex.assert_trees_all_close(jax_out, jax_ref_out, atol=0.02, rtol=1e-5)
Expand All @@ -372,7 +396,15 @@ def fn(q, k, v):
).sum()

def ref_fn(q, k, v):
return flash_attention(q, k, v, bias=None, causal=causal, softmax_scale=softmax_scale).sum()
return flash_attention(
q,
k,
v,
bias=None,
causal=causal,
softmax_scale=softmax_scale,
interpret=(jax.default_backend() == "cpu"),
).sum()

# Compare gradients.
jax_grads = jax.grad(fn, argnums=(0, 1, 2))(q, k, v)
Expand Down Expand Up @@ -414,6 +446,8 @@ def test_cudnn_dropout_against_xla_dropout(
by setting V to the identity matrix. However, this only works when seq_len == per_head_dim,
i.e. when the shape of output is the same as the shape of the dropout mask.
"""
if jax.default_backend() == "cpu":
pytest.skip(reason="cudnn function needs GPU.")
Comment on lines +449 to +450
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here and elsewhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, keep using jax.default_backend() == "cpu":

qkv_shape = (batch_size, seq_len, num_heads, per_head_dim)
softmax_scale = 1.0
cudnn_attn = functools.partial(
Expand Down Expand Up @@ -481,6 +515,8 @@ def ref_fn(q, k, v):

def test_cudnn_dropout_determinism():
"""Tests that cuDNN dropout produces identical outputs across runs."""
if jax.default_backend() == "cpu":
pytest.skip(reason="cudnn function needs GPU.")
k1, k2, k3 = jax.random.split(jax.random.PRNGKey(3), 3)
q = jax.random.normal(k1, (1, 128, 2, 64), dtype=jnp.float16)
k = jax.random.normal(k2, (1, 128, 2, 64), dtype=jnp.float16)
Expand Down
9 changes: 8 additions & 1 deletion axlearn/common/flash_attention/layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import jax
import jax.numpy as jnp
import pytest
from absl.testing import parameterized
from absl.testing import absltest, parameterized
from jax.experimental import mesh_utils
from jax.sharding import Mesh

Expand Down Expand Up @@ -98,6 +98,7 @@ def _prepare_layers(
sliding_window_size,
inference=False,
set_layer_bias_recursively=False,
tpu_block_size=512,
dropout_rate=0.0,
):
hidden_dim = num_heads * per_head_dim
Expand All @@ -119,6 +120,7 @@ def _prepare_layers(
.set(
mha_dim_to_partition_spec=default_mha_dim_to_partition_spec(mesh_axis_names),
output_dim_to_partition_spec=default_output_dim_to_partition_spec(mesh_axis_names),
tpu_block_size=tpu_block_size,
)
)
if inference:
Expand Down Expand Up @@ -432,6 +434,7 @@ def test_forward(
causal=causal,
sliding_window_size=sliding_window_size,
dropout_rate=dropout_rate,
tpu_block_size=128,
)

query_len = int(query_len_multiplier * seq_len)
Expand Down Expand Up @@ -816,3 +819,7 @@ def test_extend_step(
atol=2e-2,
)
jax.extend.backend.clear_backends()


if __name__ == "__main__":
absltest.main()
37 changes: 19 additions & 18 deletions axlearn/common/flash_attention/tpu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

"""Wrappers for FlashAttention on TPU in JAX with logit bias support."""
import functools
from typing import Optional, Union
from typing import Optional

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -40,6 +40,8 @@
)
from axlearn.common.utils import Tensor

MaskFnOrZero = MaskFnAttentionBias | ZeroAttentionBias


def tpu_flash_attention(
query: Tensor, # [batch_size, target_len, num_heads, head_dim]
Expand All @@ -48,7 +50,7 @@ def tpu_flash_attention(
bias: Tensor = None, # [batch_size, num_heads, target_len, source_len]
segment_ids: Tensor = None, # [batch_size, target_len]
*,
mask: Optional[MaskFnAttentionBias] = None,
mask: MaskFnOrZero,
softmax_scale: float = 1.0,
block_size: int = 128,
interpret: bool = False,
Expand Down Expand Up @@ -113,16 +115,17 @@ def tpu_flash_attention(
f"Source seq len {key.shape[1]} must be divisible by block size {block_size}."
)

mask: Union[MaskFnAttentionBias | ZeroAttentionBias] = as_attention_bias(mask)
mask: MaskFnOrZero = as_attention_bias(mask)

# Switch num_heads and seq_len axes.
query = jnp.einsum("btnh->bnth", query)
key = jnp.einsum("bsnh->bnsh", key)
value = jnp.einsum("bsnh->bnsh", value)
try:
check_tpu_splash_attention(
query=query,
key=key,
target_len=query.shape[2],
source_len=key.shape[2],
head_dim=query.shape[3],
mask=mask,
has_segment_ids=(segment_ids is not None),
has_bias=(bias is not None),
Expand Down Expand Up @@ -199,7 +202,7 @@ def _legacy_tpu_flash_attention(
bias: Tensor = None, # [batch_size, num_heads, target_len, source_len]
segment_ids: Tensor = None, # [batch_size, target_len]
*,
mask: MaskFnAttentionBias,
mask: MaskFnOrZero,
block_sizes: Optional[LegacyBlockSizes] = None,
interpret: bool = False,
) -> Tensor: # [batch_size, num_heads, target_len, head_dim].
Expand Down Expand Up @@ -253,17 +256,19 @@ class SplashAttentionUnsupportedError(NotImplementedError):

def check_tpu_splash_attention(
*,
query: Tensor, # [batch_size, num_heads, source_len, head_dim]
key: Tensor, # [batch_size, num_heads, target_len, head_dim]
mask: Union[MaskFnAttentionBias | ZeroAttentionBias],
target_len: int,
source_len: int,
head_dim: int,
mask: MaskFnOrZero,
has_segment_ids: bool = False,
has_bias: bool = False,
):
"""Checks if splash attention is supported on TPU for the given arguments.

Args:
query: The query tensor, of shape [batch_size, num_heads, target_len, head_dim].
key: The key tensor, of shape [batch_size, num_heads, source_len, head_dim].
target_len: The length of the target sequence.
source_len: The length of the source sequence.
head_dim: The dimension of each head.
mask: The mask to apply. This is more compute efficient compared to setting bias = -inf.
has_segment_ids: Whether segment_ids is None or not.
has_bias: Whether attention involves a bias.
Expand All @@ -272,12 +277,8 @@ def check_tpu_splash_attention(
SplashAttentionUnsupportedError: If splash attention is not supported for the given
arguments.
"""
target_len = query.shape[2]
source_len = key.shape[2]
head_dim = query.shape[3]

if has_bias:
return False # SplashAttention does not support specifying a bias.
raise SplashAttentionUnsupportedError("SplashAttention does not support specifying a bias.")
with jax.ensure_compile_time_eval():
if jnp.any(
jnp.asarray([target_len, source_len, head_dim]) % splash_attention_kernel.NUM_LANES != 0
Expand Down Expand Up @@ -305,7 +306,7 @@ def check_tpu_splash_attention(


def _to_splash_mask(
mask: Union[MaskFnAttentionBias | ZeroAttentionBias],
mask: MaskFnOrZero,
*,
mask_shape: tuple[int, int],
q_seq_shards: int = 1,
Expand Down Expand Up @@ -344,7 +345,7 @@ def _tpu_splash_attention(
key: Tensor, # [batch_size, num_heads, source_len, head_dim]
value: Tensor, # [batch_size, num_heads, source_len, head_dim]
*,
mask: Union[MaskFnAttentionBias | ZeroAttentionBias],
mask: MaskFnOrZero,
segment_ids: Optional[Tensor] = None, # [batch_size, target_len]
block_sizes: Optional[splash_attention_kernel.BlockSizes] = None,
interpret: bool = False,
Expand Down
19 changes: 8 additions & 11 deletions axlearn/common/flash_attention/tpu_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import unittest

import chex
import jax
import jax.numpy as jnp
import numpy as np
Expand All @@ -29,14 +28,10 @@
from axlearn.common.test_utils import TestCase, is_supported_mesh_shape
from axlearn.common.utils import Tensor

# Comment out to test on CPU manually. Technically, this test runs on the CPU, albeit very slowly.
if jax.default_backend() != "tpu":
pytest.skip(reason="Incompatible hardware", allow_module_level=True)


def setUpModule():
# If on CPU, emulate 4 devices.
chex.set_n_cpu_devices(4)
if jax.default_backend() not in ("tpu", "cpu"):
pytest.skip(reason="Incompatible hardware", allow_module_level=True)


def jax_fn_mask(query_position: Tensor, key_position: Tensor) -> Tensor:
Expand Down Expand Up @@ -102,7 +97,6 @@ def test_to_splash_mask(self, mask, expected):
sliding_window_size=[1024],
num_heads=[4],
per_head_dim=[256],
mesh=[(4, 1)],
mesh_axis_names=[("data", "model")],
)
def test_forward(
Expand All @@ -113,11 +107,12 @@ def test_forward(
per_head_dim,
mask_fn,
sliding_window_size,
mesh,
mesh_axis_names,
):
if not is_supported_mesh_shape(mesh):
pytest.skip(reason=f"Unsupported mesh {mesh}.")
if jax.default_backend() == "cpu" and seq_len > 1024:
pytest.skip(reason="Too slow on CPU.")
mesh = (1, 1) if jax.default_backend() == "cpu" else (4, 1)
self.assertTrue(is_supported_mesh_shape(mesh))

k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3)
q = jax.random.normal(
Expand Down Expand Up @@ -254,6 +249,8 @@ def ref_fn(q, k, v, bias, ids):

if mask is not None:
mask = MaskFnAttentionBias(mask, shape=(query_len, kv_len))
else:
mask = ZeroAttentionBias()

def fn(q, k, v, bias, ids):
record_legacy_call = unittest.mock.patch.object(
Expand Down
13 changes: 6 additions & 7 deletions axlearn/common/flash_attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
MaskFnAttentionBias,
SegmentIdAttentionBias,
TensorAttentionBias,
ZeroAttentionBias,
split,
)
from axlearn.common.flash_attention.gpu_attention import cudnn_dot_product_attention
Expand Down Expand Up @@ -203,6 +202,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
mask_fn=mask_fn,
kv_seq_len=kv_seq_len,
softmax_scale=softmax_scale,
interpret=(backend == "cpu"),
)

key = _repeat_kv_heads(query.shape[2], key)
Expand Down Expand Up @@ -237,6 +237,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
softmax_scale=softmax_scale,
causal=causal.has_value(),
dropout_rate=dropout_rate,
interpret=(backend == "cpu"),
)
else:
explicit_bias += segment_ids
Expand Down Expand Up @@ -268,20 +269,18 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
value,
bias=explicit_bias.value(),
segment_ids=get_segment_ids(segment_ids),
# The `from_sequence()` function guarantees that if there is only one
# mask, it is returned without modification.
# This allows the `causal` path in `_legacy_tpu_flash_attention()` to work.
mask=mask if not isinstance(mask, ZeroAttentionBias) else None,
mask=mask,
softmax_scale=softmax_scale,
block_size=block_size,
interpret=(backend == "cpu"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given how often we do this across locations, I wonder if we can do the following:

  • Make interpret default to None (instead of False);
  • If it's None, assume interpret=True if the backend is "cpu";

WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestion. interpret=True applies only to the Pallas kernel. Therefore, having an interpret variable in the flash layer is not aligned with the appropriate level of abstraction—neither the JAX fallback nor the cudnn code paths needs this variable.

Additionally, this line was added so contributors can easily debug the Pallas kernel on the CPU. For instance, changing the if statement to:

elif backend in ("cpu", "tpu"):

would allow debugging in layer_test.py.

)

elif backend in ("cpu", "xla"):
key = _repeat_kv_heads(query.shape[2], key)
value = _repeat_kv_heads(query.shape[2], value)
if backend == "cpu":
logging.warning("Flash attention CPU backend is for testing only.")
logging.warning("Flash attention falling back using plain MHA implementation")
logging.info("Flash attention CPU backend is for testing only.")
logging.info("Flash attention falling back using plain MHA implementation")

# `causal` is supported.
# `segment_ids` is supported.
Expand Down
Loading