From 7accfa5a604944d1667b31e76e39491d95a63a64 Mon Sep 17 00:00:00 2001 From: drisspg Date: Sat, 22 Jun 2024 17:20:41 -0700 Subject: [PATCH] bigger sweep --- benchmarks/bench_padding.py | 59 ++++++++++++++++++++++--------------- test/test_base.py | 2 +- 2 files changed, 37 insertions(+), 24 deletions(-) diff --git a/benchmarks/bench_padding.py b/benchmarks/bench_padding.py index 3a10b8e..301e4b6 100644 --- a/benchmarks/bench_padding.py +++ b/benchmarks/bench_padding.py @@ -4,9 +4,11 @@ import fire import torch -import torch.utils.benchmark as benchmark +from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig from float8_experimental.float8_utils import pad_tensor_for_matmul from tabulate import tabulate +from torch._inductor.utils import do_bench_using_profiling +from tqdm import tqdm # estimating TOPs for matmuls in fp32, fp16, fp8 # assuming A * B = C, with A being M * K, B being K * N, C being M * N @@ -26,14 +28,9 @@ def benchmark_fn_in_usec(f, *args, **kwargs): - # Manual warmup - for _ in range(4): - f(*args, **kwargs) - t0 = benchmark.Timer( - stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} - ) - measurement = t0.blocked_autorange() - return measurement.mean * 1e6 + no_args = lambda: f(*args, **kwargs) + time = do_bench_using_profiling(no_args) + return time * 1e3 def get_tops_info(tops, time, peak_tops): @@ -44,23 +41,27 @@ def get_tops_info(tops, time, peak_tops): def do_fp8_matmul(A, B, fp8_dtype, out_dtype): - A_fp8 = A.to(fp8_dtype) - B_fp8 = B.to(fp8_dtype).t() # view - scale_a = torch.tensor([1], device="cuda", dtype=torch.float32) scale_b = torch.tensor([1], device="cuda", dtype=torch.float32) - A_pad = pad_tensor_for_matmul(A_fp8, dims=1) # mem copy - B_pad = pad_tensor_for_matmul(B_fp8, dims=[0, 1]).contiguous().t() # mem copy + a_config = ScaledMMConfig( + emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True + ) + b_config = ScaledMMConfig( + emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True + ) - return torch._scaled_mm(A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype)[ - : A.shape[0], : B.shape[1] - ] + a_fp8 = Float8Tensor.to_float8(A, scale_a, fp8_dtype, mm_config=a_config) + b_fp8 = Float8Tensor.to_float8(B, scale_b, fp8_dtype, mm_config=b_config) + + return a_fp8 @ b_fp8 def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype): + # Breaks with compile due to trying to pad on fp8 dtype + # return do_fp8_matmul(A, B, fp8_dtype, out_dtype) A_pad = pad_tensor_for_matmul(A, dims=1) # mem copy - B_pad = pad_tensor_for_matmul(B, dims=[0, 1]) # mem copy + B_pad = pad_tensor_for_matmul(B, dims=0) # mem copy scale_a = torch.tensor([1], device="cuda", dtype=torch.float32) scale_b = torch.tensor([1], device="cuda", dtype=torch.float32) @@ -70,9 +71,9 @@ def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype): B_pad = B_pad.t().contiguous().t() # mem copy - return torch._scaled_mm(A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype)[ - : A.shape[0], : B.shape[1] - ] + return torch._scaled_mm( + A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype, use_fast_accum=True + ) def do_hp_matmul(A, B): @@ -92,7 +93,18 @@ def __iter__(self): def gen_configs(): - shapes = [(8192, 2500, 5000), (64, 255, 4096)] + shapes = shapes = [ + (8193, 2501, 5008), + (65, 253, 4096), + (1023, 1029, 2512), + (4095, 511, 10000), + (2047, 3073, 8192), + (511, 769, 7504), + (127, 4097, 12288), + (32769, 15, 15024), + (9217, 8191, 20480), + (16385, 1025, 25008), + ] output_dtype = torch.bfloat16 fp8_dtype = torch.float8_e4m3fn return [Experiment_config(*shape, output_dtype, fp8_dtype) for shape in shapes] @@ -112,7 +124,8 @@ def run(compile: bool = False, n_limit: Optional[int] = None): "Ref % Peak", "FP8 % Peak", ] - for experiment in experiments: + + for experiment in tqdm(experiments): M, K, N, output_dtype, fp8_dtype = experiment tops = 2 * M * N * K diff --git a/test/test_base.py b/test/test_base.py index 14e3241..b688ccb 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -314,7 +314,7 @@ class TestScaledMM: "base_dtype", [torch.float16, torch.bfloat16, torch.float32] ) @pytest.mark.parametrize("use_fast_accum", [True, False]) - def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum, padded): + def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum): torch.manual_seed(42) input_dtype = e4m3_dtype output_dtype = base_dtype