Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Adds utilities for AMD fp8 dtype support, follow up PR to add option to the configs #235

Closed
wants to merge 2 commits into from

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Mar 7, 2024

Summary

AMD GPUS support a different fp8 dtype compared to nvidia. These dtypes were added to PyTorch and we update Float8Tensor construction to use the format dependent on the arch.

For a detailed summary see: https://github.com/openxla/stablehlo/blob/main/rfcs/20230321-fp8_fnuz.md

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 7, 2024
fp8_tensor = to_fp8_no_autograd(
gradY, gradY_scale, torch.float8_e5m2, ctx.emulate
)
fp8_dtype = torch.float8_e5m2fnuz if IS_AMD else torch.float8_e5m2
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be configurable in the forward with a reasonable default?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it would be better to have torch.backends.[cuda|hip|mps].supports_dtype(XYZ) API, because I assume XPUs would probably use fnuz flavor, but say ARM would be using e5m2 flavor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be configurable instead of depending on env, numerics should be as predictable as possible. It's also valuable to emulate numerics without having the hardware, for debugging.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oaky, so it sounds like we want this to be defined at module construction. It is up to the constructor of the module to ensure that _scaled_mm will work with their module.

I do think that Nikitas backend dtype helper could a useful pytorch feature, not sure if that should live here though

@drisspg
Copy link
Contributor Author

drisspg commented Mar 11, 2024

The non emulated version is failing for me on on an mi300 machine using this version of pytorch:

pytorch-triton-rocm==3.0.0+0a22a91d04
torch==2.3.0.dev20240308+rocm6.0
FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[True-LinearType.DYNAMIC-x_shape0-False] - AssertionError: -2.7592885494232178 is too low
FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[True-LinearType.DYNAMIC-x_shape1-False] - AssertionError: -3.372152805328369 is too low
FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[True-LinearType.DYNAMIC-x_shape2-False] - AssertionError: -2.8420748710632324 is too low
FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[False-LinearType.DELAYED-x_shape0-False] - AssertionError: -2.7584447860717773 is too low
FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[False-LinearType.DELAYED-x_shape1-False] - AssertionError: -2.946033239364624 is too low
FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[False-LinearType.DELAYED-x_shape2-False] - AssertionError: -2.756319999694824 is too low
FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[False-LinearType.DYNAMIC-x_shape0-False] - AssertionError: -3.377957820892334 is too low
FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[False-LinearType.DYNAMIC-x_shape1-False] - AssertionError: -3.0644452571868896 is too low
FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[False-LinearType.DYNAMIC-x_shape2-False] - AssertionError: -3.091813564300537 is too low

@alugorey
Copy link
Contributor

@drisspg Is this mergeable? The errors above should have been fixed with this: pytorch/pytorch#125921

@drisspg
Copy link
Contributor Author

drisspg commented May 23, 2024

@alugorey I was actually thinking that something similiar would be landed here: #248

@alugorey
Copy link
Contributor

@alugorey I was actually thinking that something similiar would be landed here: #248

Ah okay, I was working on top of this PR. Do you want me to pull your changes into my PR and we abandon this one? Or do you want me to point my PR to this branch?

@drisspg
Copy link
Contributor Author

drisspg commented May 24, 2024

@alugorey ahh no worries, let met rebase and whip this PR back into shape

@drisspg drisspg force-pushed the amd-support branch 2 times, most recently from b1163d1 to e29cc35 Compare May 28, 2024 17:52
@drisspg drisspg changed the title Add AMD fp8 type support Adds utilities for AMD fp8 dtype support, follow up PR to add option to the configs Jun 1, 2024
@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@drisspg drisspg requested a review from vkuzo June 2, 2024 17:59
Comment on lines 43 to 48
elif float8_dtype == torch.float8_e4m3fnuz:
res = E4M3_FNUZ_MAX_POS / torch.clamp(amax, min=EPS)
elif float8_dtype == torch.float8_e5m2:
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
elif float8_dtype == torch.float8_e5m2fnuz:
res = E5M2_FNUZ_MAX_POS / torch.clamp(amax, min=EPS)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plz avoid code duplication

Suggested change
elif float8_dtype == torch.float8_e4m3fnuz:
res = E4M3_FNUZ_MAX_POS / torch.clamp(amax, min=EPS)
elif float8_dtype == torch.float8_e5m2:
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
elif float8_dtype == torch.float8_e5m2fnuz:
res = E5M2_FNUZ_MAX_POS / torch.clamp(amax, min=EPS)
elif float8_dtype in [torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]:
res = torch.finfo(dtype).max / torch.clamp(amax, min=EPS)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, you don't even need ifs there, just assert that float8_dtype is indeed the one

  assert float8_dtype.itemsize == 1 and float8_dtype.is_floating_point

@drisspg drisspg force-pushed the amd-support branch 3 times, most recently from 7feb581 to 5da5b5c Compare June 4, 2024 17:18
@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@drisspg merged this pull request in 5fc07fc.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants