Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FA 2.4.2 is falling unitest on A6000 and A5880 #1409

Open
BoxiangW opened this issue Dec 23, 2024 · 4 comments
Open

FA 2.4.2 is falling unitest on A6000 and A5880 #1409

BoxiangW opened this issue Dec 23, 2024 · 4 comments

Comments

@BoxiangW
Copy link
Contributor

I am using NVIDIA's 24.10-py3 container. This is how to reproduce:

git clone https://github.com/Dao-AILab/flash-attention.git
Verify the installed version of python -c 'import flash_attn; print(flash_attn.__version__)'
cd flash-attention && git checkout v2.4.2 && cd .. # for example, make sure you are not in the flash-attention dir when you run the test
py.test flash-attention/tests/test_flash_attn.py::test_flash_attn_qkvpacked[0.17-2048-160-True-True-False-False-dtype0]

I tested that this only happens on A5880 and A6000. Could be affecting A100 as well. H100 is not affected.

@Wonder1905
Copy link

Seeing similar stuff

@tridao
Copy link
Contributor

tridao commented Jan 10, 2025

What's the error?

@BoxiangW
Copy link
Contributor Author

>       assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
E       AssertionError: assert 1.78125 <= (2 * 0.00390625)
E        +  where 1.78125 = <built-in method item of Tensor object at 0x70a093d98c20>()
E        +    where <built-in method item of Tensor object at 0x70a093d98c20> = tensor(1.7812, device='cuda:0', dtype=torch.float16, grad_fn=<MaxBackward1>).item
E        +      where tensor(1.7812, device='cuda:0', dtype=torch.float16, grad_fn=<MaxBackward1>) = <built-in method max of Tensor object at 0x70a093d98b80>()
E        +        where <built-in method max of Tensor object at 0x70a093d98b80> = tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n           0.0000e+00, 0.0000e+00],\n          [0.0000...5.0690e-02,\n           2.7985e-02, 1.0124e-02]]]], device='cuda:0', dtype=torch.float16,\n       grad_fn=<AbsBackward0>).max
E        +          where tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n           0.0000e+00, 0.0000e+00],\n          [0.0000...5.0690e-02,\n           2.7985e-02, 1.0124e-02]]]], device='cuda:0', dtype=torch.float16,\n       grad_fn=<AbsBackward0>) = <built-in method abs of Tensor object at 0x70a093d98900>()
E        +            where <built-in method abs of Tensor object at 0x70a093d98900> = (tensor([[[[-1.4463e+00,  1.9756e+00, -5.2917e-02,  ..., -9.6143e-01,\n           -2.6953e-01, -5.7275e-01],\n          [...   -7.3303e-02,  1.9407e-03]]]], device='cuda:0', dtype=torch.float16,\n       grad_fn=<FlashAttnQKVPackedFuncBackward>) - tensor([[[[-1.4463e+00,  1.9756e+00, -5.2917e-02,  ..., -9.6143e-01,\n           -2.6953e-01, -5.7275e-01],\n          [...2e-02,\n           -4.5319e-02,  1.2062e-02]]]], device='cuda:0', dtype=torch.float16,\n       grad_fn=<ToCopyBackward0>)).abs
E        +  and   0.00390625 = <built-in method item of Tensor object at 0x70a08dee6a20>()
E        +    where <built-in method item of Tensor object at 0x70a08dee6a20> = tensor(0.0039, device='cuda:0', dtype=torch.float16, grad_fn=<MaxBackward1>).item
E        +      where tensor(0.0039, device='cuda:0', dtype=torch.float16, grad_fn=<MaxBackward1>) = <built-in method max of Tensor object at 0x70a08dee6980>()
E        +        where <built-in method max of Tensor object at 0x70a08dee6980> = tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n           0.0000e+00, 0.0000e+00],\n          [0.0000...7.6294e-06,\n           3.0518e-05, 2.2888e-05]]]], device='cuda:0', dtype=torch.float16,\n       grad_fn=<AbsBackward0>).max
E        +          where tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,\n           0.0000e+00, 0.0000e+00],\n          [0.0000...7.6294e-06,\n           3.0518e-05, 2.2888e-05]]]], device='cuda:0', dtype=torch.float16,\n       grad_fn=<AbsBackward0>) = <built-in method abs of Tensor object at 0x70a093d984f0>()
E        +            where <built-in method abs of Tensor object at 0x70a093d984f0> = (tensor([[[[-1.4463e+00,  1.9756e+00, -5.2917e-02,  ..., -9.6143e-01,\n           -2.6953e-01, -5.7275e-01],\n          [...780e-02,\n           -4.5349e-02,  1.2039e-02]]]], device='cuda:0', dtype=torch.float16,\n       grad_fn=<ViewBackward0>) - tensor([[[[-1.4463e+00,  1.9756e+00, -5.2917e-02,  ..., -9.6143e-01,\n           -2.6953e-01, -5.7275e-01],\n          [...2e-02,\n           -4.5319e-02,  1.2062e-02]]]], device='cuda:0', dtype=torch.float16,\n       grad_fn=<ToCopyBackward0>)).abs

flash-attention/tests/test_flash_attn.py:704: AssertionError
--------------------------------------------------------------------------------------------- Captured stdout call ----------------------------------------------------------------------------------------------
Actual dropout fraction: 0.15972204506397247
Output max diff: 1.78125
Output mean diff: 0.031402587890625
Pytorch max diff: 0.00390625
Pytorch mean diff: 3.9577484130859375e-05
Attention max diff: 0.69384765625
Attention Pytorch max diff: 0.00048828125
dQ max diff: 1.93359375
dK max diff: 1.8076171875
dV max diff: 1.9775390625
dQKV mean diff: 0.0284881591796875
dQ Pytorch max diff: 0.00390625
dK Pytorch max diff: 0.00390625
dV Pytorch max diff: 0.00390625
dQKV Pytorch mean diff: 4.2378902435302734e-05
============================================================================================ short test summary info ============================================================================================
FAILED flash-attention/tests/test_flash_attn.py::test_flash_attn_qkvpacked[0.17-2048-160-True-True-False-False-dtype0] - AssertionError: assert 1.78125 <= (2 * 0.00390625)

This is the assertion error on A5880.

@tridao
Copy link
Contributor

tridao commented Jan 11, 2025

Does the newer version of flash-attn (e.g. 2.7.2) have the same error?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants