Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Jun 23, 2024
1 parent 53921b2 commit ee5d4f9
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 2 deletions.
6 changes: 6 additions & 0 deletions float8_experimental/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):
b_data = b._data

if a._mm_config.pad_inner_dim:
assert (
b._mm_config.pad_inner_dim
), "Both mm configs must have pad_inner_dim set to True"
assert a._data.size(1) == b._data.size(
0
), f"Inner dims must match for mm, got {a._data.size(1)} and {b._data.size(0)}"
a_data = pad_tensor_for_matmul(a_data, dims=1)
b_data = pad_tensor_for_matmul(b_data, dims=0)

Expand Down
1 change: 0 additions & 1 deletion float8_experimental/float8_python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
to simplify the product code.
"""


from typing import Optional

import float8_experimental.float8_aten_api # noqa
Expand Down
56 changes: 55 additions & 1 deletion test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
import itertools
import random
import re
import unittest
import warnings

Expand Down Expand Up @@ -313,7 +314,7 @@ class TestScaledMM:
"base_dtype", [torch.float16, torch.bfloat16, torch.float32]
)
@pytest.mark.parametrize("use_fast_accum", [True, False])
def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum, padded):
torch.manual_seed(42)
input_dtype = e4m3_dtype
output_dtype = base_dtype
Expand Down Expand Up @@ -387,6 +388,59 @@ def test_merge_configs(self):
assert c.use_fast_accum is True
assert c.fp8_output is False

@unittest.skipIf(
not is_H100,
"CUDA not available",
)
@pytest.mark.parametrize(
"base_dtype", [torch.float16, torch.bfloat16, torch.float32]
)
@pytest.mark.parametrize("use_fast_accum", [True, False])
def test_pad_inner_dim(self, base_dtype, use_fast_accum):
torch.manual_seed(42)
input_dtype = torch.float8_e4m3fn
compare_type = torch.float32

a = torch.randn(16, 41, device="cuda", dtype=base_dtype)
b = torch.randn(41, 128, device="cuda", dtype=base_dtype)

a_scale = tensor_to_scale(a, input_dtype).float()
b_scale = tensor_to_scale(b, input_dtype).float()

a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype)
b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype)

with pytest.raises(
RuntimeError,
match=re.escape(
"Expected trailing dimension of mat1 to be divisible by 16 but got mat1 shape: (16x41."
),
):
a_fp8 @ b_fp8

pad_config = ScaledMMConfig(False, use_fast_accum, False, True)

a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype, mm_config=pad_config)
b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype, mm_config=pad_config)
out_padded = a_fp8 @ b_fp8
out_padded.to(compare_type)

emulated_conifg = ScaledMMConfig(True, use_fast_accum, False, False)
a_fp8 = Float8Tensor.to_float8(
a, a_scale, input_dtype, mm_config=emulated_conifg
)
b_fp8 = Float8Tensor.to_float8(
b, b_scale, input_dtype, mm_config=emulated_conifg
)
out_emualted = a_fp8 @ b_fp8
out_emualted.to(compare_type)

if base_dtype in {torch.bfloat16, torch.float16}:
atol, rtol = 7e-2, 7e-2
else:
atol, rtol = 2e-3, 2e-3
torch.testing.assert_close(out_padded, out_emualted, atol=atol, rtol=rtol)


class TestNumerics:
@pytest.mark.parametrize(
Expand Down

0 comments on commit ee5d4f9

Please sign in to comment.