Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabled running Pallas Flash Attention on CPU. #922

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ds-hwang
Copy link
Contributor

Enabled running Pallas Flash Attention on CPU.

Pallas supports CPU simulation (interpret=True), so we can use the same
TPU Pallas kernel on CPU — making code debugging easier.

This change lets the following unittests run on CPU as if they were on TPU,
enabling easier testing and debugging:

  • axlearn/common/flash_attention/tpu_attention_test.py

Similarly, gpu_attention_test.py can also be run on CPU as if they were on GPU.

  • axlearn/common/flash_attention/gpu_attention_test.py

Now CI covers those tests on CPU as well.
In M3 Max MacBook Pro, test coverages and processing time are as follows,

  • axlearn/common/flash_attention/gpu_attention_test.py: 3024 passed, 1345 skipped in 200.38s (0:03:20)
  • axlearn/common/flash_attention/tpu_attention_test.py: 18 passed, 435 skipped in 34.82s

@ds-hwang ds-hwang requested review from ruomingp, markblee and a team as code owners January 14, 2025 04:59
@ds-hwang
Copy link
Contributor Author

@ruomingp Could you take a look? From 975

Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few thoughts missed in earlier reviews...

@@ -152,6 +153,8 @@ def test_decode_against_ref(
kv_head_factor: int,
window_len: int,
):
if jax.default_backend() != "gpu" and seq_len > 1024:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: can we check it against "cpu" directly instead of != "gpu"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, done.

@@ -346,6 +357,9 @@ def test_cudnn_against_triton_ref(
causal: bool,
dtype: jnp.dtype,
):
if jax.default_backend() == "cpu":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise, let's avoid assuming that the backend is either gpu or cpu in multiple places.

Suggested change
if jax.default_backend() == "cpu":
if jax.default_backend() != "gpu":

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave this code as-is as you asked

Nit: can we check it against "cpu" directly instead of != "gpu"?

In addition, at the begin of file, it allows only "gpu" and "cpu". So == "cpu" is != "gpu" in this code.

if jax.default_backend() not in ("gpu", "cpu"):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition, at the begin of file, it allows only "gpu" and "cpu". So == "cpu" is != "gpu" in this code.

I know you are making this assumption, but such dependency is fragile---what if we extend the supported backends in the future?

In this case, requiring the backend to be "gpu" is both more robust and readable. What's the downside?

Comment on lines +447 to +450
if jax.default_backend() == "cpu":
pytest.skip(reason="cudnn function needs GPU.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here and elsewhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, keep using jax.default_backend() == "cpu":

Comment on lines 100 to 95
seq_len=[1024, 32768],
seq_len=[1024],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the sliding window size is 1024, it will be useful to keep a test case for seq_len > 1024. We can enable the test only on TPU if it's too slow on CPU. We can also use a seq_len such as 2048 for cpu if it's fast enough.

Copy link
Contributor Author

@ds-hwang ds-hwang Jan 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I changed it back to resume the first PR's code.

We had this thread in 975

@ruomingp Do we need to support seq_len up to 1024? If the block size is 128, supporting <= 256 should be enough?

@ds-hwang Agreed. I removed 32k test with this if-statement.

softmax_scale=softmax_scale,
block_size=block_size,
interpret=(backend == "cpu"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given how often we do this across locations, I wonder if we can do the following:

  • Make interpret default to None (instead of False);
  • If it's None, assume interpret=True if the backend is "cpu";

WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestion. interpret=True applies only to the Pallas kernel. Therefore, having an interpret variable in the flash layer is not aligned with the appropriate level of abstraction—neither the JAX fallback nor the cudnn code paths needs this variable.

Additionally, this line was added so contributors can easily debug the Pallas kernel on the CPU. For instance, changing the if statement to:

elif backend in ("cpu", "tpu"):

would allow debugging in layer_test.py.

Pallas supports CPU simulation (`interpret=True`), so we can use the same
TPU Pallas kernel on CPU — making code debugging easier.

This change lets the following unittests run on CPU as if they were on TPU,
enabling easier testing and debugging:
- `axlearn/common/flash_attention/tpu_attention_test.py`

Similarly, `gpu_attention_test.py` can also be run on CPU as if they were on GPU.
- `axlearn/common/flash_attention/gpu_attention_test.py`

Now CI covers those tests on CPU as well.
In M3 Max MacBook Pro, test coverages and processing time are as follows,
* axlearn/common/flash_attention/gpu_attention_test.py: 3024 passed, 1345 skipped in 200.38s (0:03:20)
* axlearn/common/flash_attention/tpu_attention_test.py: 18 passed, 435 skipped in 34.82s
Copy link
Contributor Author

@ds-hwang ds-hwang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for review. I responded all comments. Could you check it again?

softmax_scale=softmax_scale,
block_size=block_size,
interpret=(backend == "cpu"),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestion. interpret=True applies only to the Pallas kernel. Therefore, having an interpret variable in the flash layer is not aligned with the appropriate level of abstraction—neither the JAX fallback nor the cudnn code paths needs this variable.

Additionally, this line was added so contributors can easily debug the Pallas kernel on the CPU. For instance, changing the if statement to:

elif backend in ("cpu", "tpu"):

would allow debugging in layer_test.py.

@@ -152,6 +153,8 @@ def test_decode_against_ref(
kv_head_factor: int,
window_len: int,
):
if jax.default_backend() != "gpu" and seq_len > 1024:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, done.

@@ -346,6 +357,9 @@ def test_cudnn_against_triton_ref(
causal: bool,
dtype: jnp.dtype,
):
if jax.default_backend() == "cpu":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll leave this code as-is as you asked

Nit: can we check it against "cpu" directly instead of != "gpu"?

In addition, at the begin of file, it allows only "gpu" and "cpu". So == "cpu" is != "gpu" in this code.

if jax.default_backend() not in ("gpu", "cpu"):

Comment on lines +447 to +450
if jax.default_backend() == "cpu":
pytest.skip(reason="cudnn function needs GPU.")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned above, keep using jax.default_backend() == "cpu":

@ds-hwang ds-hwang requested a review from ruomingp January 14, 2025 16:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants