Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash2 and supports cross attention and dropout #905

Merged
merged 3 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
929 changes: 396 additions & 533 deletions axlearn/common/flash_attention/gpu_attention.py

Large diffs are not rendered by default.

138 changes: 69 additions & 69 deletions axlearn/common/flash_attention/gpu_attention_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,95 +10,95 @@
"""FlashAttention kernel benchmarks.

Tor run: python3 gpu_attention_benchmark.py > out.txt
Requires Jax >= 0.4.36. Sample numbers on H100 SXM5:
Requires Jax >= 0.4.36. Sample numbers on H100 SXM5 with Jax == 0.4.36:
is_decode=True, use_bwd=False, num_heads=8, num_kv_heads=8, per_head_dim=128, sw_sz=-1
jax axlearn jax-cudnn
bs=1,seq_len=1024 0.020608 0.018656 0.023680
bs=1,seq_len=4096 0.037856 0.022784 0.056704
bs=1,seq_len=8192 0.033792 0.032768 0.104448
bs=1,seq_len=131072 0.227808 0.198816 1.486752
bs=4,seq_len=1024 0.021440 0.022208 0.024032
bs=4,seq_len=4096 0.069728 0.054624 0.059584
bs=4,seq_len=8192 0.081952 0.076064 0.105920
bs=4,seq_len=131072 0.823104 0.705056 1.488832
bs=8,seq_len=1024 0.032544 0.030688 0.024608
bs=8,seq_len=4096 0.089728 0.071648 0.063584
bs=8,seq_len=8192 0.129184 0.114944 0.109856
bs=8,seq_len=131072 1.616800 1.376288 1.503360
bs=16,seq_len=1024 0.050976 0.048608 0.037504
bs=16,seq_len=4096 0.136768 0.117312 0.104224
bs=16,seq_len=8192 0.234688 0.200128 0.190944
bs=16,seq_len=131072 3.211200 2.727040 2.779872
bs=32,seq_len=1024 0.078656 0.072992 0.061440
bs=32,seq_len=4096 0.236576 0.204512 0.190752
bs=32,seq_len=8192 0.443488 0.372352 0.361216
bs=32,seq_len=131072 6.392320 5.453344 5.495488
bs=1,seq_len=1024 0.020832 0.017536 0.024128
bs=1,seq_len=4096 0.037472 0.021248 0.058656
bs=1,seq_len=8192 0.034016 0.032576 0.108576
bs=1,seq_len=131072 0.229856 0.198944 1.558464
bs=4,seq_len=1024 0.021632 0.023296 0.024352
bs=4,seq_len=4096 0.068064 0.055168 0.061312
bs=4,seq_len=8192 0.080352 0.075968 0.109696
bs=4,seq_len=131072 0.824576 0.703360 1.560768
bs=8,seq_len=1024 0.033536 0.030304 0.024448
bs=8,seq_len=4096 0.089056 0.071712 0.062944
bs=8,seq_len=8192 0.128960 0.114848 0.112736
bs=8,seq_len=131072 1.620032 1.373088 1.566208
bs=16,seq_len=1024 0.050368 0.048064 0.036608
bs=16,seq_len=4096 0.134816 0.116320 0.104320
bs=16,seq_len=8192 0.234880 0.200384 0.191936
bs=16,seq_len=131072 3.219008 2.726912 2.784768
bs=32,seq_len=1024 0.078112 0.070816 0.061568
bs=32,seq_len=4096 0.235648 0.203296 0.191936
bs=32,seq_len=8192 0.442080 0.371936 0.365152
bs=32,seq_len=131072 6.404832 5.448480 5.541504
is_decode=True, use_bwd=False, num_heads=8, seq_len=32768, per_head_dim=128, sw_sz=-1
jax axlearn jax-cudnn
bs=1,num_kv_heads=1 0.049280 0.059296 0.378304
bs=1,num_kv_heads=8 0.076352 0.070912 0.377344
bs=8,num_kv_heads=1 0.111072 0.080480 0.377696
bs=8,num_kv_heads=8 0.425536 0.368576 0.386880
bs=1,num_kv_heads=1 0.027648 0.058464 0.398816
bs=1,num_kv_heads=8 0.076096 0.070368 0.398912
bs=8,num_kv_heads=1 0.101696 0.078560 0.399040
bs=8,num_kv_heads=8 0.426656 0.367616 0.403360
is_decode=True, use_bwd=False, num_heads=8, num_kv_heads=8, per_head_dim=128
jax axlearn jax-cudnn
bs=1,seq_len=131072,sw_sz=-1 0.228640 0.199040 1.476928
bs=1,seq_len=131072,sw_sz=4096 0.232320 0.053824 4.441376
bs=1,seq_len=131072,sw_sz=16384 0.233696 0.061120 4.420992
bs=8,seq_len=131072,sw_sz=-1 1.621696 1.374080 1.496224
bs=8,seq_len=131072,sw_sz=4096 1.626016 0.193792 4.463296
bs=8,seq_len=131072,sw_sz=16384 1.628704 0.318176 4.451648
bs=1,seq_len=131072,sw_sz=-1 0.230336 0.199968 1.559168
bs=1,seq_len=131072,sw_sz=4096 0.235296 0.051296 4.414048
bs=1,seq_len=131072,sw_sz=16384 0.235904 0.062976 4.385216
bs=8,seq_len=131072,sw_sz=-1 1.619008 1.372768 1.570272
bs=8,seq_len=131072,sw_sz=4096 1.635424 0.194720 4.390976
bs=8,seq_len=131072,sw_sz=16384 1.632832 0.321280 4.361984
is_decode=False, use_bwd=False, num_heads=32, num_kv_heads=None, seq_len=4096, per_head_dim=128, sw_sz=-1
jax axlearn jax-cudnn jax-pallas
bs=2 3.502944 0.915360 0.467744 0.845792
bs=4 6.969376 1.753152 0.890496 1.617280
bs=8 13.962816 3.415232 1.735232 3.150752
bs=2 3.583424 0.894912 0.488480 0.852960
bs=4 7.107168 1.712448 0.922592 1.629888
bs=8 14.202400 3.341568 1.801920 3.184064
is_decode=False, use_bwd=False, bs=2, num_kv_heads=None, seq_len=4096, per_head_dim=128, sw_sz=-1
jax axlearn jax-cudnn jax-pallas
num_heads=12 1.262560 0.393536 0.205952 0.362304
num_heads=16 1.786816 0.498304 0.257664 0.459936
num_heads=32 3.507488 2.591456 0.468672 2.443296
num_heads=48 5.246336 1.338272 0.675968 1.231328
num_heads=72 7.866848 1.961152 0.995712 1.805376
num_heads=12 1.287712 0.383200 0.214400 0.365120
num_heads=16 1.803232 0.485408 0.270496 0.463040
num_heads=32 3.578208 0.896576 0.488544 2.468096
num_heads=48 5.346112 1.305856 0.707872 1.241728
num_heads=72 8.001568 1.915776 1.035200 1.820288
is_decode=False, use_bwd=False, bs=2, num_heads=32, num_kv_heads=None, per_head_dim=128, sw_sz=-1
jax axlearn jax-cudnn jax-pallas
seq_len=128 0.030592 0.011584 0.013024 0.012960
seq_len=256 0.051520 0.015648 0.016640 0.015744
seq_len=512 0.118720 0.038976 0.028224 0.037152
seq_len=1024 0.310880 0.096256 0.054784 0.090368
seq_len=2048 0.931072 0.277312 0.150784 0.256928
seq_len=4096 3.516672 2.595872 0.465568 2.448128
seq_len=256 0.049184 0.015360 0.016352 0.015488
seq_len=512 0.110400 0.038624 0.028480 0.037760
seq_len=1024 0.302304 0.094560 0.056736 0.090464
seq_len=2048 0.936832 0.269856 0.154304 0.258944
seq_len=4096 3.584800 0.895776 0.487104 2.462560
seq_len=8192 14.260608 3.268320 1.742048 3.104640
is_decode=False, use_bwd=False, bs=2, num_heads=32, num_kv_heads=None, seq_len=4096, sw_sz=-1
jax axlearn jax-cudnn jax-pallas
per_head_dim=16 3.220960 0.487808 0.332928 0.478720
per_head_dim=32 3.277824 0.530240 0.334624 0.515040
per_head_dim=64 3.345376 0.696480 0.338944 0.631296
per_head_dim=128 3.515616 2.594208 0.465824 2.442784
per_head_dim=16 3.262592 0.518912 0.356544 0.477120
per_head_dim=32 3.323552 0.563520 0.358944 0.533344
per_head_dim=64 3.411744 0.690464 0.360192 0.635296
per_head_dim=128 3.585920 0.896032 0.488416 2.461696
is_decode=False, use_bwd=True, num_heads=32, num_kv_heads=None, seq_len=4096, per_head_dim=128, sw_sz=-1
jax axlearn jax-cudnn jax-pallas
bs=2 10.780096 4.573344 2.080672 4.487104
bs=4 21.426336 9.336192 3.988224 9.159904
bs=8 42.808033 18.926559 7.975296 18.075487
bs=2 10.878624 3.924992 2.123008 4.504256
bs=4 21.626017 8.043040 4.071552 9.186080
bs=8 43.269279 16.195999 8.124896 18.184799
is_decode=False, use_bwd=True, bs=2, num_kv_heads=None, seq_len=4096, per_head_dim=128, sw_sz=-1
jax axlearn jax-cudnn jax-pallas
num_heads=12 4.128352 1.738016 0.882976 1.696704
num_heads=16 5.467808 2.307488 1.120608 2.247904
num_heads=32 10.782432 4.559456 2.082592 4.488448
num_heads=48 16.119776 6.958272 3.027808 6.858144
num_heads=72 24.140833 10.706656 4.560288 10.279136
num_heads=12 4.159424 1.519680 0.898816 1.711808
num_heads=16 5.486912 2.001952 1.142144 2.256960
num_heads=32 10.886848 3.928896 2.114496 4.502976
num_heads=48 16.224319 6.085408 3.093696 6.888640
num_heads=72 24.367489 9.190560 4.642720 10.323552
is_decode=False, use_bwd=True, bs=2, num_heads=32, num_kv_heads=None, per_head_dim=128, sw_sz=-1
jax axlearn jax-cudnn jax-pallas
seq_len=128 0.058944 0.037824 0.039040 0.036384
seq_len=256 0.100384 0.069024 0.052608 0.067872
seq_len=512 0.317056 0.159904 0.111840 0.158912
seq_len=1024 0.906400 0.431104 0.244160 0.421792
seq_len=2048 2.861056 1.319648 0.655840 1.297728
seq_len=4096 10.762560 4.576864 2.079904 4.489056
seq_len=256 0.094496 0.060096 0.053184 0.065760
seq_len=512 0.297440 0.139328 0.112736 0.161664
seq_len=1024 0.886304 0.361536 0.246848 0.418720
seq_len=2048 2.857952 1.118368 0.675168 1.294144
seq_len=4096 10.880512 3.914048 2.119808 4.503936
seq_len=8192 43.000095 14.913824 7.484128 16.730017
is_decode=False, use_bwd=True, bs=2, num_heads=32, num_kv_heads=None, seq_len=4096, sw_sz=-1
jax axlearn jax-cudnn jax-pallas
per_head_dim=16 10.084800 1.744640 1.263264 1.711296
per_head_dim=32 10.204480 2.098816 1.291104 2.041184
per_head_dim=64 10.374720 2.649888 1.335200 2.510304
per_head_dim=128 10.779680 4.568096 2.079264 4.489792
per_head_dim=16 10.150080 1.826656 1.288192 1.718688
per_head_dim=32 10.277440 2.028608 1.316512 2.048864
per_head_dim=64 10.463904 2.569408 1.364448 2.540512
per_head_dim=128 10.875328 3.929568 2.124192 4.502912
"""
# pylint: enable=line-too-long
import itertools
Expand Down Expand Up @@ -365,8 +365,8 @@ def bench_flash_attention_fwd_bwd(use_bwd: bool):
libraries = ["jax", "axlearn", "jax-cudnn", "jax-pallas"]
benchmark_sweep(libraries, common_kwargs, bs=[2, 4, 8])
benchmark_sweep(libraries, common_kwargs, num_heads=[12, 16, 32, 48, 72])
# 128 to 4096.
benchmark_sweep(libraries, common_kwargs, seq_len=[int(2**i) for i in range(7, 13)])
# 256 to 8192.
benchmark_sweep(libraries, common_kwargs, seq_len=[int(2**i) for i in range(8, 14)])
benchmark_sweep(libraries, common_kwargs, per_head_dim=[16, 32, 64, 128])


Expand Down
Loading
Loading