From b4473ede90a5688c4be4b5ff6535d93e463cfbd0 Mon Sep 17 00:00:00 2001 From: drisspg Date: Sun, 16 Jun 2024 15:42:05 -0700 Subject: [PATCH] add test --- float8_experimental/float8_ops.py | 6 ++++ test/test_base.py | 52 ++++++++++++++++++++++++++++++- 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 8e94dfc..5341600 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -123,6 +123,12 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor): b_data = b._data if a._mm_config.pad_inner_dim: + assert ( + b._mm_config.pad_inner_dim + ), "Both mm configs must have pad_inner_dim set to True" + assert a._data.size(1) == b._data.size( + 0 + ), f"Inner dims must match for mm, got {a._data.size(1)} and {b._data.size(0)}" a_data = pad_tensor_for_matmul(a_data, dims=1) b_data = pad_tensor_for_matmul(b_data, dims=0) diff --git a/test/test_base.py b/test/test_base.py index 6e7a34c..bc6e92c 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import itertools import random +import re import unittest import warnings @@ -312,7 +313,7 @@ class TestScaledMM: "base_dtype", [torch.float16, torch.bfloat16, torch.float32] ) @pytest.mark.parametrize("use_fast_accum", [True, False]) - def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): + def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum, padded): torch.manual_seed(42) input_dtype = torch.float8_e4m3fn output_dtype = base_dtype @@ -393,6 +394,55 @@ def test_merge_configs(self): assert c.use_fast_accum is True assert c.fp8_output is False + @pytest.mark.parametrize( + "base_dtype", [torch.float16, torch.bfloat16, torch.float32] + ) + @pytest.mark.parametrize("use_fast_accum", [True, False]) + def test_pad_inner_dim(self, base_dtype, use_fast_accum): + torch.manual_seed(42) + input_dtype = torch.float8_e4m3fn + compare_type = torch.float32 + + a = torch.randn(16, 41, device="cuda", dtype=base_dtype) + b = torch.randn(41, 128, device="cuda", dtype=base_dtype) + + a_scale = tensor_to_scale(a, input_dtype).float() + b_scale = tensor_to_scale(b, input_dtype).float() + + a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype) + b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype) + + with pytest.raises( + RuntimeError, + match=re.escape( + "Expected trailing dimension of mat1 to be divisible by 16 but got mat1 shape: (16x41." + ), + ): + a_fp8 @ b_fp8 + + pad_config = ScaledMMConfig(False, use_fast_accum, False, True) + + a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype, mm_config=pad_config) + b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype, mm_config=pad_config) + out_padded = a_fp8 @ b_fp8 + out_padded.to(compare_type) + + emulated_conifg = ScaledMMConfig(True, use_fast_accum, False, False) + a_fp8 = Float8Tensor.to_float8( + a, a_scale, input_dtype, mm_config=emulated_conifg + ) + b_fp8 = Float8Tensor.to_float8( + b, b_scale, input_dtype, mm_config=emulated_conifg + ) + out_emualted = a_fp8 @ b_fp8 + out_emualted.to(compare_type) + + if base_dtype in {torch.bfloat16, torch.float16}: + atol, rtol = 7e-2, 7e-2 + else: + atol, rtol = 2e-3, 2e-3 + torch.testing.assert_close(out_padded, out_emualted, atol=atol, rtol=rtol) + class TestNumerics: @pytest.mark.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])