Skip to content

Commit

Permalink
Customizable SM90 prefill kernels. (#704)
Browse files Browse the repository at this point in the history
  • Loading branch information
hyhieu authored Dec 29, 2024
1 parent 1312409 commit 4ba91c0
Show file tree
Hide file tree
Showing 15 changed files with 2,011 additions and 596 deletions.
84 changes: 47 additions & 37 deletions aot_build_utils/generate_batch_paged_prefill_sm90_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,7 @@
import sys
from pathlib import Path

from .literal_map import (
dtype_literal,
idtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)
from .literal_map import dtype_literal, idtype_literal, mask_mode_literal


def get_cu_file_str(
Expand All @@ -36,40 +31,56 @@ def get_cu_file_str(
dtype_out,
idtype,
):
pos_encoding_mode = None
allow_fp16_qk_reduction = None

def get_insts(attention_variant):
return "\n".join(
[
"""template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>(
Params& params,
cudaStream_t stream);
template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>(
Params& params,
cudaStream_t stream);
template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>(
Params& params,
cudaStream_t stream);
template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>(
Params& params,
cudaStream_t stream);
return """
template cudaError_t BatchPrefillWithPagedKVCacheDispatched
<{head_dim},
{mask_mode},
/*USE_SWA=*/true,
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/true,
{attention_variant}>
(Params& params, cudaStream_t stream);
template cudaError_t BatchPrefillWithPagedKVCacheDispatched
<{head_dim},
{mask_mode},
/*USE_SWA=*/true,
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/false,
{attention_variant}>
(Params& params, cudaStream_t stream);
template cudaError_t BatchPrefillWithPagedKVCacheDispatched
<{head_dim},
{mask_mode},
/*USE_SWA=*/false,
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/true,
{attention_variant}>
(Params& params, cudaStream_t stream);
template cudaError_t BatchPrefillWithPagedKVCacheDispatched
<{head_dim},
{mask_mode},
/*USE_SWA=*/false,
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/false,
{attention_variant}>
(Params& params, cudaStream_t stream);
""".format(
head_dim=head_dim,
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
allow_fp16_qk_reduction=allow_fp16_qk_reduction,
mask_mode=mask_mode_literal[int(mask_mode)],
attention_variant=attention_variant,
)
]
head_dim=head_dim,
mask_mode=mask_mode_literal[int(mask_mode)],
attention_variant=attention_variant,
)

dtype_q = dtype_literal[dtype_q]
dtype_kv = dtype_literal[dtype_kv]
dtype_out = dtype_literal[dtype_out]
idtype = idtype_literal[idtype]

content = f"""#include <flashinfer/attention/hopper/prefill_sm90.cuh>
content = f""" // batch_paged_prefill_sm90 template inst
#include <flashinfer/attention/hopper/params.cuh>
#include <flashinfer/attention/hopper/prefill_sm90.cuh>
#include <flashinfer/attention/hopper/variants.cuh>
#include <flashinfer/cutlass_utils.cuh>
Expand All @@ -82,9 +93,9 @@ def get_insts(attention_variant):
using Params = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, {idtype}>;
{get_insts("LogitsSoftCap")}
{get_insts("LogitsSoftCap<Params>")}
{get_insts("StandardAttention")}
{get_insts("StandardAttention<Params>")}
}}"""
return content
Expand All @@ -93,12 +104,11 @@ def get_insts(attention_variant):
if __name__ == "__main__":
pattern = (
r"batch_paged_prefill_head_([0-9]+)_posenc_([0-9]+)_"
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu"
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_"
r"dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu"
)
compiled_pattern = re.compile(pattern)
path = Path(sys.argv[1])
fname = path.name
match = compiled_pattern.match(fname)

with open(path, "w") as f:
f.write(get_cu_file_str(*match.groups()))
path.write_text(get_cu_file_str(*match.groups()))
57 changes: 33 additions & 24 deletions aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,39 +38,48 @@ def get_cu_file_str(
):

def get_insts(attention_variant):
return "\n".join(
[
"""template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>(
Params& params,
cudaStream_t stream);
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>(
Params& params,
cudaStream_t stream);
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>(
Params& params,
cudaStream_t stream);
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>(
Params& params,
cudaStream_t stream);
return """
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched
<{head_dim},
{mask_mode},
/*USE_SWA=*/true,
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/true,
{attention_variant}>(Params& params, cudaStream_t stream);
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched
<{head_dim},
{mask_mode},
/*USE_SWA=*/true,
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/false,
{attention_variant}>(Params& params, cudaStream_t stream);
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched
<{head_dim},
{mask_mode},
/*USE_SWA=*/false,
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/true,
{attention_variant}>(Params& params, cudaStream_t stream);
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched
<{head_dim},
{mask_mode},
/*USE_SWA=*/false,
/*SAME_SCHEDULE_FOR_ALL_HEADS=*/false,
{attention_variant}>(Params& params, cudaStream_t stream);
""".format(
head_dim=head_dim,
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
allow_fp16_qk_reduction=allow_fp16_qk_reduction,
mask_mode=mask_mode_literal[int(mask_mode)],
attention_variant=attention_variant,
)
]
)

dtype_q = dtype_literal[dtype_q]
dtype_kv = dtype_literal[dtype_kv]
dtype_out = dtype_literal[dtype_out]
idtype = idtype_literal[idtype]

content = f"""#include <flashinfer/attention/hopper/prefill_sm90.cuh>
content = f""" // batch_ragged_prefill_sm90 template inst
#include <flashinfer/attention/hopper/params.cuh>
#include <flashinfer/attention/hopper/prefill_sm90.cuh>
#include <flashinfer/attention/hopper/variants.cuh>
#include <flashinfer/cutlass_utils.cuh>
Expand All @@ -83,9 +92,9 @@ def get_insts(attention_variant):
using Params = BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, {idtype}>;
{get_insts("LogitsSoftCap")}
{get_insts("LogitsSoftCap<Params>")}
{get_insts("StandardAttention")}
{get_insts("StandardAttention<Params>")}
}}
"""
Expand Down
38 changes: 20 additions & 18 deletions aot_build_utils/generate_single_prefill_sm90_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def get_cu_file_str(
dtype_kv,
dtype_out,
):
content = """#include <flashinfer/attention/hopper/prefill_sm90.cuh>
content = """ // single_prefill_sm90 template inst
#include <flashinfer/attention/hopper/params.cuh>
#include <flashinfer/attention/hopper/prefill_sm90.cuh>
#include <flashinfer/attention/hopper/variants.cuh>
#include <flashinfer/cutlass_utils.cuh>
Expand All @@ -42,31 +44,32 @@ def get_cu_file_str(
using Params = SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>;
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, LogitsSoftCap>(
Params& params,
cudaStream_t stream);
template cudaError_t SinglePrefillWithKVCacheDispatched
<{head_dim}, {mask_mode}, /*USE_SWA=*/true, LogitsSoftCap<Params>>
(Params& params, cudaStream_t stream);
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, LogitsSoftCap>(
Params& params,
cudaStream_t stream);
template cudaError_t SinglePrefillWithKVCacheDispatched
<{head_dim}, {mask_mode}, /*USE_SWA=*/false, LogitsSoftCap<Params>>
(Params& params, cudaStream_t stream);
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, StandardAttention>(
Params& params,
cudaStream_t stream);
template cudaError_t SinglePrefillWithKVCacheDispatched
<{head_dim}, {mask_mode}, /*USE_SWA=*/true, StandardAttention<Params>>
(Params& params, cudaStream_t stream);
template cudaError_t SinglePrefillWithKVCacheDispatched
<{head_dim}, {mask_mode}, /*USE_SWA=*/false, StandardAttention<Params>>
(Params& params, cudaStream_t stream);
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, StandardAttention>(
Params& params,
cudaStream_t stream);
}}
""".format(
head_dim=head_dim,
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
allow_fp16_qk_reduction=allow_fp16_qk_reduction,
# pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
# allow_fp16_qk_reduction=allow_fp16_qk_reduction,
mask_mode=mask_mode_literal[int(mask_mode)],
dtype_q=dtype_literal[dtype_q],
dtype_kv=dtype_literal[dtype_kv],
dtype_out=dtype_literal[dtype_out],
use_custom_mask="true" if int(mask_mode) == 2 else "false",
# use_custom_mask="true" if int(mask_mode) == 2 else "false",
)
return content

Expand All @@ -81,5 +84,4 @@ def get_cu_file_str(
path = Path(sys.argv[1])
fname = path.name
match = compiled_pattern.match(fname)
with open(path, "w") as f:
f.write(get_cu_file_str(*match.groups()))
path.write_text(get_cu_file_str(*match.groups()))
26 changes: 14 additions & 12 deletions csrc/batch_prefill_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,14 @@
namespace flashinfer {

template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLINDING_WINDOW,
bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant, typename DTypeQ,
typename DTypeKV, typename DTypeO, typename IdType>
cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, IdType>& params, cudaStream_t stream);
bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant>
cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::ParamsT& params,
cudaStream_t stream);

template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLINDING_WINDOW,
bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant, typename DTypeQ,
typename DTypeKV, typename DTypeO, typename IdType>
cudaError_t BatchPrefillWithPagedKVCacheDispatched(
BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>& params, cudaStream_t stream);
bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant>
cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::ParamsT& params,
cudaStream_t stream);

} // namespace flashinfer

Expand Down Expand Up @@ -110,7 +108,8 @@ void BatchPrefillWithRaggedKVCacheSM90Run(
using DTypeO = DTypeQ;
using IdType = int32_t;

BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, IdType> params;
using BatchPrefillRaggedParams = BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, IdType>;
BatchPrefillRaggedParams params;

params.q_ptr = static_cast<DTypeQ*>(q.data_ptr());
params.k_ptr = static_cast<DTypeKV*>(k.data_ptr());
Expand Down Expand Up @@ -160,7 +159,8 @@ void BatchPrefillWithRaggedKVCacheSM90Run(
return DISPATCH_BOOL(use_swa, USE_SWA, [&] {
return DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] {
using AttentionVariant =
std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap, StandardAttention>;
std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap<BatchPrefillRaggedParams>,
StandardAttention<BatchPrefillRaggedParams>>;
cudaError_t status = BatchPrefillWithRaggedKVCacheDispatched<
HEAD_DIM, MASK_MODE, USE_SWA, SAME_SCHEDULER_FOR_ALL_HEADS, AttentionVariant>(
params, stream);
Expand Down Expand Up @@ -220,7 +220,8 @@ void BatchPrefillWithPagedKVCacheSM90Run(
using DTypeO = DTypeQ;
using IdType = int32_t;

BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType> params;
using BatchPrefillPagedParams = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>;
BatchPrefillPagedParams params;

params.q_ptr = static_cast<DTypeQ*>(q.data_ptr());
params.k_ptr = static_cast<DTypeKV*>(paged_k_cache.data_ptr());
Expand Down Expand Up @@ -272,7 +273,8 @@ void BatchPrefillWithPagedKVCacheSM90Run(
return DISPATCH_BOOL(use_swa, USE_SWA, [&] {
return DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] {
using AttentionVariant =
std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap, StandardAttention>;
std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap<BatchPrefillPagedParams>,
StandardAttention<BatchPrefillPagedParams>>;
cudaError_t status = BatchPrefillWithPagedKVCacheDispatched<
HEAD_DIM, MASK_MODE, USE_SWA, SAME_SCHEDULER_FOR_ALL_HEADS, AttentionVariant>(
params, stream);
Expand Down
10 changes: 6 additions & 4 deletions csrc/single_prefill_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
namespace flashinfer {

template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLINDING_WINDOW,
typename AttentionVariant, typename DTypeQ, typename DTypeKV, typename DTypeO>
cudaError_t SinglePrefillWithKVCacheDispatched(SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>& params,
typename AttentionVariant>
cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::ParamsT& params,
cudaStream_t stream);

} // namespace flashinfer
Expand Down Expand Up @@ -59,7 +59,8 @@ void single_prefill_with_kv_cache_sm90(unsigned int mask_mode_code, at::Tensor q
using DTypeQ = cutlass_dtype_t<q_type>;
using DTypeKV = DTypeQ;
using DTypeO = DTypeQ;
SinglePrefillParams<DTypeQ, DTypeKV, DTypeO> params;
using SinglePrefillParams = SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>;
SinglePrefillParams params;
params.q_ptr = static_cast<DTypeQ*>(q.data_ptr());
params.k_ptr = static_cast<DTypeKV*>(k.data_ptr());
params.v_ptr = static_cast<DTypeKV*>(v.data_ptr());
Expand Down Expand Up @@ -96,7 +97,8 @@ void single_prefill_with_kv_cache_sm90(unsigned int mask_mode_code, at::Tensor q
return DISPATCH_BOOL(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] {
return DISPATCH_BOOL(use_swa, USE_SWA, [&] {
using AttentionVariant =
std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap, StandardAttention>;
std::conditional_t<USE_LOGITS_SOFT_CAP, LogitsSoftCap<SinglePrefillParams>,
StandardAttention<SinglePrefillParams>>;
cudaError_t status =
SinglePrefillWithKVCacheDispatched<HEAD_DIM, MASK_MODE, USE_SWA, AttentionVariant>(
params, stream);
Expand Down
Loading

0 comments on commit 4ba91c0

Please sign in to comment.