You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I tried to run the ./flash_attn/flash_attn_triton_amd/bench.py and I encountered an issue while benchmarking with FlashAttention-2 on an triton setup. When both the sequence input lengths (-sq and -sk) are greater than 4096, the following error occurs during the backward pass:
Benchmarking prefill in fwd mode...
benchmark-prefill-d128-layoutbhsd-modefwd:
BATCH HQ HK N_CTX_Q N_CTX_K TFLOPS
0 8.0 16.0 16.0 4096.0 4096.0 96.501665
Benchmarking prefill in bwd mode...
Traceback (most recent call last):
File "/nvme/zky/temp/flash-attention-main/flash_attn/flash_attn_triton_amd/bench.py", line 291, in <module>
main()
File "/nvme/zky/temp/flash-attention-main/flash_attn/flash_attn_triton_amd/bench.py", line 288, in main
run_benchmark(args, fn_name, FUNCTIONS[fn_name], mode)
File "/nvme/zky/temp/flash-attention-main/flash_attn/flash_attn_triton_amd/bench.py", line 196, in run_benchmark
bench_function.run(save_path=".", print_data=True)
File "/data/anaconda3/envs/falcon_mamba/lib/python3.10/site-packages/triton/testing.py", line 349, in run
result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs))
File "/data/anaconda3/envs/falcon_mamba/lib/python3.10/site-packages/triton/testing.py", line 292, in _run
ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags)
File "/nvme/zky/temp/flash-attention-main/flash_attn/flash_attn_triton_amd/bench.py", line 189, in bench_function
ms = triton.testing.do_bench(benchmark_fn, warmup=warmup, rep=rep)
File "/data/anaconda3/envs/falcon_mamba/lib/python3.10/site-packages/triton/testing.py", line 106, in do_bench
fn()
File "/nvme/zky/temp/flash-attention-main/flash_attn/flash_attn_triton_amd/bench.py", line 180, in <lambda>
benchmark_fn = lambda: output.backward(grad_output, retain_graph=True)
File "/data/anaconda3/envs/falcon_mamba/lib/python3.10/site-packages/torch/_tensor.py", line 581, in backward
torch.autograd.backward(
File "/data/anaconda3/envs/falcon_mamba/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
_engine_run_backward(
File "/data/anaconda3/envs/falcon_mamba/lib/python3.10/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/data/anaconda3/envs/falcon_mamba/lib/python3.10/site-packages/torch/autograd/function.py", line 307, in apply
return user_fn(self, *args)
File "/data/anaconda3/envs/falcon_mamba/lib/python3.10/site-packages/flash_attn/flash_attn_triton_amd/interface_torch.py", line 54, in backward
return attention_prefill_backward_triton_impl(
File "/data/anaconda3/envs/falcon_mamba/lib/python3.10/site-packages/flash_attn/flash_attn_triton_amd/bwd_prefill.py", line 625, in attention_prefill_backward_triton_impl
_bwd_kernel[(batch_headsize, num_blocks_n if sequence_parallel else 1)](
File "/data/anaconda3/envs/falcon_mamba/lib/python3.10/site-packages/triton/runtime/jit.py", line 345, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
File "/data/anaconda3/envs/falcon_mamba/lib/python3.10/site-packages/triton/runtime/jit.py", line 691, in run
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
File "/data/anaconda3/envs/falcon_mamba/lib/python3.10/site-packages/triton/backends/nvidia/driver.py", line 365, in __call__
self.launch(*args, **kwargs)
RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered
It appears that when both -sq and -sk exceed 4096, the backward kernel fails with an illegal memory access. However, when the sequence lengths are less than or equal to 4096, the benchmarking completes without any issues.
The Triton version is 3.1.0, and the GPU is H20.
Does anyone have any idea?
The text was updated successfully, but these errors were encountered:
Hi,
I tried to run the
./flash_attn/flash_attn_triton_amd/bench.py
and I encountered an issue while benchmarking with FlashAttention-2 on an triton setup. When both the sequence input lengths (-sq and -sk) are greater than 4096, the following error occurs during the backward pass:Here is my script.
It appears that when both -sq and -sk exceed 4096, the backward kernel fails with an illegal memory access. However, when the sequence lengths are less than or equal to 4096, the benchmarking completes without any issues.
The Triton version is 3.1.0, and the GPU is H20.
Does anyone have any idea?
The text was updated successfully, but these errors were encountered: