Skip to content

Commit

Permalink
Optimize TPU Flash Attention (20x XLA compilation speed-up on 32k lon…
Browse files Browse the repository at this point in the history
…g context) (#908)

Use splash attention lazy mask instead of jnp mask, which is O(T^2).

The host memory usage for the `jnp` mask is O(T^2). Currently, a `jnp` mask is
created and then wrapped with `NumpyMask` for use in Splash Attention,
resulting in O(T^2) temporal HBM usage (somehow XLA avoids allocating it tho).
This PR proposes using `CausalMask` or `LocalMask`, allowing each Splash
Attention block to lazily create and use the required mask in the pallas
kernel.

The runtime performance of Splash Attention remains nearly the same. However,
the JIT compilation time for the function using Splash Attention has improved
significantly. It appears that allocating the O(T^2) mask in HBM and then
wrapping it with `NumpyMask` consumes a lot of XLA compilation time.

* Benchmark: on TPUv5p, (model_dim/heads/kv_heads/seq_len), tools/attention_benchmark.py

1) measure time with XLA compilation

NumpyMask (ASIS)
----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
FlashAttentionBenchmark/4096/16/2/4096        3609 ms         1645 ms            1
FlashAttentionBenchmark/4096/16/2/8192        7828 ms         5696 ms            1
FlashAttentionBenchmark/4096/16/2/32768      94368 ms        91442 ms            1

CausalMask (Proposed PR): significant XLA compilation speed-up.
----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
FlashAttentionBenchmark/4096/16/2/4096        22.9 ms         1.60 ms          127
FlashAttentionBenchmark/4096/16/2/8192        40.8 ms         2.12 ms           88
FlashAttentionBenchmark/4096/16/2/32768       7641 ms         5458 ms            1

2) measure time without XLA compilation (pure jit computation)

NumpyMask (ASIS)
----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
FlashAttentionBenchmark/4096/16/2/4096        9.97 ms        0.757 ms          918
FlashAttentionBenchmark/4096/16/2/8192        19.4 ms        0.934 ms          832
FlashAttentionBenchmark/4096/16/2/32768        116 ms         1.03 ms          100

CausalMask (Proposed PR): slight step time speed-up.
----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
FlashAttentionBenchmark/4096/16/2/4096        9.82 ms        0.690 ms          964
FlashAttentionBenchmark/4096/16/2/8192        19.2 ms        0.822 ms          837
FlashAttentionBenchmark/4096/16/2/32768        116 ms        0.997 ms          100

In addition, tpu_attention_benchmark.py is changed to use 8k seq len, not 2k
with sliding window = 1k.

NumpyMask (ASIS)
Benchmarking attention representative of 1.2b model layer on TPU v5.
ref_fwd:0.2288s, flash_fwd:0.0014s
ref_bwd:0.0218s, flash_bwd:0.0058s
Benchmarking attention representative of 12.6b model layer on TPU v5.
ref_fwd:0.5700s, flash_fwd:0.0032s
ref_bwd:0.0527s, flash_bwd:0.0149s
Benchmarking attention representative of 29.6b model layer on TPU v5.
ref_fwd:0.7958s, flash_fwd:0.0044s
ref_bwd:0.0730s, flash_bwd:0.0205s
Benchmarking attention representative of 65.2b model layer on TPU v5.
ref_fwd:1.0222s, flash_fwd:0.0055s
ref_bwd:0.0949s, flash_bwd:0.0262s
Benchmarking attention representative of 134b model layer on TPU v5.
ref_fwd:1.2486s, flash_fwd:0.0067s
ref_bwd:0.1161s, flash_bwd:0.0318s
Benchmarking attention representative of 261.7b model layer on TPU v5.
ref_fwd:1.5577s, flash_fwd:0.0072s
ref_bwd:0.1348s, flash_bwd:0.0375s

LocalMask (Proposed PR): slight fwd/bwd time speed-up.
Benchmarking attention representative of 1.2b model layer on TPU v5.
ref_fwd:0.2291s, flash_fwd:0.0014s
ref_bwd:0.0217s, flash_bwd:0.0058s
Benchmarking attention representative of 12.6b model layer on TPU v5.
ref_fwd:0.5699s, flash_fwd:0.0032s
ref_bwd:0.0524s, flash_bwd:0.0152s
Benchmarking attention representative of 29.6b model layer on TPU v5.
ref_fwd:0.7957s, flash_fwd:0.0043s
ref_bwd:0.0731s, flash_bwd:0.0204s
Benchmarking attention representative of 65.2b model layer on TPU v5.
ref_fwd:1.0225s, flash_fwd:0.0055s
ref_bwd:0.0948s, flash_bwd:0.0262s
Benchmarking attention representative of 134b model layer on TPU v5.
ref_fwd:1.2485s, flash_fwd:0.0067s
ref_bwd:0.1159s, flash_bwd:0.0313s
Benchmarking attention representative of 261.7b model layer on TPU v5.
ref_fwd:1.5577s, flash_fwd:0.0072s
ref_bwd:0.1349s, flash_bwd:0.0373s
  • Loading branch information
ds-hwang authored Jan 7, 2025
1 parent d83b450 commit c40b39a
Show file tree
Hide file tree
Showing 5 changed files with 257 additions and 122 deletions.
6 changes: 5 additions & 1 deletion axlearn/common/attention_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,11 @@ def sliding_window_causal_mask(sliding_window_size: int) -> MaskFn:
def mask(query_position: Tensor, key_position: Tensor):
return query_position - key_position <= sliding_window_size

return and_masks(causal_mask, mask)
fun = and_masks(causal_mask, mask)
# Flash attention needs to recognize sliding window size in _to_splash_mask().
# pylint: disable-next=protected-access
fun._sliding_window_size = sliding_window_size
return fun


def make_causal_biases(seq_len: int) -> Tensor:
Expand Down
Loading

0 comments on commit c40b39a

Please sign in to comment.