Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Optimize TPU Flash Attention (20x XLA compilation speed-up on 32k lon…
…g context) (#908) Use splash attention lazy mask instead of jnp mask, which is O(T^2). The host memory usage for the `jnp` mask is O(T^2). Currently, a `jnp` mask is created and then wrapped with `NumpyMask` for use in Splash Attention, resulting in O(T^2) temporal HBM usage (somehow XLA avoids allocating it tho). This PR proposes using `CausalMask` or `LocalMask`, allowing each Splash Attention block to lazily create and use the required mask in the pallas kernel. The runtime performance of Splash Attention remains nearly the same. However, the JIT compilation time for the function using Splash Attention has improved significantly. It appears that allocating the O(T^2) mask in HBM and then wrapping it with `NumpyMask` consumes a lot of XLA compilation time. * Benchmark: on TPUv5p, (model_dim/heads/kv_heads/seq_len), tools/attention_benchmark.py 1) measure time with XLA compilation NumpyMask (ASIS) ---------------------------------------------------------------------------------------- Benchmark Time CPU Iterations ---------------------------------------------------------------------------------------- FlashAttentionBenchmark/4096/16/2/4096 3609 ms 1645 ms 1 FlashAttentionBenchmark/4096/16/2/8192 7828 ms 5696 ms 1 FlashAttentionBenchmark/4096/16/2/32768 94368 ms 91442 ms 1 CausalMask (Proposed PR): significant XLA compilation speed-up. ---------------------------------------------------------------------------------------- Benchmark Time CPU Iterations ---------------------------------------------------------------------------------------- FlashAttentionBenchmark/4096/16/2/4096 22.9 ms 1.60 ms 127 FlashAttentionBenchmark/4096/16/2/8192 40.8 ms 2.12 ms 88 FlashAttentionBenchmark/4096/16/2/32768 7641 ms 5458 ms 1 2) measure time without XLA compilation (pure jit computation) NumpyMask (ASIS) ---------------------------------------------------------------------------------------- Benchmark Time CPU Iterations ---------------------------------------------------------------------------------------- FlashAttentionBenchmark/4096/16/2/4096 9.97 ms 0.757 ms 918 FlashAttentionBenchmark/4096/16/2/8192 19.4 ms 0.934 ms 832 FlashAttentionBenchmark/4096/16/2/32768 116 ms 1.03 ms 100 CausalMask (Proposed PR): slight step time speed-up. ---------------------------------------------------------------------------------------- Benchmark Time CPU Iterations ---------------------------------------------------------------------------------------- FlashAttentionBenchmark/4096/16/2/4096 9.82 ms 0.690 ms 964 FlashAttentionBenchmark/4096/16/2/8192 19.2 ms 0.822 ms 837 FlashAttentionBenchmark/4096/16/2/32768 116 ms 0.997 ms 100 In addition, tpu_attention_benchmark.py is changed to use 8k seq len, not 2k with sliding window = 1k. NumpyMask (ASIS) Benchmarking attention representative of 1.2b model layer on TPU v5. ref_fwd:0.2288s, flash_fwd:0.0014s ref_bwd:0.0218s, flash_bwd:0.0058s Benchmarking attention representative of 12.6b model layer on TPU v5. ref_fwd:0.5700s, flash_fwd:0.0032s ref_bwd:0.0527s, flash_bwd:0.0149s Benchmarking attention representative of 29.6b model layer on TPU v5. ref_fwd:0.7958s, flash_fwd:0.0044s ref_bwd:0.0730s, flash_bwd:0.0205s Benchmarking attention representative of 65.2b model layer on TPU v5. ref_fwd:1.0222s, flash_fwd:0.0055s ref_bwd:0.0949s, flash_bwd:0.0262s Benchmarking attention representative of 134b model layer on TPU v5. ref_fwd:1.2486s, flash_fwd:0.0067s ref_bwd:0.1161s, flash_bwd:0.0318s Benchmarking attention representative of 261.7b model layer on TPU v5. ref_fwd:1.5577s, flash_fwd:0.0072s ref_bwd:0.1348s, flash_bwd:0.0375s LocalMask (Proposed PR): slight fwd/bwd time speed-up. Benchmarking attention representative of 1.2b model layer on TPU v5. ref_fwd:0.2291s, flash_fwd:0.0014s ref_bwd:0.0217s, flash_bwd:0.0058s Benchmarking attention representative of 12.6b model layer on TPU v5. ref_fwd:0.5699s, flash_fwd:0.0032s ref_bwd:0.0524s, flash_bwd:0.0152s Benchmarking attention representative of 29.6b model layer on TPU v5. ref_fwd:0.7957s, flash_fwd:0.0043s ref_bwd:0.0731s, flash_bwd:0.0204s Benchmarking attention representative of 65.2b model layer on TPU v5. ref_fwd:1.0225s, flash_fwd:0.0055s ref_bwd:0.0948s, flash_bwd:0.0262s Benchmarking attention representative of 134b model layer on TPU v5. ref_fwd:1.2485s, flash_fwd:0.0067s ref_bwd:0.1159s, flash_bwd:0.0313s Benchmarking attention representative of 261.7b model layer on TPU v5. ref_fwd:1.5577s, flash_fwd:0.0072s ref_bwd:0.1349s, flash_bwd:0.0373s
- Loading branch information