Skip to content

Commit

Permalink
[ROCm] Tune flex-attention and decode to num_stages=1 (pytorch#139883)
Browse files Browse the repository at this point in the history
Fixes pytorch#139755 pytorch#139621

The new stream pipeliner on AMD triton backend enables num_stages to function equivalent to NV backend. This upgrade in triton 3.2 will cause OOM issues in flex attention due to num_stages=3 setting, we have tuned this to num_stages=1 which is the best setting for flash attention kernels and avoids the shmem issues.

We will follow up this PR with some config tuning on AMD backend.

Pull Request resolved: pytorch#139883
Approved by: https://github.com/bertmaher
  • Loading branch information
jataylo authored and pytorchmergebot committed Nov 7, 2024
1 parent 36e0f11 commit 8d070d2
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
34 changes: 32 additions & 2 deletions torch/_inductor/kernel/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,18 @@ def _use_flex_decoding(query, kernel_options):
(torch.float16, 256): (32, 64, 4, 3),
}

_rocm_default_config = {
(torch.float32, 64): (128, 32, 4, 1),
(torch.float32, 128): (128, 32, 4, 1),
(torch.float32, 256): (64, 16, 4, 1),
(torch.bfloat16, 64): (128, 64, 8, 1),
(torch.bfloat16, 128): (128, 64, 8, 1),
(torch.bfloat16, 256): (32, 64, 8, 1),
(torch.float16, 64): (128, 64, 8, 1),
(torch.float16, 128): (128, 64, 8, 1),
(torch.float16, 256): (32, 64, 4, 1),
}


def _get_default_config_fwd(query) -> Tuple[int, int, int, int]:
dtype = query.get_dtype()
Expand All @@ -615,6 +627,12 @@ def _get_default_config_fwd(query) -> Tuple[int, int, int, int]:
else:
default_config = (128, 64, 4, 3)
default_config = _a100_default_config.get((dtype, head_dim), default_config)
elif head_dim <= 256 and torch.version.hip:
if dtype == torch.float32:
default_config = (64, 64, 4, 1)
else:
default_config = (128, 64, 8, 1)
default_config = _rocm_default_config.get((dtype, head_dim), default_config)
else: # modest hardware or extremely large head_dim
if dtype == torch.float32:
default_config = (32, 16, 4, 3)
Expand All @@ -630,7 +648,14 @@ def _get_default_config_bwd(query) -> Tuple[int, int, int, int]:

if dtype == torch.float32:
return (16, 16, 4, 1)
if head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0): # H100
if head_dim <= 256 and torch.version.hip:
if head_dim == 64:
return (64, 64, 4, 1)
elif head_dim == 128:
return (64, 128, 4, 1)
else:
return (64, 64, 4, 1)
elif head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0): # H100
if head_dim == 64:
return (64, 64, 4, 3)
elif head_dim == 128:
Expand Down Expand Up @@ -845,6 +870,10 @@ def flex_attention(
(64, 64, 4, 3),
]

# On ROCm convert num_stages to 1 to avoid shmem issues
if torch.version.hip:
configs = [(c[0], c[1], c[2], 1) for c in configs]

# Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards.
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE)
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE)
Expand Down Expand Up @@ -1780,13 +1809,14 @@ def flex_attention_backward(*args, **kwargs):
configs: List[Tuple[int, int, int, int]] = []
configs.append(_get_default_config_bwd(query))
if config.max_autotune:
num_stages_list = [1, 3, 4, 5] if torch.version.hip is None else [1]
configs.extend(
[
(BLOCK1, BLOCK2, w, s)
for BLOCK1 in [32, 64]
for BLOCK2 in [32, 64, 128]
for w in [4, 8]
for s in [1, 3, 4, 5]
for s in num_stages_list
if BLOCK2 % BLOCK1 == 0
]
)
Expand Down
10 changes: 9 additions & 1 deletion torch/_inductor/kernel/flex_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,10 @@ def _get_decoding_default_config(key) -> Tuple[int, int, int]:
if sm_version >= (9, 0):
if head_dim > 128 and dtype == torch.float32:
return default_config
return (64, 2, 3)
if torch.version.hip is None:
return (64, 2, 3)
else:
return (64, 2, 1)
return default_config


Expand Down Expand Up @@ -407,6 +410,11 @@ def create_flex_decoding_kernel(*args, **kwargs):
(32, 2, 3),
(128, 2, 3),
]

# Use num_stages=1 on ROCm to avoid shmem limitation
if torch.version.hip:
configs = [(c[0], c[1], 1) for c in configs]

# TODO: fix autotuning.

kernel_options.setdefault("SM_SCALE", scale)
Expand Down

0 comments on commit 8d070d2

Please sign in to comment.