Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FA3 regression on H100 80GB? #1432

Open
bastianhagedorn opened this issue Jan 9, 2025 · 8 comments
Open

FA3 regression on H100 80GB? #1432

bastianhagedorn opened this issue Jan 9, 2025 · 8 comments

Comments

@bastianhagedorn
Copy link

Hi I just updated my local installation of this repo to v2.7.2 and saw a signifcant drop in performance trying to run non-causal FA3 on H100 80GB. Most likely this is because suddenly the MMA instruction shape for the first HGMMAs changed to m64n16k16. I think it was m64n176k16 before. Was there some change in the heuristics or might that problem be on my side?

I can provide more details like specific environment, CUDA version, problem sizes, etc, as needed. Just wanted to raise awareness for now.

Thanks
Bastian

@tridao
Copy link
Contributor

tridao commented Jan 9, 2025

Are you installing FA3 from hopper directory? Which 2 versions (tag or commit) are you comparing?

@bastianhagedorn
Copy link
Author

bastianhagedorn commented Jan 9, 2025

Yes, I run python setup.py install from inside the hopper directory. (btw, pulling the latest main commit fails to compile for me so I checked out tag v2.7.2 explicitly which seems to work). Unfortunately I don't have the original version I used any longer since I just pulled the latest version. It was at least a few months old.

@tridao
Copy link
Contributor

tridao commented Jan 9, 2025

Can you try this commit?
68bf390

@bastianhagedorn
Copy link
Author

Then I'm seeing this:

Submodule path '../csrc/cutlass': checked out 'c506e16788cb08416a4a57e11a9067beeee29420'


torch.__version__  = 2.5.1+cu124


copy /home/bhagedorn/.flashattn/nvidia/nvcc/bin to /home/scratch.bhagedorn_sw/flash-attention/hopper/../third_party/nvidia/backend/bin ...
copy /home/bhagedorn/.flashattn/nvidia/nvcc/nvvm/bin to /home/scratch.bhagedorn_sw/flash-attention/hopper/../third_party/nvidia/backend/bin ...
/home/bhagedorn/scratch/miniconda3/envs/graphene2/lib/python3.10/site-packages/setuptools/dist.py:509: InformationOnly: Normalizing '3.0.0.b1' to '3.0.0b1'
  self.metadata.version = self._normalize_version(
running install
/home/bhagedorn/scratch/miniconda3/envs/graphene2/lib/python3.10/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: setup.py install is deprecated.
!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!
  self.initialize_options()
/home/bhagedorn/scratch/miniconda3/envs/graphene2/lib/python3.10/site-packages/setuptools/_distutils/cmd.py:66: EasyInstallDeprecationWarning: easy_install command is deprecated.
!!

        ********************************************************************************
        Please avoid running ``setup.py`` and ``easy_install``.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://github.com/pypa/setuptools/issues/917 for details.
        ********************************************************************************

!!
  self.initialize_options()
running bdist_egg
running egg_info
writing flash_attn.egg-info/PKG-INFO
writing dependency_links to flash_attn.egg-info/dependency_links.txt
writing requirements to flash_attn.egg-info/requires.txt
writing top-level names to flash_attn.egg-info/top_level.txt
reading manifest file 'flash_attn.egg-info/SOURCES.txt'
writing manifest file 'flash_attn.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_py
copying flash_attn_interface.py -> build/lib.linux-x86_64-cpython-310
running build_ext
/home/bhagedorn/scratch/miniconda3/envs/graphene2/lib/python3.10/site-packages/torch/utils/cpp_extension.py:416: UserWarning: The detected CUDA version (12.9) has a minor version mismatch with the version that was used to compile PyTorch (12.4). Most likely this shouldn't be a problem.
  warnings.warn(CUDA_MISMATCH_WARN.format(cuda_str_version, torch.version.cuda))
/home/bhagedorn/scratch/miniconda3/envs/graphene2/lib/python3.10/site-packages/torch/utils/cpp_extension.py:426: UserWarning: There are no g++ version bounds defined for CUDA version 12.9
  warnings.warn(f'There are no {compiler_name} version bounds defined for CUDA version {cuda_str_version}')
building 'flash_attn_3_cuda' extension
Emitting ninja build file /home/scratch.bhagedorn_sw/flash-attention/hopper/build/temp.linux-x86_64-cpython-310/build.ninja...
Compiling objects...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[1/60] c++ -MMD -MF /home/scratch.bhagedorn_sw/flash-attention/hopper/build/temp.linux-x86_64-cpython-310/flash_api.o.d -pthread -B /home/bhagedorn/scratch/miniconda3/envs/graphene2/compiler_compat -Wno-unused-result -Wsign-compare -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /home/bhagedorn/scratch/miniconda3/envs/graphene2/include -fPIC -O2 -isystem /home/bhagedorn/scratch/miniconda3/envs/graphene2/include -fPIC -I/home/scratch.bhagedorn_sw/flash-attention/hopper -I/home/scratch.bhagedorn_sw/flash-attention/csrc/cutlass/include -I/home/bhagedorn/scratch/miniconda3/envs/graphene2/lib/python3.10/site-packages/torch/include -I/home/bhagedorn/scratch/miniconda3/envs/graphene2/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/home/bhagedorn/scratch/miniconda3/envs/graphene2/lib/python3.10/site-packages/torch/include/TH -I/home/bhagedorn/scratch/miniconda3/envs/graphene2/lib/python3.10/site-packages/torch/include/THC -I/home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/include -I/home/bhagedorn/scratch/miniconda3/envs/graphene2/include/python3.10 -c -c /home/scratch.bhagedorn_sw/flash-attention/hopper/flash_api.cpp -o /home/scratch.bhagedorn_sw/flash-attention/hopper/build/temp.linux-x86_64-cpython-310/flash_api.o -O3 -std=c++17 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=flash_attn_3_cuda -D_GLIBCXX_USE_CXX11_ABI=0
[2/60] /home/scratch.bhagedorn_sw/flash-attention/hopper/../third_party/nvidia/backend/bin/nvcc --generate-dependencies-with-compile --dependency-output /home/scratch.bhagedorn_sw/flash-attention/hopper/build/temp.linux-x86_64-cpython-310/instantiations/flash_fwd_hdimall_bf16_paged_sm90.o.d -I/home/scratch.bhagedorn_sw/flash-attention/hopper -I/home/scratch.bhagedorn_sw/flash-attention/csrc/cutlass/include -I/home/bhagedorn/scratch/miniconda3/envs/graphene2/lib/python3.10/site-packages/torch/include -I/home/bhagedorn/scratch/miniconda3/envs/graphene2/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/home/bhagedorn/scratch/miniconda3/envs/graphene2/lib/python3.10/site-packages/torch/include/TH -I/home/bhagedorn/scratch/miniconda3/envs/graphene2/lib/python3.10/site-packages/torch/include/THC -I/home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/include -I/home/bhagedorn/scratch/miniconda3/envs/graphene2/include/python3.10 -c -c /home/scratch.bhagedorn_sw/flash-attention/hopper/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu -o /home/scratch.bhagedorn_sw/flash-attention/hopper/build/temp.linux-x86_64-cpython-310/instantiations/flash_fwd_hdimall_bf16_paged_sm90.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' --threads 4 -O3 -std=c++17 --ftemplate-backtrace-limit=0 --use_fast_math --resource-usage -lineinfo -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED -DCUTLASS_DEBUG_TRACE_LEVEL=0 -DNDEBUG -gencode arch=compute_90a,code=sm_90a -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=flash_attn_3_cuda -D_GLIBCXX_USE_CXX11_ABI=0
FAILED: /home/scratch.bhagedorn_sw/flash-attention/hopper/build/temp.linux-x86_64-cpython-310/instantiations/flash_fwd_hdimall_bf16_paged_sm90.o
/home/scratch.bhagedorn_sw/flash-attention/hopper/../third_party/nvidia/backend/bin/nvcc --generate-dependencies-with-compile --dependency-output /home/scratch.bhagedorn_sw/flash-attention/hopper/build/temp.linux-x86_64-cpython-310/instantiations/flash_fwd_hdimall_bf16_paged_sm90.o.d -I/home/scratch.bhagedorn_sw/flash-attention/hopper -I/home/scratch.bhagedorn_sw/flash-attention/csrc/cutlass/include -I/home/bhagedorn/scratch/miniconda3/envs/graphene2/lib/python3.10/site-packages/torch/include -I/home/bhagedorn/scratch/miniconda3/envs/graphene2/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/home/bhagedorn/scratch/miniconda3/envs/graphene2/lib/python3.10/site-packages/torch/include/TH -I/home/bhagedorn/scratch/miniconda3/envs/graphene2/lib/python3.10/site-packages/torch/include/THC -I/home/scratch.svc_compute_arch/release/cuda_toolkit/internal/latest/include -I/home/bhagedorn/scratch/miniconda3/envs/graphene2/include/python3.10 -c -c /home/scratch.bhagedorn_sw/flash-attention/hopper/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu -o /home/scratch.bhagedorn_sw/flash-attention/hopper/build/temp.linux-x86_64-cpython-310/instantiations/flash_fwd_hdimall_bf16_paged_sm90.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' --threads 4 -O3 -std=c++17 --ftemplate-backtrace-limit=0 --use_fast_math --resource-usage -lineinfo -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED -DCUTLASS_DEBUG_TRACE_LEVEL=0 -DNDEBUG -gencode arch=compute_90a,code=sm_90a -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=flash_attn_3_cuda -D_GLIBCXX_USE_CXX11_ABI=0
Killed

@tridao
Copy link
Contributor

tridao commented Jan 9, 2025

you'd need MAX_JOBS=4 (the number depends on your RAM) to avoid OOM during compilation

@bastianhagedorn
Copy link
Author

That took forever but it built successfully, thanks :) but the problem remains, performance is not great and I see HGMMA.64x16x16.F32 used for the first GEMM.

@tridao
Copy link
Contributor

tridao commented Jan 10, 2025

Very strange, we don't ever use HGMMA.64x16x16. I just dumped the SASS and it's using HGMMA.64x176x16.F32.BF16 as expected.
Maybe sth to do with the way you're compiling it. And plz make sure you're actually running the newly compiled version and not some old version.

@bastianhagedorn
Copy link
Author

Thanks for looking into this, I'll have a closer look at everything again and will get back to you at some point next week 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants