-
Notifications
You must be signed in to change notification settings - Fork 20
Adds utilities for AMD fp8 dtype support, follow up PR to add option to the configs #235
Conversation
fp8_tensor = to_fp8_no_autograd( | ||
gradY, gradY_scale, torch.float8_e5m2, ctx.emulate | ||
) | ||
fp8_dtype = torch.float8_e5m2fnuz if IS_AMD else torch.float8_e5m2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be configurable in the forward with a reasonable default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if it would be better to have torch.backends.[cuda|hip|mps].supports_dtype(XYZ)
API, because I assume XPUs would probably use fnuz
flavor, but say ARM would be using e5m2
flavor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be configurable instead of depending on env, numerics should be as predictable as possible. It's also valuable to emulate numerics without having the hardware, for debugging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oaky, so it sounds like we want this to be defined at module construction. It is up to the constructor of the module to ensure that _scaled_mm will work with their module.
I do think that Nikitas backend dtype helper could a useful pytorch feature, not sure if that should live here though
The non emulated version is failing for me on on an mi300 machine using this version of pytorch: pytorch-triton-rocm==3.0.0+0a22a91d04
torch==2.3.0.dev20240308+rocm6.0 FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[True-LinearType.DYNAMIC-x_shape0-False] - AssertionError: -2.7592885494232178 is too low
FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[True-LinearType.DYNAMIC-x_shape1-False] - AssertionError: -3.372152805328369 is too low
FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[True-LinearType.DYNAMIC-x_shape2-False] - AssertionError: -2.8420748710632324 is too low
FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[False-LinearType.DELAYED-x_shape0-False] - AssertionError: -2.7584447860717773 is too low
FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[False-LinearType.DELAYED-x_shape1-False] - AssertionError: -2.946033239364624 is too low
FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[False-LinearType.DELAYED-x_shape2-False] - AssertionError: -2.756319999694824 is too low
FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[False-LinearType.DYNAMIC-x_shape0-False] - AssertionError: -3.377957820892334 is too low
FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[False-LinearType.DYNAMIC-x_shape1-False] - AssertionError: -3.0644452571868896 is too low
FAILED test/test_base.py::TestFloat8Linear::test_linear_nobias[False-LinearType.DYNAMIC-x_shape2-False] - AssertionError: -3.091813564300537 is too low |
@drisspg Is this mergeable? The errors above should have been fixed with this: pytorch/pytorch#125921 |
@alugorey ahh no worries, let met rebase and whip this PR back into shape |
b1163d1
to
e29cc35
Compare
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
float8_experimental/float8_utils.py
Outdated
elif float8_dtype == torch.float8_e4m3fnuz: | ||
res = E4M3_FNUZ_MAX_POS / torch.clamp(amax, min=EPS) | ||
elif float8_dtype == torch.float8_e5m2: | ||
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS) | ||
elif float8_dtype == torch.float8_e5m2fnuz: | ||
res = E5M2_FNUZ_MAX_POS / torch.clamp(amax, min=EPS) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Plz avoid code duplication
elif float8_dtype == torch.float8_e4m3fnuz: | |
res = E4M3_FNUZ_MAX_POS / torch.clamp(amax, min=EPS) | |
elif float8_dtype == torch.float8_e5m2: | |
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS) | |
elif float8_dtype == torch.float8_e5m2fnuz: | |
res = E5M2_FNUZ_MAX_POS / torch.clamp(amax, min=EPS) | |
elif float8_dtype in [torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]: | |
res = torch.finfo(dtype).max / torch.clamp(amax, min=EPS) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, you don't even need ifs there, just assert that float8_dtype is indeed the one
assert float8_dtype.itemsize == 1 and float8_dtype.is_floating_point
7feb581
to
5da5b5c
Compare
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Summary
AMD GPUS support a different fp8 dtype compared to nvidia. These dtypes were added to PyTorch and we update Float8Tensor construction to use the format dependent on the arch.
For a detailed summary see: https://github.com/openxla/stablehlo/blob/main/rfcs/20230321-fp8_fnuz.md