Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
bigger sweep
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Jun 24, 2024
1 parent 017e858 commit 7accfa5
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 24 deletions.
59 changes: 36 additions & 23 deletions benchmarks/bench_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import fire

import torch
import torch.utils.benchmark as benchmark
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
from float8_experimental.float8_utils import pad_tensor_for_matmul
from tabulate import tabulate
from torch._inductor.utils import do_bench_using_profiling
from tqdm import tqdm

# estimating TOPs for matmuls in fp32, fp16, fp8
# assuming A * B = C, with A being M * K, B being K * N, C being M * N
Expand All @@ -26,14 +28,9 @@


def benchmark_fn_in_usec(f, *args, **kwargs):
# Manual warmup
for _ in range(4):
f(*args, **kwargs)
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
measurement = t0.blocked_autorange()
return measurement.mean * 1e6
no_args = lambda: f(*args, **kwargs)
time = do_bench_using_profiling(no_args)
return time * 1e3


def get_tops_info(tops, time, peak_tops):
Expand All @@ -44,23 +41,27 @@ def get_tops_info(tops, time, peak_tops):


def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
A_fp8 = A.to(fp8_dtype)
B_fp8 = B.to(fp8_dtype).t() # view

scale_a = torch.tensor([1], device="cuda", dtype=torch.float32)
scale_b = torch.tensor([1], device="cuda", dtype=torch.float32)

A_pad = pad_tensor_for_matmul(A_fp8, dims=1) # mem copy
B_pad = pad_tensor_for_matmul(B_fp8, dims=[0, 1]).contiguous().t() # mem copy
a_config = ScaledMMConfig(
emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True
)
b_config = ScaledMMConfig(
emulate=False, use_fast_accum=True, fp8_output=True, pad_inner_dim=True
)

return torch._scaled_mm(A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype)[
: A.shape[0], : B.shape[1]
]
a_fp8 = Float8Tensor.to_float8(A, scale_a, fp8_dtype, mm_config=a_config)
b_fp8 = Float8Tensor.to_float8(B, scale_b, fp8_dtype, mm_config=b_config)

return a_fp8 @ b_fp8


def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype):
# Breaks with compile due to trying to pad on fp8 dtype
# return do_fp8_matmul(A, B, fp8_dtype, out_dtype)
A_pad = pad_tensor_for_matmul(A, dims=1) # mem copy
B_pad = pad_tensor_for_matmul(B, dims=[0, 1]) # mem copy
B_pad = pad_tensor_for_matmul(B, dims=0) # mem copy

scale_a = torch.tensor([1], device="cuda", dtype=torch.float32)
scale_b = torch.tensor([1], device="cuda", dtype=torch.float32)
Expand All @@ -70,9 +71,9 @@ def do_fp8_pad_first_matmul(A, B, fp8_dtype, out_dtype):

B_pad = B_pad.t().contiguous().t() # mem copy

return torch._scaled_mm(A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype)[
: A.shape[0], : B.shape[1]
]
return torch._scaled_mm(
A_pad, B_pad, scale_a, scale_b, out_dtype=out_dtype, use_fast_accum=True
)


def do_hp_matmul(A, B):
Expand All @@ -92,7 +93,18 @@ def __iter__(self):


def gen_configs():
shapes = [(8192, 2500, 5000), (64, 255, 4096)]
shapes = shapes = [
(8193, 2501, 5008),
(65, 253, 4096),
(1023, 1029, 2512),
(4095, 511, 10000),
(2047, 3073, 8192),
(511, 769, 7504),
(127, 4097, 12288),
(32769, 15, 15024),
(9217, 8191, 20480),
(16385, 1025, 25008),
]
output_dtype = torch.bfloat16
fp8_dtype = torch.float8_e4m3fn
return [Experiment_config(*shape, output_dtype, fp8_dtype) for shape in shapes]
Expand All @@ -112,7 +124,8 @@ def run(compile: bool = False, n_limit: Optional[int] = None):
"Ref % Peak",
"FP8 % Peak",
]
for experiment in experiments:

for experiment in tqdm(experiments):
M, K, N, output_dtype, fp8_dtype = experiment
tops = 2 * M * N * K

Expand Down
2 changes: 1 addition & 1 deletion test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ class TestScaledMM:
"base_dtype", [torch.float16, torch.bfloat16, torch.float32]
)
@pytest.mark.parametrize("use_fast_accum", [True, False])
def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum, padded):
def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
torch.manual_seed(42)
input_dtype = e4m3_dtype
output_dtype = base_dtype
Expand Down

0 comments on commit 7accfa5

Please sign in to comment.