Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running flash_attn/flash_attn_triton_amd/bench.py with sequence length > 4096 causes RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered #1440

Open
jiqimaoke opened this issue Jan 12, 2025 · 0 comments

Comments

@jiqimaoke
Copy link

Hi,

I tried to run the ./flash_attn/flash_attn_triton_amd/bench.py and I encountered an issue while benchmarking with FlashAttention-2 on an triton setup. When both the sequence input lengths (-sq and -sk) are greater than 4096, the following error occurs during the backward pass:

Benchmarking prefill in fwd mode...
benchmark-prefill-d128-layoutbhsd-modefwd:
   BATCH    HQ    HK  N_CTX_Q  N_CTX_K     TFLOPS
0    8.0  16.0  16.0   4096.0   4096.0  96.501665
Benchmarking prefill in bwd mode...
Traceback (most recent call last):
  File "/nvme/zky/temp/flash-attention-main/flash_attn/flash_attn_triton_amd/bench.py", line 291, in <module>
    main()
  File "/nvme/zky/temp/flash-attention-main/flash_attn/flash_attn_triton_amd/bench.py", line 288, in main
    run_benchmark(args, fn_name, FUNCTIONS[fn_name], mode)
  File "/nvme/zky/temp/flash-attention-main/flash_attn/flash_attn_triton_amd/bench.py", line 196, in run_benchmark
    bench_function.run(save_path=".", print_data=True)
  File "/data/anaconda3/envs/falcon_mamba/lib/python3.10/site-packages/triton/testing.py", line 349, in run
    result_dfs.append(self._run(bench, save_path, show_plots, print_data, **kwargs))
  File "/data/anaconda3/envs/falcon_mamba/lib/python3.10/site-packages/triton/testing.py", line 292, in _run
    ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags)
  File "/nvme/zky/temp/flash-attention-main/flash_attn/flash_attn_triton_amd/bench.py", line 189, in bench_function
    ms = triton.testing.do_bench(benchmark_fn, warmup=warmup, rep=rep)
  File "/data/anaconda3/envs/falcon_mamba/lib/python3.10/site-packages/triton/testing.py", line 106, in do_bench
    fn()
  File "/nvme/zky/temp/flash-attention-main/flash_attn/flash_attn_triton_amd/bench.py", line 180, in <lambda>
    benchmark_fn = lambda: output.backward(grad_output, retain_graph=True)
  File "/data/anaconda3/envs/falcon_mamba/lib/python3.10/site-packages/torch/_tensor.py", line 581, in backward
    torch.autograd.backward(
  File "/data/anaconda3/envs/falcon_mamba/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/data/anaconda3/envs/falcon_mamba/lib/python3.10/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/data/anaconda3/envs/falcon_mamba/lib/python3.10/site-packages/torch/autograd/function.py", line 307, in apply
    return user_fn(self, *args)
  File "/data/anaconda3/envs/falcon_mamba/lib/python3.10/site-packages/flash_attn/flash_attn_triton_amd/interface_torch.py", line 54, in backward
    return attention_prefill_backward_triton_impl(
  File "/data/anaconda3/envs/falcon_mamba/lib/python3.10/site-packages/flash_attn/flash_attn_triton_amd/bwd_prefill.py", line 625, in attention_prefill_backward_triton_impl
    _bwd_kernel[(batch_headsize, num_blocks_n if sequence_parallel else 1)](
  File "/data/anaconda3/envs/falcon_mamba/lib/python3.10/site-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/data/anaconda3/envs/falcon_mamba/lib/python3.10/site-packages/triton/runtime/jit.py", line 691, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
  File "/data/anaconda3/envs/falcon_mamba/lib/python3.10/site-packages/triton/backends/nvidia/driver.py", line 365, in __call__
    self.launch(*args, **kwargs)
RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

Here is my script.

python ./flash_attn/flash_attn_triton_amd/bench.py -b 8 -hq 16 -hk 16 -sq 4096 -sk 4096 -d 128 -return_tflops -benchmark_fn prefill -causal

It appears that when both -sq and -sk exceed 4096, the backward kernel fails with an illegal memory access. However, when the sequence lengths are less than or equal to 4096, the benchmarking completes without any issues.

The Triton version is 3.1.0, and the GPU is H20.

Does anyone have any idea?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant