-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
455bfe8
commit 0cee1ef
Showing
12 changed files
with
1,574 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
import argparse | ||
import copy | ||
import itertools | ||
|
||
import torch | ||
import triton | ||
from sgl_kernel import fp8_blockwise_scaled_mm | ||
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm | ||
|
||
|
||
def to_int8(tensor: torch.Tensor) -> torch.Tensor: | ||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) | ||
|
||
|
||
WEIGHT_SHAPES = { | ||
"meta-llama/Llama-3.1-8B-Instruct": [ | ||
([4096, 6144], 1), | ||
([4096, 4096], 0), | ||
([4096, 28672], 1), | ||
([14336, 4096], 0), | ||
], | ||
"meta-llama/Llama-3.3-70B-Instruct": [ | ||
([8192, 10240], 1), | ||
([8192, 8192], 0), | ||
([8192, 57344], 1), | ||
([28672, 8192], 0), | ||
], | ||
"mistralai/Mistral-Large-Instruct-2407": [ | ||
([12288, 14336], 1), | ||
([12288, 12288], 0), | ||
([12288, 57344], 1), | ||
([28672, 12288], 0), | ||
], | ||
"Qwen/Qwen2.5-7B-Instruct": [ | ||
([3584, 4608], 1), | ||
([3584, 3584], 0), | ||
([3584, 37888], 1), | ||
([18944, 3584], 0), | ||
], | ||
"Qwen/Qwen2.5-32B-Instruct": [ | ||
([5120, 7168], 1), | ||
([5120, 5120], 0), | ||
([5120, 55296], 1), | ||
([27648, 5120], 0), | ||
], | ||
"Qwen/Qwen2.5-72B-Instruct": [ | ||
([8192, 10240], 1), | ||
([8192, 8192], 0), | ||
([8192, 59136], 1), | ||
([29568, 8192], 0), | ||
], | ||
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [ | ||
([2048, 3072], 1), | ||
([2048, 4096], 1), | ||
([2048, 2048], 0), | ||
([2048, 576], 0), | ||
([2048, 21888], 1), | ||
([10944, 2048], 0), | ||
([2048, 2816], 1), | ||
([1408, 2048], 0), | ||
], | ||
} | ||
|
||
|
||
def cdiv(a: int, b: int) -> int: | ||
"""Ceiling division.""" | ||
return -(a // -b) | ||
|
||
|
||
def scale_shape(shape, group_shape): | ||
assert len(shape) == len(group_shape) | ||
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape))) | ||
|
||
|
||
@triton.testing.perf_report( | ||
triton.testing.Benchmark( | ||
x_names=["batch_size"], | ||
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048], | ||
x_log=False, | ||
line_arg="provider", | ||
line_vals=["vllm", "sgl-kernel"], | ||
line_names=["vllm fp8 blockwise gemm", "sgl-kernel fp8 blockwise gemm"], | ||
styles=[("blue", "-"), ("orange", "-")], | ||
ylabel="GB/s", | ||
plot_name="fp8 blockwise scaled matmul", | ||
args={}, | ||
) | ||
) | ||
def benchmark(batch_size, provider, N, K): | ||
M = batch_size | ||
fp8_info = torch.finfo(torch.float8_e4m3fn) | ||
fp8_max, fp8_min = fp8_info.max, fp8_info.min | ||
|
||
a_fp32 = (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max | ||
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) | ||
|
||
b_fp32 = (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max | ||
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t() | ||
|
||
scale_a_group_shape = (1, 128) | ||
scale_b_group_shape = (128, 128) | ||
scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape) | ||
scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape) | ||
|
||
scale_a = torch.randn(scale_a_shape, device="cuda", dtype=torch.float32) | ||
scale_b = torch.randn(scale_b_shape, device="cuda", dtype=torch.float32) | ||
scale_a = scale_a.t().contiguous().t() | ||
scale_b = scale_b.t().contiguous().t() | ||
|
||
quantiles = [0.5, 0.2, 0.8] | ||
if provider == "sgl-kernel": | ||
ms, min_ms, max_ms = triton.testing.do_bench( | ||
lambda: fp8_blockwise_scaled_mm( | ||
a_fp8, b_fp8, scale_a, scale_b, torch.float16 | ||
), | ||
quantiles=quantiles, | ||
) | ||
if provider == "vllm": | ||
ms, min_ms, max_ms = triton.testing.do_bench( | ||
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16), | ||
quantiles=quantiles, | ||
) | ||
gbps = ( | ||
lambda ms: ( | ||
(2 * M * N * K - M * N) * a_fp8.element_size() | ||
+ (3 * M * N) * scale_a.element_size() | ||
) | ||
* 1e-9 | ||
/ (ms * 1e-3) | ||
) | ||
return gbps(ms), gbps(max_ms), gbps(min_ms) | ||
|
||
|
||
def prepare_shapes(args): | ||
KN_model_names = [] | ||
models_tps = list(itertools.product(args.models, args.tp_sizes)) | ||
for model, tp_size in models_tps: | ||
assert model in WEIGHT_SHAPES | ||
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]): | ||
KN[tp_split_dim] = KN[tp_split_dim] // tp_size | ||
KN.append(model) | ||
KN_model_names.append(KN) | ||
return KN_model_names | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--models", | ||
nargs="+", | ||
type=str, | ||
default=["meta-llama/Llama-3.1-8B-Instruct"], | ||
help="List of models to benchmark", | ||
) | ||
parser.add_argument( | ||
"--tp-sizes", | ||
nargs="+", | ||
type=int, | ||
default=[1], | ||
help="List of tensor parallel sizes", | ||
) | ||
args = parser.parse_args() | ||
|
||
KN_model_names = prepare_shapes(args) | ||
for K, N, model_name in KN_model_names: | ||
print(f"{model_name} N={N} K={K}: ") | ||
benchmark.run( | ||
print_data=True, | ||
show_plots=True, | ||
save_path="bench_fp8_blockwise_res", | ||
N=N, | ||
K=K, | ||
) | ||
|
||
print("Benchmark finished!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
125 changes: 125 additions & 0 deletions
125
sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
// Adapt from | ||
// https://github.com/vllm-project/vllm/blob/v0.7.1/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp | ||
// Modified from: cutlass/gemm/collective/builders/sm90_gmma_builder.inl | ||
// clang-format off | ||
#pragma once | ||
|
||
#include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl" | ||
|
||
#include "cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp" | ||
|
||
|
||
///////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
||
namespace cutlass::gemm::collective { | ||
|
||
///////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
||
// GMMA_TMA_WS_SS (BlockScaled Builders) | ||
template < | ||
class ElementA, | ||
class GmemLayoutATag, | ||
int AlignmentA, | ||
class ElementB, | ||
class GmemLayoutBTag, | ||
int AlignmentB, | ||
class ElementAccumulator, | ||
class TileShape_MNK, | ||
class ClusterShape_MNK, | ||
class StageCountType, | ||
int ScaleGranularityM | ||
> | ||
struct CollectiveBuilder< | ||
arch::Sm90, | ||
arch::OpClassTensorOp, | ||
ElementA, | ||
GmemLayoutATag, | ||
AlignmentA, | ||
ElementB, | ||
GmemLayoutBTag, | ||
AlignmentB, | ||
ElementAccumulator, | ||
TileShape_MNK, | ||
ClusterShape_MNK, | ||
StageCountType, | ||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>, | ||
cute::enable_if_t< | ||
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()> | ||
> { | ||
using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>; | ||
|
||
static_assert(is_static<TileShape_MNK>::value); | ||
static_assert(is_static<ClusterShape_MNK>::value); | ||
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED | ||
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n"); | ||
#endif | ||
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(), | ||
"Should meet TMA alignment requirement\n"); | ||
|
||
static constexpr bool IsArrayOfPointersGemm = (cute::is_any_of_v<KernelScheduleType, | ||
KernelPtrArrayTmaWarpSpecializedCooperative, | ||
KernelPtrArrayTmaWarpSpecializedPingpong>); | ||
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>(); | ||
static_assert((!IsFP8Input || !IsArrayOfPointersGemm), | ||
"KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum is only compatible with FP8 Blocked Scaled version right now."); | ||
|
||
// For fp32 types, map to tf32 MMA value type | ||
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>; | ||
using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>; | ||
|
||
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementAMma, GmemLayoutATag>(); | ||
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementBMma, GmemLayoutBTag>(); | ||
|
||
static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType, | ||
KernelTmaWarpSpecializedCooperative, | ||
KernelPtrArrayTmaWarpSpecializedCooperative, | ||
KernelTmaWarpSpecializedCooperativeFP8BlockScaledSubGroupMAccum<ScaleGranularityM>>; | ||
using AtomLayoutMNK = cute::conditional_t<IsCooperative, | ||
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>; | ||
|
||
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< | ||
ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); | ||
|
||
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); | ||
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); | ||
|
||
using SmemLayoutAtomA = decltype(detail::ss_smem_selector< | ||
GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); | ||
using SmemLayoutAtomB = decltype(detail::ss_smem_selector< | ||
GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); | ||
|
||
static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0; | ||
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage); | ||
|
||
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes - KernelSmemCarveout, | ||
ElementAMma, ElementBMma, TileShape_MNK>(StageCountType{}); | ||
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingSubGroupMFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM>; | ||
|
||
using SmemCopyAtomA = void; | ||
using SmemCopyAtomB = void; | ||
|
||
using CollectiveOp = CollectiveMma< | ||
DispatchPolicy, | ||
TileShape_MNK, | ||
ElementA, | ||
TagToStrideA_t<GmemLayoutATag>, | ||
ElementB, | ||
TagToStrideB_t<GmemLayoutBTag>, | ||
TiledMma, | ||
GmemTiledCopyA, | ||
SmemLayoutAtomA, | ||
SmemCopyAtomA, | ||
cute::identity, | ||
GmemTiledCopyB, | ||
SmemLayoutAtomB, | ||
SmemCopyAtomB, | ||
cute::identity | ||
>; | ||
}; | ||
|
||
|
||
///////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
||
} // namespace cutlass::gemm::collective | ||
|
||
///////////////////////////////////////////////////////////////////////////////////////////////// |
Oops, something went wrong.