Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support blockwise fp8 matmul kernel #3267

Merged
merged 3 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import argparse
import copy
import itertools

import torch
import triton
from sgl_kernel import fp8_blockwise_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm


def get_weight_shapes(args):
models_tps = list(itertools.product(args.models, args.tp_sizes))
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model.
# cannot TP
total = [
# (512 + 64, 7168), # this weight is not supported by current kernel
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(7168, 16384),
(7168, 18432),
]
# N can TP
n_tp = [
(18432 * 2, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(24576, 1536),
(4096, 7168),
]
# K can TP
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
# only support Deepseek-V3
SUPPORT_MODEL = ["deepseek-ai/DeepSeek-V3"]

weight_shapes = []
for model, tp_size in models_tps:
assert model in SUPPORT_MODEL
for t in total:
new_t = [t[0], t[1], model]
weight_shapes.append(new_t)
for n_t in n_tp:
new_t = [n_t[0] // tp_size, n_t[1], model]
weight_shapes.append(new_t)
for k_t in k_tp:
new_t = [k_t[0], k_t[1] // tp_size, model]
weight_shapes.append(new_t)
return weight_shapes


def cdiv(a: int, b: int) -> int:
"""Ceiling division."""
return -(a // -b)


def scale_shape(shape, group_shape):
assert len(shape) == len(group_shape)
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
x_log=False,
line_arg="provider",
line_vals=["vllm", "sgl-kernel"],
line_names=["vllm fp8 blockwise gemm", "sgl-kernel fp8 blockwise gemm"],
styles=[("blue", "-"), ("orange", "-")],
ylabel="GB/s",
plot_name="fp8 blockwise scaled matmul",
args={},
)
)
def benchmark(batch_size, provider, N, K):
M = batch_size
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min

a_fp32 = (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)

b_fp32 = (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t()

scale_a_group_shape = (1, 128)
scale_b_group_shape = (128, 128)
scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape)
scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape)

scale_a = torch.randn(scale_a_shape, device="cuda", dtype=torch.float32)
scale_b = torch.randn(scale_b_shape, device="cuda", dtype=torch.float32)
scale_a = scale_a.t().contiguous().t()
scale_b = scale_b.t().contiguous().t()

quantiles = [0.5, 0.2, 0.8]
if provider == "sgl-kernel":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fp8_blockwise_scaled_mm(
a_fp8, b_fp8, scale_a, scale_b, torch.float16
),
quantiles=quantiles,
)
if provider == "vllm":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16),
quantiles=quantiles,
)
gbps = (
lambda ms: (
(2 * M * N * K - M * N) * a_fp8.element_size()
+ (3 * M * N) * scale_a.element_size()
)
* 1e-9
/ (ms * 1e-3)
)
return gbps(ms), gbps(max_ms), gbps(min_ms)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
nargs="+",
type=str,
default=["deepseek-ai/DeepSeek-V3"],
help="List of models to benchmark",
)
parser.add_argument(
"--tp-sizes",
nargs="+",
type=int,
default=[1],
help="List of tensor parallel sizes",
)
args = parser.parse_args()

NK_model_names = get_weight_shapes(args)
for N, K, model_name in NK_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True,
show_plots=True,
save_path="bench_fp8_blockwise_res",
N=N,
K=K,
)

print("Benchmark finished!")
1 change: 1 addition & 0 deletions sgl-kernel/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def _get_version():
"src/sgl-kernel/csrc/moe_align_kernel.cu",
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
"src/sgl-kernel/csrc/fp8_gemm_kernel.cu",
"src/sgl-kernel/csrc/fp8_blockwise_gemm_kernel.cu",
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
"src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu",
"src/sgl-kernel/csrc/eagle_utils.cu",
Expand Down
2 changes: 2 additions & 0 deletions sgl-kernel/src/sgl-kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
build_tree_kernel_efficient,
custom_dispose,
custom_reduce,
fp8_blockwise_scaled_mm,
fp8_scaled_mm,
fused_add_rmsnorm,
gelu_and_mul,
Expand Down Expand Up @@ -44,6 +45,7 @@
"bmm_fp8",
"custom_dispose",
"custom_reduce",
"fp8_blockwise_scaled_mm",
"fp8_scaled_mm",
"fused_add_rmsnorm",
"gelu_and_mul",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Adapt from
// https://github.com/vllm-project/vllm/blob/v0.7.1/csrc/cutlass_extensions/gemm/collective/collective_buildler.hpp
// Modified from: cutlass/gemm/collective/builders/sm90_gmma_builder.inl
// clang-format off
#pragma once

#include <cutlass/gemm/collective/builders/sm90_gmma_builder.inl>
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"


/////////////////////////////////////////////////////////////////////////////////////////////////

namespace cutlass::gemm::collective {

/////////////////////////////////////////////////////////////////////////////////////////////////

// GMMA_TMA_WS_SS (BlockScaled Builders)
template <
class ElementA,
class GmemLayoutATag,
int AlignmentA,
class ElementB,
class GmemLayoutBTag,
int AlignmentB,
class ElementAccumulator,
class TileShape_MNK,
class ClusterShape_MNK,
class StageCountType,
int ScaleGranularityM
>
struct CollectiveBuilder<
arch::Sm90,
arch::OpClassTensorOp,
ElementA,
GmemLayoutATag,
AlignmentA,
ElementB,
GmemLayoutBTag,
AlignmentB,
ElementAccumulator,
TileShape_MNK,
ClusterShape_MNK,
StageCountType,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>,
cute::enable_if_t<
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
> {
using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>;

static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
#endif
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
"Should meet TMA alignment requirement\n");

static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v<KernelScheduleType,
KernelPtrArrayTmaWarpSpecializedCooperative,
KernelPtrArrayTmaWarpSpecializedPingpong>);
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
static_assert((!IsFP8Input || !IsArrayOfPointersGemm),
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now.");

// For fp32 types, map to tf32 MMA value type
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;

static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementAMma, GmemLayoutATag>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementBMma, GmemLayoutBTag>();

static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperative,
KernelPtrArrayTmaWarpSpecializedCooperative,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>;
using AtomLayoutMNK = cute::conditional_t<IsCooperative,
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;

using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{}));

using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));

using SmemLayoutAtomA = decltype(detail::ss_smem_selector<
GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());

static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0;
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);

static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
ElementAMma, ElementBMma, TileShape_MNK>(StageCountType{});
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM>;

using SmemCopyAtomA = void;
using SmemCopyAtomB = void;

using CollectiveOp = CollectiveMma<
DispatchPolicy,
TileShape_MNK,
ElementA,
TagToStrideA_t<GmemLayoutATag>,
ElementB,
TagToStrideB_t<GmemLayoutBTag>,
TiledMma,
GmemTiledCopyA,
SmemLayoutAtomA,
SmemCopyAtomA,
cute::identity,
GmemTiledCopyB,
SmemLayoutAtomB,
SmemCopyAtomB,
cute::identity
>;
};


/////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace cutlass::gemm::collective

/////////////////////////////////////////////////////////////////////////////////////////////////
Loading
Loading