Skip to content

Commit

Permalink
support blockwise fp8 matmul kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhang2077 committed Feb 3, 2025
1 parent 455bfe8 commit 0cee1ef
Show file tree
Hide file tree
Showing 12 changed files with 1,574 additions and 0 deletions.
175 changes: 175 additions & 0 deletions sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import argparse
import copy
import itertools

import torch
import triton
from sgl_kernel import fp8_blockwise_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm


def to_int8(tensor: torch.Tensor) -> torch.Tensor:
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)


WEIGHT_SHAPES = {
"meta-llama/Llama-3.1-8B-Instruct": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-3.3-70B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 57344], 1),
([28672, 8192], 0),
],
"mistralai/Mistral-Large-Instruct-2407": [
([12288, 14336], 1),
([12288, 12288], 0),
([12288, 57344], 1),
([28672, 12288], 0),
],
"Qwen/Qwen2.5-7B-Instruct": [
([3584, 4608], 1),
([3584, 3584], 0),
([3584, 37888], 1),
([18944, 3584], 0),
],
"Qwen/Qwen2.5-32B-Instruct": [
([5120, 7168], 1),
([5120, 5120], 0),
([5120, 55296], 1),
([27648, 5120], 0),
],
"Qwen/Qwen2.5-72B-Instruct": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 59136], 1),
([29568, 8192], 0),
],
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
([2048, 3072], 1),
([2048, 4096], 1),
([2048, 2048], 0),
([2048, 576], 0),
([2048, 21888], 1),
([10944, 2048], 0),
([2048, 2816], 1),
([1408, 2048], 0),
],
}


def cdiv(a: int, b: int) -> int:
"""Ceiling division."""
return -(a // -b)


def scale_shape(shape, group_shape):
assert len(shape) == len(group_shape)
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
x_log=False,
line_arg="provider",
line_vals=["vllm", "sgl-kernel"],
line_names=["vllm fp8 blockwise gemm", "sgl-kernel fp8 blockwise gemm"],
styles=[("blue", "-"), ("orange", "-")],
ylabel="GB/s",
plot_name="fp8 blockwise scaled matmul",
args={},
)
)
def benchmark(batch_size, provider, N, K):
M = batch_size
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min

a_fp32 = (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)

b_fp32 = (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t()

scale_a_group_shape = (1, 128)
scale_b_group_shape = (128, 128)
scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape)
scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape)

scale_a = torch.randn(scale_a_shape, device="cuda", dtype=torch.float32)
scale_b = torch.randn(scale_b_shape, device="cuda", dtype=torch.float32)
scale_a = scale_a.t().contiguous().t()
scale_b = scale_b.t().contiguous().t()

quantiles = [0.5, 0.2, 0.8]
if provider == "sgl-kernel":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fp8_blockwise_scaled_mm(
a_fp8, b_fp8, scale_a, scale_b, torch.float16
),
quantiles=quantiles,
)
if provider == "vllm":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16),
quantiles=quantiles,
)
gbps = (
lambda ms: (
(2 * M * N * K - M * N) * a_fp8.element_size()
+ (3 * M * N) * scale_a.element_size()
)
* 1e-9
/ (ms * 1e-3)
)
return gbps(ms), gbps(max_ms), gbps(min_ms)


def prepare_shapes(args):
KN_model_names = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
assert model in WEIGHT_SHAPES
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KN.append(model)
KN_model_names.append(KN)
return KN_model_names


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
nargs="+",
type=str,
default=["meta-llama/Llama-3.1-8B-Instruct"],
help="List of models to benchmark",
)
parser.add_argument(
"--tp-sizes",
nargs="+",
type=int,
default=[1],
help="List of tensor parallel sizes",
)
args = parser.parse_args()

KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True,
show_plots=True,
save_path="bench_fp8_blockwise_res",
N=N,
K=K,
)

print("Benchmark finished!")
1 change: 1 addition & 0 deletions sgl-kernel/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def _get_version():
"src/sgl-kernel/csrc/moe_align_kernel.cu",
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
"src/sgl-kernel/csrc/fp8_gemm_kernel.cu",
"src/sgl-kernel/csrc/fp8_blockwise_gemm_kernel.cu",
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
"src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu",
"3rdparty/flashinfer/csrc/activation.cu",
Expand Down
2 changes: 2 additions & 0 deletions sgl-kernel/src/sgl-kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
bmm_fp8,
custom_dispose,
custom_reduce,
fp8_blockwise_scaled_mm,
fp8_scaled_mm,
fused_add_rmsnorm,
gelu_and_mul,
Expand All @@ -29,6 +30,7 @@
"bmm_fp8",
"custom_dispose",
"custom_reduce",
"fp8_blockwise_scaled_mm",
"fp8_scaled_mm",
"fused_add_rmsnorm",
"gelu_and_mul",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Adapt from
// https://github.com/vllm-project/vllm/blob/v0.7.1/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp
// Modified from: cutlass/gemm/collective/builders/sm90_gmma_builder.inl
// clang-format off
#pragma once

#include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl"

#include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp"


/////////////////////////////////////////////////////////////////////////////////////////////////

namespace cutlass::gemm::collective {

/////////////////////////////////////////////////////////////////////////////////////////////////

// GMMA_TMA_WS_SS (BlockScaled Builders)
template <
class ElementA,
class GmemLayoutATag,
int AlignmentA,
class ElementB,
class GmemLayoutBTag,
int AlignmentB,
class ElementAccumulator,
class TileShape_MNK,
class ClusterShape_MNK,
class StageCountType,
int ScaleGranularityM
>
struct CollectiveBuilder<
arch::Sm90,
arch::OpClassTensorOp,
ElementA,
GmemLayoutATag,
AlignmentA,
ElementB,
GmemLayoutBTag,
AlignmentB,
ElementAccumulator,
TileShape_MNK,
ClusterShape_MNK,
StageCountType,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>,
cute::enable_if_t<
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
> {
using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>;

static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
#endif
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
"Should meet TMA alignment requirement\n");

static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v<KernelScheduleType,
KernelPtrArrayTmaWarpSpecializedCooperative,
KernelPtrArrayTmaWarpSpecializedPingpong>);
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
static_assert((!IsFP8Input || !IsArrayOfPointersGemm),
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now.");

// For fp32 types, map to tf32 MMA value type
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;

static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementAMma, GmemLayoutATag>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementBMma, GmemLayoutBTag>();

static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperative,
KernelPtrArrayTmaWarpSpecializedCooperative,
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>;
using AtomLayoutMNK = cute::conditional_t<IsCooperative,
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;

using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<
ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{}));

using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));

using SmemLayoutAtomA = decltype(detail::ss_smem_selector<
GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<
GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());

static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0;
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);

static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
ElementAMma, ElementBMma, TileShape_MNK>(StageCountType{});
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM>;

using SmemCopyAtomA = void;
using SmemCopyAtomB = void;

using CollectiveOp = CollectiveMma<
DispatchPolicy,
TileShape_MNK,
ElementA,
TagToStrideA_t<GmemLayoutATag>,
ElementB,
TagToStrideB_t<GmemLayoutBTag>,
TiledMma,
GmemTiledCopyA,
SmemLayoutAtomA,
SmemCopyAtomA,
cute::identity,
GmemTiledCopyB,
SmemLayoutAtomB,
SmemCopyAtomB,
cute::identity
>;
};


/////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace cutlass::gemm::collective

/////////////////////////////////////////////////////////////////////////////////////////////////
Loading

0 comments on commit 0cee1ef

Please sign in to comment.