diff --git a/aot_build_utils/generate_batch_paged_prefill_sm90_inst.py b/aot_build_utils/generate_batch_paged_prefill_sm90_inst.py index b5e1ebcd..35bbda41 100644 --- a/aot_build_utils/generate_batch_paged_prefill_sm90_inst.py +++ b/aot_build_utils/generate_batch_paged_prefill_sm90_inst.py @@ -18,12 +18,7 @@ import sys from pathlib import Path -from .literal_map import ( - dtype_literal, - idtype_literal, - mask_mode_literal, - pos_encoding_mode_literal, -) +from .literal_map import dtype_literal, idtype_literal, mask_mode_literal def get_cu_file_str( @@ -36,32 +31,46 @@ def get_cu_file_str( dtype_out, idtype, ): + pos_encoding_mode = None + allow_fp16_qk_reduction = None + def get_insts(attention_variant): - return "\n".join( - [ - """template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>( - Params& params, - cudaStream_t stream); - -template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>( - Params& params, - cudaStream_t stream); - -template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>( - Params& params, - cudaStream_t stream); - -template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>( - Params& params, - cudaStream_t stream); + return """ +template cudaError_t BatchPrefillWithPagedKVCacheDispatched + <{head_dim}, + {mask_mode}, + /*USE_SWA=*/true, + /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, + {attention_variant}> + (Params& params, cudaStream_t stream); + +template cudaError_t BatchPrefillWithPagedKVCacheDispatched + <{head_dim}, + {mask_mode}, + /*USE_SWA=*/true, + /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, + {attention_variant}> + (Params& params, cudaStream_t stream); + +template cudaError_t BatchPrefillWithPagedKVCacheDispatched + <{head_dim}, + {mask_mode}, + /*USE_SWA=*/false, + /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, + {attention_variant}> + (Params& params, cudaStream_t stream); + +template cudaError_t BatchPrefillWithPagedKVCacheDispatched + <{head_dim}, + {mask_mode}, + /*USE_SWA=*/false, + /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, + {attention_variant}> + (Params& params, cudaStream_t stream); """.format( - head_dim=head_dim, - pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], - allow_fp16_qk_reduction=allow_fp16_qk_reduction, - mask_mode=mask_mode_literal[int(mask_mode)], - attention_variant=attention_variant, - ) - ] + head_dim=head_dim, + mask_mode=mask_mode_literal[int(mask_mode)], + attention_variant=attention_variant, ) dtype_q = dtype_literal[dtype_q] @@ -69,7 +78,9 @@ def get_insts(attention_variant): dtype_out = dtype_literal[dtype_out] idtype = idtype_literal[idtype] - content = f"""#include + content = f""" // batch_paged_prefill_sm90 template inst +#include +#include #include #include @@ -82,9 +93,9 @@ def get_insts(attention_variant): using Params = BatchPrefillPagedParams; -{get_insts("LogitsSoftCap")} +{get_insts("LogitsSoftCap")} -{get_insts("StandardAttention")} +{get_insts("StandardAttention")} }}""" return content @@ -93,12 +104,11 @@ def get_insts(attention_variant): if __name__ == "__main__": pattern = ( r"batch_paged_prefill_head_([0-9]+)_posenc_([0-9]+)_" - r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu" + r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_" + r"dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu" ) compiled_pattern = re.compile(pattern) path = Path(sys.argv[1]) fname = path.name match = compiled_pattern.match(fname) - - with open(path, "w") as f: - f.write(get_cu_file_str(*match.groups())) + path.write_text(get_cu_file_str(*match.groups())) diff --git a/aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py b/aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py index ad53dc31..98dee8e2 100644 --- a/aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py +++ b/aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py @@ -38,39 +38,48 @@ def get_cu_file_str( ): def get_insts(attention_variant): - return "\n".join( - [ - """template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>( - Params& params, - cudaStream_t stream); - -template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, {attention_variant}>( - Params& params, - cudaStream_t stream); - -template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>( - Params& params, - cudaStream_t stream); - -template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, {attention_variant}>( - Params& params, - cudaStream_t stream); + return """ +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched + <{head_dim}, + {mask_mode}, + /*USE_SWA=*/true, + /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, + {attention_variant}>(Params& params, cudaStream_t stream); + +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched + <{head_dim}, + {mask_mode}, + /*USE_SWA=*/true, + /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, + {attention_variant}>(Params& params, cudaStream_t stream); + +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched + <{head_dim}, + {mask_mode}, + /*USE_SWA=*/false, + /*SAME_SCHEDULE_FOR_ALL_HEADS=*/true, + {attention_variant}>(Params& params, cudaStream_t stream); + +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched + <{head_dim}, + {mask_mode}, + /*USE_SWA=*/false, + /*SAME_SCHEDULE_FOR_ALL_HEADS=*/false, + {attention_variant}>(Params& params, cudaStream_t stream); """.format( head_dim=head_dim, - pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], - allow_fp16_qk_reduction=allow_fp16_qk_reduction, mask_mode=mask_mode_literal[int(mask_mode)], attention_variant=attention_variant, ) - ] - ) dtype_q = dtype_literal[dtype_q] dtype_kv = dtype_literal[dtype_kv] dtype_out = dtype_literal[dtype_out] idtype = idtype_literal[idtype] - content = f"""#include + content = f""" // batch_ragged_prefill_sm90 template inst +#include +#include #include #include @@ -83,9 +92,9 @@ def get_insts(attention_variant): using Params = BatchPrefillRaggedParams; -{get_insts("LogitsSoftCap")} +{get_insts("LogitsSoftCap")} -{get_insts("StandardAttention")} +{get_insts("StandardAttention")} }} """ diff --git a/aot_build_utils/generate_single_prefill_sm90_inst.py b/aot_build_utils/generate_single_prefill_sm90_inst.py index 13e57999..e37fce25 100644 --- a/aot_build_utils/generate_single_prefill_sm90_inst.py +++ b/aot_build_utils/generate_single_prefill_sm90_inst.py @@ -30,7 +30,9 @@ def get_cu_file_str( dtype_kv, dtype_out, ): - content = """#include + content = """ // single_prefill_sm90 template inst +#include +#include #include #include @@ -42,31 +44,32 @@ def get_cu_file_str( using Params = SinglePrefillParams; -template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, LogitsSoftCap>( - Params& params, - cudaStream_t stream); +template cudaError_t SinglePrefillWithKVCacheDispatched + <{head_dim}, {mask_mode}, /*USE_SWA=*/true, LogitsSoftCap> + (Params& params, cudaStream_t stream); -template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, LogitsSoftCap>( - Params& params, - cudaStream_t stream); +template cudaError_t SinglePrefillWithKVCacheDispatched + <{head_dim}, {mask_mode}, /*USE_SWA=*/false, LogitsSoftCap> + (Params& params, cudaStream_t stream); -template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, StandardAttention>( - Params& params, - cudaStream_t stream); +template cudaError_t SinglePrefillWithKVCacheDispatched + <{head_dim}, {mask_mode}, /*USE_SWA=*/true, StandardAttention> + (Params& params, cudaStream_t stream); + +template cudaError_t SinglePrefillWithKVCacheDispatched + <{head_dim}, {mask_mode}, /*USE_SWA=*/false, StandardAttention> + (Params& params, cudaStream_t stream); -template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, StandardAttention>( - Params& params, - cudaStream_t stream); }} """.format( head_dim=head_dim, - pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], - allow_fp16_qk_reduction=allow_fp16_qk_reduction, + # pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], + # allow_fp16_qk_reduction=allow_fp16_qk_reduction, mask_mode=mask_mode_literal[int(mask_mode)], dtype_q=dtype_literal[dtype_q], dtype_kv=dtype_literal[dtype_kv], dtype_out=dtype_literal[dtype_out], - use_custom_mask="true" if int(mask_mode) == 2 else "false", + # use_custom_mask="true" if int(mask_mode) == 2 else "false", ) return content @@ -81,5 +84,4 @@ def get_cu_file_str( path = Path(sys.argv[1]) fname = path.name match = compiled_pattern.match(fname) - with open(path, "w") as f: - f.write(get_cu_file_str(*match.groups())) + path.write_text(get_cu_file_str(*match.groups())) diff --git a/csrc/batch_prefill_sm90.cu b/csrc/batch_prefill_sm90.cu index ea53eb12..7eb44221 100644 --- a/csrc/batch_prefill_sm90.cu +++ b/csrc/batch_prefill_sm90.cu @@ -29,16 +29,14 @@ namespace flashinfer { template -cudaError_t BatchPrefillWithRaggedKVCacheDispatched( - BatchPrefillRaggedParams& params, cudaStream_t stream); + bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant> +cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::ParamsT& params, + cudaStream_t stream); template -cudaError_t BatchPrefillWithPagedKVCacheDispatched( - BatchPrefillPagedParams& params, cudaStream_t stream); + bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant> +cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::ParamsT& params, + cudaStream_t stream); } // namespace flashinfer @@ -110,7 +108,8 @@ void BatchPrefillWithRaggedKVCacheSM90Run( using DTypeO = DTypeQ; using IdType = int32_t; - BatchPrefillRaggedParams params; + using BatchPrefillRaggedParams = BatchPrefillRaggedParams; + BatchPrefillRaggedParams params; params.q_ptr = static_cast(q.data_ptr()); params.k_ptr = static_cast(k.data_ptr()); @@ -160,7 +159,8 @@ void BatchPrefillWithRaggedKVCacheSM90Run( return DISPATCH_BOOL(use_swa, USE_SWA, [&] { return DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] { using AttentionVariant = - std::conditional_t; + std::conditional_t, + StandardAttention>; cudaError_t status = BatchPrefillWithRaggedKVCacheDispatched< HEAD_DIM, MASK_MODE, USE_SWA, SAME_SCHEDULER_FOR_ALL_HEADS, AttentionVariant>( params, stream); @@ -220,7 +220,8 @@ void BatchPrefillWithPagedKVCacheSM90Run( using DTypeO = DTypeQ; using IdType = int32_t; - BatchPrefillPagedParams params; + using BatchPrefillPagedParams = BatchPrefillPagedParams; + BatchPrefillPagedParams params; params.q_ptr = static_cast(q.data_ptr()); params.k_ptr = static_cast(paged_k_cache.data_ptr()); @@ -272,7 +273,8 @@ void BatchPrefillWithPagedKVCacheSM90Run( return DISPATCH_BOOL(use_swa, USE_SWA, [&] { return DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] { using AttentionVariant = - std::conditional_t; + std::conditional_t, + StandardAttention>; cudaError_t status = BatchPrefillWithPagedKVCacheDispatched< HEAD_DIM, MASK_MODE, USE_SWA, SAME_SCHEDULER_FOR_ALL_HEADS, AttentionVariant>( params, stream); diff --git a/csrc/single_prefill_sm90.cu b/csrc/single_prefill_sm90.cu index 0b254b8c..41b73e9f 100644 --- a/csrc/single_prefill_sm90.cu +++ b/csrc/single_prefill_sm90.cu @@ -28,8 +28,8 @@ namespace flashinfer { template -cudaError_t SinglePrefillWithKVCacheDispatched(SinglePrefillParams& params, + typename AttentionVariant> +cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::ParamsT& params, cudaStream_t stream); } // namespace flashinfer @@ -59,7 +59,8 @@ void single_prefill_with_kv_cache_sm90(unsigned int mask_mode_code, at::Tensor q using DTypeQ = cutlass_dtype_t; using DTypeKV = DTypeQ; using DTypeO = DTypeQ; - SinglePrefillParams params; + using SinglePrefillParams = SinglePrefillParams; + SinglePrefillParams params; params.q_ptr = static_cast(q.data_ptr()); params.k_ptr = static_cast(k.data_ptr()); params.v_ptr = static_cast(v.data_ptr()); @@ -96,7 +97,8 @@ void single_prefill_with_kv_cache_sm90(unsigned int mask_mode_code, at::Tensor q return DISPATCH_BOOL(use_logits_soft_cap, USE_LOGITS_SOFT_CAP, [&] { return DISPATCH_BOOL(use_swa, USE_SWA, [&] { using AttentionVariant = - std::conditional_t; + std::conditional_t, + StandardAttention>; cudaError_t status = SinglePrefillWithKVCacheDispatched( params, stream); diff --git a/flashinfer/jit/attention.py b/flashinfer/jit/attention.py index 1ead1215..92115395 100644 --- a/flashinfer/jit/attention.py +++ b/flashinfer/jit/attention.py @@ -26,6 +26,8 @@ from .batch_prefill_sm90_templ import ( batch_prefill_sm90_suffix, batch_prefill_sm90_templ, + customizable_batch_prefill_sm90_suffix, + customizable_batch_prefill_sm90_templ, ) from .batch_prefill_templ import batch_prefill_suffix, batch_prefill_templ from .core import load_cuda_ops, sm90a_nvcc_flags @@ -38,6 +40,8 @@ from .single_prefill_sm90_templ import ( single_prefill_sm90_suffix, single_prefill_sm90_templ, + customizable_single_prefill_sm90_suffix, + customizable_single_prefill_sm90_templ, ) from .single_prefill_templ import ( customizable_single_prefill_templ, @@ -101,11 +105,11 @@ def get_single_decode_uri( ) -def gen_single_decode_module(*args): +def gen_single_decode_module(*args, **kwargs): gen_directory = FLASHINFER_GEN_SRC_DIR os.makedirs(gen_directory, exist_ok=True) - uri = get_single_decode_uri(*args) - sources = get_single_decode_sources(*args) + uri = get_single_decode_uri(*args, **kwargs) + sources = get_single_decode_sources(*args, **kwargs) source_paths = [] for suffix, source in zip(single_decode_suffix, sources): path = gen_directory / f"{uri}{suffix}" @@ -161,10 +165,10 @@ def get_batch_decode_uri( ) -def gen_batch_decode_module(*args): +def gen_batch_decode_module(*args, **kwargs): gen_directory = FLASHINFER_GEN_SRC_DIR - uri = get_batch_decode_uri(*args) - sources = get_batch_decode_sources(*args) + uri = get_batch_decode_uri(*args, **kwargs) + sources = get_batch_decode_sources(*args, **kwargs) source_paths = [] for suffix, source in zip(batch_decode_suffix, sources): path = gen_directory / f"{uri}{suffix}" @@ -218,10 +222,10 @@ def get_batch_decode_mla_uri( ) -def gen_batch_decode_mla_module(*args): +def gen_batch_decode_mla_module(*args, **kwargs): gen_directory = FLASHINFER_GEN_SRC_DIR - uri = get_batch_decode_mla_uri(*args) - sources = get_batch_decode_mla_sources(*args) + uri = get_batch_decode_mla_uri(*args, **kwargs) + sources = get_batch_decode_mla_sources(*args, **kwargs) source_paths = [] for suffix, source in zip(batch_decode_mla_suffix, sources): path = gen_directory / f"{uri}{suffix}" @@ -306,14 +310,14 @@ def get_single_prefill_uri( ) -def get_single_prefill_sm90_uri(*args): - return get_single_prefill_uri(*args) + "_sm90" +def get_single_prefill_sm90_uri(*args, **kwargs): + return get_single_prefill_uri(*args, **kwargs) + "_sm90" -def gen_single_prefill_module(*args): +def gen_single_prefill_module(*args, **kwargs): gen_directory = FLASHINFER_GEN_SRC_DIR - uri = get_single_prefill_uri(*args) - sources = get_single_prefill_sources(*args) + uri = get_single_prefill_uri(*args, **kwargs) + sources = get_single_prefill_sources(*args, **kwargs) source_paths = [] for suffix, source in zip(single_prefill_suffix, sources): path = gen_directory / f"{uri}{suffix}" @@ -323,10 +327,10 @@ def gen_single_prefill_module(*args): return load_cuda_ops(uri, source_paths) -def gen_single_prefill_sm90_module(*args): +def gen_single_prefill_sm90_module(*args, **kwargs): gen_directory = FLASHINFER_GEN_SRC_DIR - uri = get_single_prefill_sm90_uri(*args) - sources = get_single_prefill_sm90_sources(*args) + uri = get_single_prefill_sm90_uri(*args, **kwargs) + sources = get_single_prefill_sm90_sources(*args, **kwargs) source_paths = [] for suffix, source in zip(single_prefill_sm90_suffix, sources): path = gen_directory / f"{uri}{suffix}" @@ -418,14 +422,14 @@ def get_batch_prefill_uri( ) -def get_batch_prefill_sm90_uri(*args): - return get_batch_prefill_uri(*args) + "_sm90" +def get_batch_prefill_sm90_uri(*args, **kwargs): + return get_batch_prefill_uri(*args, **kwargs) + "_sm90" -def gen_batch_prefill_module(*args): +def gen_batch_prefill_module(*args, **kwargs): gen_directory = FLASHINFER_GEN_SRC_DIR - uri = get_batch_prefill_uri(*args) - sources = get_batch_prefill_sources(*args) + uri = get_batch_prefill_uri(*args, **kwargs) + sources = get_batch_prefill_sources(*args, **kwargs) source_paths = [] for suffix, source in zip(batch_prefill_suffix, sources): path = gen_directory / f"{uri}{suffix}" @@ -435,10 +439,10 @@ def gen_batch_prefill_module(*args): return load_cuda_ops(uri, source_paths) -def gen_batch_prefill_sm90_module(*args): +def gen_batch_prefill_sm90_module(*args, **kwargs): gen_directory = FLASHINFER_GEN_SRC_DIR - uri = get_batch_prefill_sm90_uri(*args) - sources = get_batch_prefill_sm90_sources(*args) + uri = get_batch_prefill_sm90_uri(*args, **kwargs) + sources = get_batch_prefill_sm90_sources(*args, **kwargs) source_paths = [] for suffix, source in zip(batch_prefill_sm90_suffix, sources): path = gen_directory / f"{uri}{suffix}" @@ -610,9 +614,9 @@ def get_customize_single_prefill_sources( ) -def gen_customize_single_decode_module(module_name, *args): +def gen_customize_single_decode_module(module_name, *args, **kwargs): gen_directory = FLASHINFER_GEN_SRC_DIR - sources = get_customize_single_decode_sources(*args) + sources = get_customize_single_decode_sources(*args, **kwargs) source_paths = [] for suffix, source in zip(single_decode_suffix, sources): path = gen_directory / f"{module_name}{suffix}" @@ -622,9 +626,9 @@ def gen_customize_single_decode_module(module_name, *args): return load_cuda_ops(module_name, source_paths) -def gen_customize_single_prefill_module(module_name, *args): +def gen_customize_single_prefill_module(module_name, *args, **kwargs): gen_directory = FLASHINFER_GEN_SRC_DIR - sources = get_customize_single_prefill_sources(*args) + sources = get_customize_single_prefill_sources(*args, **kwargs) source_paths = [] for suffix, source in zip(single_prefill_suffix, sources): path = gen_directory / f"{module_name}{suffix}" @@ -632,3 +636,165 @@ def gen_customize_single_prefill_module(module_name, *args): write_if_different(path, source) return load_cuda_ops(module_name, source_paths) + + +def get_customize_batch_prefill_sm90_sources( + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + dtype_idx: torch.dtype, + head_dim: int, + additional_input_tensor_var_names: List[str], + additional_input_tensor_var_types: List[str], + additional_input_scalar_var_names: List[str], + additional_input_scalar_var_types: List[str], + variant_name: str, + variant_decl: str, +) -> List[str]: + additional_params_decl = ";\n ".join( + [ + f"{dtype}* {var}_ptr" + for dtype, var in zip( + additional_input_tensor_var_types, additional_input_tensor_var_names + ) + ] + + [ + f"{dtype} {var}" + for dtype, var in zip( + additional_input_scalar_var_types, additional_input_scalar_var_names + ) + ] + ) + additional_func_params = ",\n ".join( + [f"at::Tensor {var}" for var in additional_input_tensor_var_names] + + [ + f"{dtype} {var}" + for dtype, var in zip( + additional_input_scalar_var_types, additional_input_scalar_var_names + ) + ] + ) + additional_params_setter = ";\n ".join( + [ + f"params.{var}_ptr = static_cast<{dtype}*>({var}.data_ptr())" + for dtype, var in zip( + additional_input_tensor_var_types, additional_input_tensor_var_names + ) + ] + + [f"params.{var} = {var}" for var in additional_input_scalar_var_names] + ) + + if additional_func_params: + additional_func_params += "," + + return render_templates( + customizable_batch_prefill_sm90_templ, + { + "dtype_q": dtype_map[dtype_q], + "dtype_kv": dtype_map[dtype_kv], + "dtype_o": dtype_map[dtype_o], + "dtype_idx": dtype_map[dtype_idx], + "head_dim": head_dim, + "variant_decl": variant_decl, + "variant_name": variant_name, + "use_sliding_window": "false", + "additional_params_decl": additional_params_decl, + "additional_params_setter": additional_params_setter, + "additional_func_params": additional_func_params, + }, + ) + + +def gen_customize_batch_prefill_sm90_module(module_name, *args, **kwargs): + gen_directory = FLASHINFER_GEN_SRC_DIR + sources = get_customize_batch_prefill_sm90_sources(*args, **kwargs) + source_paths = [] + for suffix, source in zip(customizable_batch_prefill_sm90_suffix, sources): + path = gen_directory / f"{module_name}{suffix}" + source_paths.append(path) + write_if_different(path, source) + return load_cuda_ops( + module_name, + source_paths, + extra_cuda_cflags=["-gencode=arch=compute_90a,code=sm_90a"], + ) + + +def get_customize_single_prefill_sm90_sources( + dtype_q: torch.dtype, + dtype_kv: torch.dtype, + dtype_o: torch.dtype, + head_dim: int, + additional_input_tensor_var_names: List[str], + additional_input_tensor_var_types: List[str], + additional_input_scalar_var_names: List[str], + additional_input_scalar_var_types: List[str], + variant_name: str, + variant_decl: str, +) -> List[str]: + additional_params_decl = ";\n ".join( + [ + f"{dtype}* {var}_ptr" + for dtype, var in zip( + additional_input_tensor_var_types, additional_input_tensor_var_names + ) + ] + + [ + f"{dtype} {var}" + for dtype, var in zip( + additional_input_scalar_var_types, additional_input_scalar_var_names + ) + ] + ) + additional_func_params = ",\n ".join( + [f"at::Tensor {var}" for var in additional_input_tensor_var_names] + + [ + f"{dtype} {var}" + for dtype, var in zip( + additional_input_scalar_var_types, additional_input_scalar_var_names + ) + ] + ) + additional_params_setter = ";\n ".join( + [ + f"params.{var}_ptr = static_cast<{dtype}*>({var}.data_ptr())" + for dtype, var in zip( + additional_input_tensor_var_types, additional_input_tensor_var_names + ) + ] + + [f"params.{var} = {var}" for var in additional_input_scalar_var_names] + ) + + if additional_func_params: + additional_func_params += "," + + return render_templates( + customizable_single_prefill_sm90_templ, + { + "dtype_q": dtype_map[dtype_q], + "dtype_kv": dtype_map[dtype_kv], + "dtype_o": dtype_map[dtype_o], + "head_dim": head_dim, + "variant_decl": variant_decl, + "variant_name": variant_name, + "use_sliding_window": "false", + "additional_params_decl": additional_params_decl, + "additional_params_setter": additional_params_setter, + "additional_func_params": additional_func_params, + }, + ) + + +def gen_customize_single_prefill_sm90_module(module_name, *args, **kwargs): + gen_directory = FLASHINFER_GEN_SRC_DIR + sources = get_customize_single_prefill_sm90_sources(*args, **kwargs) + source_paths = [] + for suffix, source in zip(customizable_single_prefill_sm90_suffix, sources): + path = gen_directory / f"{module_name}{suffix}" + source_paths.append(path) + write_if_different(path, source) + return load_cuda_ops( + module_name, + source_paths, + extra_cuda_cflags=["-gencode=arch=compute_90a,code=sm_90a"], + ) diff --git a/flashinfer/jit/batch_prefill_sm90_templ.py b/flashinfer/jit/batch_prefill_sm90_templ.py index e36e8ded..2748b651 100644 --- a/flashinfer/jit/batch_prefill_sm90_templ.py +++ b/flashinfer/jit/batch_prefill_sm90_templ.py @@ -14,149 +14,716 @@ limitations under the License. """ +batch_prefill_plan_func = r"""std::vector BatchPrefillWithKVCacheSM90Plan( + bool causal, + at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, + at::Tensor page_locked_int_workspace_buffer, + at::Tensor qo_indptr, + at::Tensor kv_indptr, + at::Tensor kv_len_arr, + unsigned int total_num_rows, + unsigned int batch_size, + unsigned int num_qo_heads, + unsigned int num_kv_heads, + unsigned int page_size, + bool enable_cuda_graph, + int64_t cuda_stream)""" + + +batch_prefill_plan_impl = f"""// _plan.cu +#include +#include "pytorch_extension_utils.h" + +using namespace flashinfer; + +{batch_prefill_plan_func} {{ + + size_t float_workspace_size_in_bytes = + float_workspace_buffer.numel() * float_workspace_buffer.element_size(); + size_t int_workspace_size_in_bytes = + int_workspace_buffer.numel() * int_workspace_buffer.element_size(); + + PrefillPlanSM90Info plan_info; + cudaStream_t stream = reinterpret_cast(cuda_stream); + + cudaError_t status = PrefillSM90Plan( + float_workspace_buffer.data_ptr(), + float_workspace_size_in_bytes, + int_workspace_buffer.data_ptr(), + page_locked_int_workspace_buffer.data_ptr(), + int_workspace_size_in_bytes, + plan_info, + qo_indptr.data_ptr<{{{{ dtype_idx }}}}>(), + kv_indptr.data_ptr<{{{{ dtype_idx }}}}>(), + kv_len_arr.data_ptr<{{{{ dtype_idx }}}}>(), + total_num_rows, + batch_size, + num_qo_heads, + num_kv_heads, + {{{{ head_dim }}}}, + page_size, + causal, + enable_cuda_graph, + sizeof({{{{dtype_o}}}}), + stream); + + TORCH_CHECK(status == cudaSuccess, + "PrefillSM90Plan failed with error: ", cudaGetErrorString(status)); + + return plan_info.ToVector(); +}} +""" + + batch_prefill_sm90_suffix = [ "_plan.cu", - *[f"_ragged_kernel_mask_{mask_mode}.cu" for mask_mode in [0, 1, 2]], + "_ragged_kernel_mask_0.cu", + "_ragged_kernel_mask_1.cu", + "_ragged_kernel_mask_2.cu", "_ragged_run.cu", - *[f"_paged_kernel_mask_{mask_mode}.cu" for mask_mode in [0, 1, 2]], + "_paged_kernel_mask_0.cu", + "_paged_kernel_mask_r.cu", + "_paged_kernel_mask_2.cu", "_paged_run.cu", "_pybind.cc", ] +batch_prefill_ragged_func_templ = r"""void BatchPrefillWithRaggedKVCacheSM90Run( + unsigned int mask_mode_code, + at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, + std::vector plan_info_vec, + at::Tensor q, + at::Tensor k, + at::Tensor v, + std::optional maybe_custom_mask, + std::optional maybe_alibi_slopes, + at::Tensor qo_indptr, + at::Tensor kv_indptr, + std::optional maybe_qk_indptr, + at::Tensor o, + unsigned int layout, + int32_t window_left, + float logits_soft_cap, + float sm_scale, + float rope_scale, + float rope_theta, + std::optional maybe_lse, + {{ additional_func_params }} + int64_t cuda_stream)""" + + def ragged_prefill_sm90_inst_templ(mask_mode: str) -> str: - return ( - r"""#include + return f"""// batch_prefill_ragged_sm90 template inst +#include +#include #include #include #include -namespace flashinfer { +namespace flashinfer {{ -using DTypeQ = cutlass_dtype_t<{{dtype_q}}>; -using DTypeKV = cutlass_dtype_t<{{dtype_kv}}>; -using DTypeO = cutlass_dtype_t<{{dtype_o}}>; -using IdType = cutlass_dtype_t<{{dtype_idx}}>; +using DTypeQ = cutlass_dtype_t<{{{{ dtype_q }}}}>; +using DTypeKV = cutlass_dtype_t<{{{{ dtype_kv }}}}>; +using DTypeO = cutlass_dtype_t<{{{{ dtype_o }}}}>; +using IdType = cutlass_dtype_t<{{{{ dtype_idx }}}}>; -using RaggedParams = - BatchPrefillRaggedParams; -using AttentionVariant = std::conditional_t<{{use_logits_soft_cap}}, LogitsSoftCap, StandardAttention>; +using RaggedParams = BatchPrefillRaggedParams; +using AttentionVariant = std::conditional_t< + {{{{use_logits_soft_cap}}}}, LogitsSoftCap, StandardAttention>; + +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched + <{{{{ head_dim }}}}, + {mask_mode}, + /*USE_SWA=*/true, + /*SAME_SCHEDULER_FOR_ALL_HEADS=*/true, + AttentionVariant>(RaggedParams& params, cudaStream_t stream); + +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched + <{{{{ head_dim }}}}, + {mask_mode}, + /*USE_SWA=*/true, + /*SAME_SCHEDULER_FOR_ALL_HEADS=*/false, + AttentionVariant>(RaggedParams& params, cudaStream_t stream); + +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched + <{{{{ head_dim }}}}, + {mask_mode}, + /*USE_SWA=*/false, + /*SAME_SCHEDULER_FOR_ALL_HEADS=*/true, + AttentionVariant>(RaggedParams& params, cudaStream_t stream); + +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched + <{{{{ head_dim }}}}, + {mask_mode}, + /*USE_SWA=*/false, + /*SAME_SCHEDULER_FOR_ALL_HEADS=*/false, + AttentionVariant>(RaggedParams& params, cudaStream_t stream); +}}""" + + +batch_prefill_paged_func_templ = r"""void BatchPrefillWithPagedKVCacheSM90Run( + unsigned int mask_mode_code, + at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, + std::vector plan_info_vec, + at::Tensor q, + at::Tensor paged_k_cache, + at::Tensor paged_v_cache, + std::optional maybe_custom_mask, + std::optional maybe_alibi_slopes, + at::Tensor qo_indptr, + at::Tensor paged_kv_indptr, + at::Tensor paged_kv_indices, + at::Tensor paged_kv_last_page_len, + std::optional maybe_qk_indptr, + at::Tensor o, + unsigned int layout, + int32_t window_left, + float logits_soft_cap, + float sm_scale, + float rope_scale, + float rope_theta, + std::optional maybe_lse, + {{ additional_func_params }} + int64_t cuda_stream)""" -template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{{ head_dim }}, """ - + mask_mode - + r""", /*USE_SWA=*/true, /*SAME_SCHEDULER_FOR_ALL_HEADS=*/true, AttentionVariant>( - RaggedParams& params, - cudaStream_t stream); -template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{{ head_dim }}, """ - + mask_mode - + r""", /*USE_SWA=*/false, /*SAME_SCHEDULER_FOR_ALL_HEADS=*/true, AttentionVariant>( - RaggedParams& params, - cudaStream_t stream); +def paged_prefill_sm90_inst_templ(mask_mode: str) -> str: + return f"""// batch_prefill_paged_sm90 template inst +#include +#include +#include +#include +#include -template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{{ head_dim }}, """ - + mask_mode - + r""", /*USE_SWA=*/true, /*SAME_SCHEDULER_FOR_ALL_HEADS=*/false, AttentionVariant>( - RaggedParams& params, - cudaStream_t stream); +namespace flashinfer {{ -template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{{ head_dim }}, """ - + mask_mode - + r""", /*USE_SWA=*/false, /*SAME_SCHEDULER_FOR_ALL_HEADS=*/false, AttentionVariant>( - RaggedParams& params, - cudaStream_t stream); +using DTypeQ = cutlass_dtype_t<{{{{ dtype_q }}}}>; +using DTypeKV = cutlass_dtype_t<{{{{ dtype_kv }}}}>; +using DTypeO = cutlass_dtype_t<{{{{ dtype_o }}}}>; +using IdType = cutlass_dtype_t<{{{{ dtype_idx }}}}>; -}""" - ) +using PagedParams = BatchPrefillPagedParams; +using AttentionVariant = std::conditional_t< + {{{{use_logits_soft_cap}}}}, LogitsSoftCap, StandardAttention>; + +template cudaError_t BatchPrefillWithPagedKVCacheDispatched + <{{{{ head_dim }}}}, + {mask_mode}, + /*USE_SWA=*/true, + /*SAME_SCHEDULER_FOR_ALL_HEADS=*/true, + AttentionVariant>(PagedParams& params, cudaStream_t stream); + +template cudaError_t BatchPrefillWithPagedKVCacheDispatched + <{{{{ head_dim }}}}, + {mask_mode}, + /*USE_SWA=*/true, + /*SAME_SCHEDULER_FOR_ALL_HEADS=*/false, + AttentionVariant>(PagedParams& params, cudaStream_t stream); + +template cudaError_t BatchPrefillWithPagedKVCacheDispatched + <{{{{ head_dim }}}}, + {mask_mode}, + /*USE_SWA=*/false, + /*SAME_SCHEDULER_FOR_ALL_HEADS=*/true, + AttentionVariant>(PagedParams& params, cudaStream_t stream); + +template cudaError_t BatchPrefillWithPagedKVCacheDispatched + <{{{{ head_dim }}}}, + {mask_mode}, + /*USE_SWA=*/false, + /*SAME_SCHEDULER_FOR_ALL_HEADS=*/false, + AttentionVariant>(PagedParams& params, cudaStream_t stream); + +}}""" -def paged_prefill_sm90_inst_templ(mask_mode: str) -> str: - return ( - r"""#include +batch_prefill_sm90_templ = [ + batch_prefill_plan_impl, + ragged_prefill_sm90_inst_templ("MaskMode::kNone"), + ragged_prefill_sm90_inst_templ("MaskMode::kCausal"), + ragged_prefill_sm90_inst_templ("MaskMode::kCustom"), + f"""// _ragged_run.cu +#include #include +#include +#include +#include +#include #include -#include +#include + +#include "pytorch_extension_utils.h" -namespace flashinfer { +using namespace flashinfer; -using DTypeQ = cutlass_dtype_t<{{dtype_q}}>; -using DTypeKV = cutlass_dtype_t<{{dtype_kv}}>; -using DTypeO = cutlass_dtype_t<{{dtype_o}}>; -using IdType = cutlass_dtype_t<{{dtype_idx}}>; +using DTypeQ = cutlass_dtype_t<{{{{ dtype_q }}}}>; +using DTypeKV = cutlass_dtype_t<{{{{ dtype_kv }}}}>; +using DTypeO = cutlass_dtype_t<{{{{ dtype_o }}}}>; +using IdType = cutlass_dtype_t<{{{{ dtype_idx }}}}>; -using PagedParams = BatchPrefillPagedParams; -using AttentionVariant = std::conditional_t<{{use_logits_soft_cap}}, LogitsSoftCap, StandardAttention>; +using RaggedParams = BatchPrefillRaggedParams; -template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{{ head_dim }}, """ - + mask_mode - + r""", /*USE_SWA=*/true, /*SAME_SCHEDULER_FOR_ALL_HEADS=*/true, AttentionVariant>( - PagedParams& params, - cudaStream_t stream); +{batch_prefill_ragged_func_templ} {{ + PrefillPlanSM90Info plan_info; + plan_info.FromVector(plan_info_vec); -template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{{ head_dim }}, """ - + mask_mode - + r""", /*USE_SWA=*/false, /*SAME_SCHEDULER_FOR_ALL_HEADS=*/true, AttentionVariant>( - PagedParams& params, - cudaStream_t stream); + if (maybe_lse) {{ + const auto& lse = *maybe_lse; + TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); + TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); + }} -template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{{ head_dim }}, """ - + mask_mode - + r""", /*USE_SWA=*/true, /*SAME_SCHEDULER_FOR_ALL_HEADS=*/false, AttentionVariant>( - PagedParams& params, - cudaStream_t stream); + void* float_buffer_ptr = float_workspace_buffer.data_ptr(); + void* int_buffer_ptr = int_workspace_buffer.data_ptr(); -template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{{ head_dim }}, """ - + mask_mode - + r""", /*USE_SWA=*/false, /*SAME_SCHEDULER_FOR_ALL_HEADS=*/false, AttentionVariant>( - PagedParams& params, - cudaStream_t stream); + auto q_scalar_type = q.scalar_type(); -}""" - ) + QKVLayout kv_layout = static_cast(layout); + cudaStream_t stream = reinterpret_cast(cuda_stream); + const MaskMode mask_mode = static_cast(mask_mode_code); + RaggedParams params; + + params.q_ptr = static_cast(q.data_ptr()); + params.k_ptr = static_cast(k.data_ptr()); + params.v_ptr = static_cast(v.data_ptr()); + params.o_ptr = static_cast(o.data_ptr()); + params.lse_ptr = maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr; + params.q_stride_n = q.stride(0); + params.q_stride_h = q.stride(1); + params.o_stride_n = o.stride(0); + params.o_stride_h = o.stride(1); + if (kv_layout == QKVLayout::kNHD) {{ + params.k_stride_n = k.stride(0); + params.k_stride_h = k.stride(1); + params.v_stride_n = v.stride(0); + params.v_stride_h = v.stride(1); + }} else {{ + params.k_stride_h = k.stride(0); + params.k_stride_n = k.stride(1); + params.v_stride_h = v.stride(0); + params.v_stride_n = v.stride(1); + }} + params.nnz_qo = q.size(0); + params.nnz_kv = k.size(0); + params.head_dim = {{{{ head_dim }}}}; + params.num_qo_heads = q.size(1); + params.num_kv_heads = k.size(1); + params.group_size = params.num_qo_heads / params.num_kv_heads; + params.window_left = window_left; + params.logits_soft_cap = logits_soft_cap; + params.sm_scale_log2 = sm_scale * math::log2e; + params.causal = mask_mode_code == 1; + params.qo_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); + params.qo_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_indptr_offset); + params.kv_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_indptr_offset); + params.qo_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_len_offset); + params.kv_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_len_offset); + params.head_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.head_indices_offset); + params.work_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.work_indptr_offset); + + bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads; + DISPATCH_MASK_MODE(mask_mode, MASK_MODE, {{ + DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] {{ + using AttentionVariant = + std::conditional_t<{{{{ use_logits_soft_cap }}}}, + LogitsSoftCap, + StandardAttention>; + cudaError_t status = + BatchPrefillWithRaggedKVCacheDispatched + <{{{{ head_dim }}}}, + MASK_MODE, + {{{{ use_sliding_window }}}}, + SAME_SCHEDULER_FOR_ALL_HEADS, + AttentionVariant>(params, stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithRaggedKVCacheSM90Run failed with error: ", + cudaGetErrorString(status)); + return true; + }}); + }}); +}} +""", + paged_prefill_sm90_inst_templ("MaskMode::kNone"), + paged_prefill_sm90_inst_templ("MaskMode::kCausal"), + paged_prefill_sm90_inst_templ("MaskMode::kCustom"), + f"""// _paged_run.cu +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include -batch_prefill_sm90_templ = [ - r"""#include #include "pytorch_extension_utils.h" using namespace flashinfer; -std::vector BatchPrefillWithKVCacheSM90Plan( - bool causal, at::Tensor float_workspace_buffer, - at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer, - at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int total_num_rows, - unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, - bool enable_cuda_graph, int64_t cuda_stream) { - size_t float_workspace_size_in_bytes = - float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); - size_t int_workspace_size_in_bytes = - int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); +using DTypeQ = cutlass_dtype_t<{{{{ dtype_q }}}}>; +using DTypeKV = cutlass_dtype_t<{{{{ dtype_kv }}}}>; +using DTypeO = cutlass_dtype_t<{{{{ dtype_o }}}}>; +using IdType = cutlass_dtype_t<{{{{ dtype_idx }}}}>; +using PagedParams = BatchPrefillPagedParams; + +{batch_prefill_paged_func_templ} {{ PrefillPlanSM90Info plan_info; + plan_info.FromVector(plan_info_vec); + + if (maybe_lse) {{ + const auto& lse = *maybe_lse; + TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); + TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); + }} + QKVLayout kv_layout = static_cast(layout); + unsigned int num_kv_heads, page_size; + unsigned int head_dim = q.size(2); + if (kv_layout == QKVLayout::kHND) {{ + num_kv_heads = paged_k_cache.size(1); + page_size = paged_k_cache.size(2); + }} else {{ + page_size = paged_k_cache.size(1); + num_kv_heads = paged_k_cache.size(2); + }} + + void* float_buffer_ptr = float_workspace_buffer.data_ptr(); + void* int_buffer_ptr = int_workspace_buffer.data_ptr(); + + auto q_scalar_type = q.scalar_type(); + cudaStream_t stream = reinterpret_cast(cuda_stream); + const MaskMode mask_mode = static_cast(mask_mode_code); - cudaError_t status = PrefillSM90Plan( - float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, - int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), - int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<{{ dtype_idx }}>(), - kv_indptr.data_ptr<{{ dtype_idx }}>(), kv_len_arr.data_ptr<{{ dtype_idx }}>(), - total_num_rows, batch_size, num_qo_heads, num_kv_heads, {{ head_dim }}, page_size, - causal, enable_cuda_graph, sizeof({{dtype_o}}), stream); + PagedParams params; - TORCH_CHECK(status == cudaSuccess, - "PrefillSM90Plan failed with error: ", cudaGetErrorString(status)); + params.q_ptr = static_cast(q.data_ptr()); + params.k_ptr = static_cast(paged_k_cache.data_ptr()); + params.v_ptr = static_cast(paged_v_cache.data_ptr()); + params.o_ptr = static_cast(o.data_ptr()); + params.lse_ptr = maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr; + params.q_stride_n = q.stride(0); + params.q_stride_h = q.stride(1); + params.o_stride_n = o.stride(0); + params.o_stride_h = o.stride(1); + if (kv_layout == QKVLayout::kNHD) {{ + // (num_pages, page_size, num_heads, head_dim) + params.k_stride_n = paged_k_cache.stride(1); + params.k_stride_h = paged_k_cache.stride(2); + params.v_stride_n = paged_v_cache.stride(1); + params.v_stride_h = paged_v_cache.stride(2); + }} else {{ + // (num_pages, num_heads, page_size, head_dim) + params.k_stride_h = paged_k_cache.stride(1); + params.k_stride_n = paged_k_cache.stride(2); + params.v_stride_h = paged_v_cache.stride(1); + params.v_stride_n = paged_v_cache.stride(2); + }} + params.nnz_qo = q.size(0); + params.head_dim = head_dim; + params.num_qo_heads = q.size(1); + params.num_kv_heads = num_kv_heads; + params.group_size = params.num_qo_heads / num_kv_heads; + params.page_size = page_size; + params.window_left = window_left; + params.logits_soft_cap = logits_soft_cap; + params.sm_scale_log2 = sm_scale * math::log2e; + params.causal = mask_mode_code == 1; + params.qo_tile_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); + params.qo_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_indptr_offset); + params.kv_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_indptr_offset); + params.qo_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_len_offset); + params.kv_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_len_offset); + params.head_indices = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.head_indices_offset); + params.work_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.work_indptr_offset); + params.kv_indices = static_cast(paged_kv_indices.data_ptr()); - return plan_info.ToVector(); -} + bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads; + DISPATCH_MASK_MODE(mask_mode, MASK_MODE, {{ + DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] {{ + using AttentionVariant = + std::conditional_t<{{{{ use_logits_soft_cap }}}}, + LogitsSoftCap, + StandardAttention>; + cudaError_t status = + BatchPrefillWithPagedKVCacheDispatched + <{{{{ head_dim }}}}, + MASK_MODE, + {{{{ use_sliding_window }}}}, + SAME_SCHEDULER_FOR_ALL_HEADS, + AttentionVariant>(params, stream); + TORCH_CHECK(status == cudaSuccess, + "BatchPrefillWithPagedKVCacheSM90Run failed with error: ", + cudaGetErrorString(status)); + return true; + }}); + }}); +}} +""", + f"""// _pybind.cc +#include "pytorch_extension_utils.h" + +{batch_prefill_plan_func}; + +{batch_prefill_ragged_func_templ}; + +{batch_prefill_paged_func_templ}; + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {{ + m.def("plan", &BatchPrefillWithKVCacheSM90Plan); + m.def("ragged_run", &BatchPrefillWithRaggedKVCacheSM90Run); + m.def("paged_run", &BatchPrefillWithPagedKVCacheSM90Run); +}} """, - *[ - ragged_prefill_sm90_inst_templ(mask_mode) - for mask_mode in ["MaskMode::kNone", "MaskMode::kCausal", "MaskMode::kCustom"] - ], - r""" +] + + +# stuffs beyond this line are not tested + + +customizable_batch_prefill_ragged_params_templ = r""" +struct BatchPrefillRaggedParams { + using DTypeQ = cutlass_dtype_t<{{ dtype_q }}>; + using DTypeKV = cutlass_dtype_t<{{ dtype_kv }}>; + using DTypeO = cutlass_dtype_t<{{ dtype_o }}>; + using IdType = cutlass_dtype_t<{{ dtype_idx }}>; + + // The QKV matrices. + DTypeQ* q_ptr; + DTypeKV* k_ptr; + DTypeKV* v_ptr; + DTypeO* o_ptr; + float* lse_ptr; + + // Additional params + {{ additional_params_decl }}; + + IdType* qo_tile_indices; + IdType* qo_indptr; + IdType* kv_indptr; + IdType* qo_lens; + IdType* kv_lens; + IdType* head_indices; + IdType* work_indptr; + + int64_t q_stride_n; + int64_t k_stride_n; + int64_t v_stride_n; + int64_t o_stride_n; + int64_t q_stride_h; + int64_t k_stride_h; + int64_t v_stride_h; + int64_t o_stride_h; + int64_t nnz_qo; + int64_t nnz_kv; + + int head_dim; + int num_qo_heads; + int num_kv_heads; + int group_size; + int window_left; + + float logits_soft_cap; + float sm_scale_log2; + bool causal; +}; +""" + + +def customizable_ragged_prefill_sm90_inst_templ(mask_mode: str) -> str: + return f"""// customizable_ragged_prefill_sm90_inst +#include +#include +#include +#include + +namespace flashinfer {{ + +using DTypeQ = cutlass_dtype_t<{{{{ dtype_q }}}}>; +using DTypeKV = cutlass_dtype_t<{{{{ dtype_kv }}}}>; +using DTypeO = cutlass_dtype_t<{{{{ dtype_o }}}}>; +using IdType = cutlass_dtype_t<{{{{ dtype_idx }}}}>; + +{customizable_batch_prefill_ragged_params_templ} + +{{{{ variant_decl }}}} + +using AttentionVariant = {{{{ variant_name }}}}; + +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched + <{{{{ head_dim }}}}, + /*mask_mode*/{mask_mode}, + /*USE_SWA=*/true, + /*SAME_SCHEDULER_FOR_ALL_HEADS=*/true, + AttentionVariant> + (typename AttentionVariant::ParamsT& params, cudaStream_t stream); + +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched + <{{{{ head_dim }}}}, + /*mask_mode*/{mask_mode}, + /*USE_SWA=*/true, + /*SAME_SCHEDULER_FOR_ALL_HEADS=*/false, + AttentionVariant> + (typename AttentionVariant::ParamsT& params, cudaStream_t stream); + +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched + <{{{{ head_dim }}}}, + /*mask_mode*/{mask_mode}, + /*USE_SWA=*/false, + /*SAME_SCHEDULER_FOR_ALL_HEADS=*/true, + AttentionVariant> + (typename AttentionVariant::ParamsT& params, cudaStream_t stream); + +template cudaError_t BatchPrefillWithRaggedKVCacheDispatched + <{{{{ head_dim }}}}, + /*mask_mode*/{mask_mode}, + /*USE_SWA=*/false, + /*SAME_SCHEDULER_FOR_ALL_HEADS=*/false, + AttentionVariant> + (typename AttentionVariant::ParamsT& params, cudaStream_t stream); +}} // namespace flashinfer +""" + + +customizable_batch_prefill_paged_params_templ = r""" +struct BatchPrefillPagedParams { + using DTypeQ = cutlass_dtype_t<{{ dtype_q }}>; + using DTypeKV = cutlass_dtype_t<{{ dtype_kv }}>; + using DTypeO = cutlass_dtype_t<{{ dtype_o }}>; + using IdType = cutlass_dtype_t<{{ dtype_idx }}>; + + // The QKV matrices. + DTypeQ* q_ptr; + DTypeKV* k_ptr; + DTypeKV* v_ptr; + DTypeO* o_ptr; + float* lse_ptr; + + // Additional params + {{ additional_params_decl }}; + + IdType* qo_tile_indices; + IdType* qo_indptr; + IdType* kv_indptr; + IdType* kv_indices; + IdType* qo_lens; + IdType* kv_lens; + IdType* head_indices; + IdType* work_indptr; + + int64_t q_stride_n; + int64_t k_stride_n; + int64_t v_stride_n; + int64_t o_stride_n; + int64_t q_stride_h; + int64_t k_stride_h; + int64_t v_stride_h; + int64_t o_stride_h; + int64_t nnz_qo; + + int head_dim; + int num_qo_heads; + int num_kv_heads; + int group_size; + int page_size; + int window_left; + + float logits_soft_cap; + float sm_scale_log2; + bool causal; +}; +""" + + +def customizable_paged_prefill_sm90_inst_templ(mask_mode: str) -> str: + return f"""// sm90_batch_paged_prefill template inst +#include +#include +#include +#include + +namespace flashinfer {{ + +using DTypeQ = cutlass_dtype_t<{{{{ dtype_q }}}}>; +using DTypeKV = cutlass_dtype_t<{{{{ dtype_kv }}}}>; +using DTypeO = cutlass_dtype_t<{{{{ dtype_o }}}}>; +using IdType = cutlass_dtype_t<{{{{ dtype_idx }}}}>; + +{customizable_batch_prefill_paged_params_templ} + +{{{{ variant_decl }}}} + +using AttentionVariant = {{{{ variant_name }}}}; + +template cudaError_t BatchPrefillWithPagedKVCacheDispatched + <{{{{ head_dim }}}}, + {mask_mode}, + /*USE_SWA=*/true, + /*SAME_SCHEDULER_FOR_ALL_HEADS=*/true, + AttentionVariant> + (typename AttentionVariant::ParamsT& params, cudaStream_t stream); + +template cudaError_t BatchPrefillWithPagedKVCacheDispatched + <{{{{ head_dim }}}}, + {mask_mode}, + /*USE_SWA=*/true, + /*SAME_SCHEDULER_FOR_ALL_HEADS=*/false, + AttentionVariant> + (typename AttentionVariant::ParamsT& params, cudaStream_t stream); + +template cudaError_t BatchPrefillWithPagedKVCacheDispatched + <{{{{ head_dim }}}}, + {mask_mode}, + /*USE_SWA=*/false, + /*SAME_SCHEDULER_FOR_ALL_HEADS=*/true, + AttentionVariant> + (typename AttentionVariant::ParamsT& params, cudaStream_t stream); + +template cudaError_t BatchPrefillWithPagedKVCacheDispatched + <{{{{ head_dim }}}}, + {mask_mode}, + /*USE_SWA=*/false, + /*SAME_SCHEDULER_FOR_ALL_HEADS=*/false, + AttentionVariant> + (typename AttentionVariant::ParamsT& params, cudaStream_t stream); +}} +""" + + +customizable_batch_prefill_sm90_suffix = [ + "_plan.cu", + "_ragged_kernel_mask_0.cu", + "_ragged_kernel_mask_1.cu", + "_ragged_kernel_mask_2.cu", + "_ragged_run.cu", + "_paged_kernel_mask_0.cu", + "_paged_kernel_mask_1.cu", + "_paged_kernel_mask_2.cu", + "_paged_run.cu", + "_pybind.cu", +] + + +customizable_batch_prefill_sm90_templ = [ + batch_prefill_plan_impl, + customizable_ragged_prefill_sm90_inst_templ("MaskMode::kNone"), + customizable_ragged_prefill_sm90_inst_templ("MaskMode::kCausal"), + customizable_ragged_prefill_sm90_inst_templ("MaskMode::kCustom"), + f"""// _ragged_run.cu +#include #include #include -#include #include #include #include @@ -164,40 +731,32 @@ def paged_prefill_sm90_inst_templ(mask_mode: str) -> str: #include "pytorch_extension_utils.h" -namespace flashinfer { +namespace flashinfer {{ + +{customizable_batch_prefill_ragged_params_templ} -template -cudaError_t BatchPrefillWithRaggedKVCacheDispatched( - BatchPrefillRaggedParams& params, cudaStream_t stream); +{{{{ variant_decl }}}} -}; // namespace flashinfer +}}; // namespace flashinfer using namespace flashinfer; -using DTypeQ = cutlass_dtype_t<{{ dtype_q }}>; -using DTypeKV = cutlass_dtype_t<{{ dtype_kv }}>; -using DTypeO = cutlass_dtype_t<{{ dtype_o }}>; -using IdType = cutlass_dtype_t<{{ dtype_idx }}>; +using DTypeQ = cutlass_dtype_t<{{{{ dtype_q }}}}>; +using DTypeKV = cutlass_dtype_t<{{{{ dtype_kv }}}}>; +using DTypeO = cutlass_dtype_t<{{{{ dtype_o }}}}>; +using IdType = cutlass_dtype_t<{{{{ dtype_idx }}}}>; -using RaggedParams = BatchPrefillRaggedParams; +using RaggedParams = BatchPrefillRaggedParams; -void BatchPrefillWithRaggedKVCacheSM90Run( - unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - std::vector plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v, - std::optional maybe_custom_mask, std::optional maybe_alibi_slopes, - at::Tensor qo_indptr, at::Tensor kv_indptr, std::optional maybe_qk_indptr, - at::Tensor o, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, std::optional maybe_lse, int64_t cuda_stream) { +{batch_prefill_ragged_func_templ} {{ PrefillPlanSM90Info plan_info; plan_info.FromVector(plan_info_vec); - if (maybe_lse) { + if (maybe_lse) {{ const auto& lse = *maybe_lse; TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); - } + }} void* float_buffer_ptr = float_workspace_buffer.data_ptr(); void* int_buffer_ptr = int_workspace_buffer.data_ptr(); @@ -210,29 +769,29 @@ def paged_prefill_sm90_inst_templ(mask_mode: str) -> str: RaggedParams params; - params.q_ptr = static_cast(q.data_ptr()); - params.k_ptr = static_cast(k.data_ptr()); - params.v_ptr = static_cast(v.data_ptr()); - params.o_ptr = static_cast(o.data_ptr()); + params.q_ptr = static_cast*>(q.data_ptr()); + params.k_ptr = static_cast*>(k.data_ptr()); + params.v_ptr = static_cast*>(v.data_ptr()); + params.o_ptr = static_cast*>(o.data_ptr()); params.lse_ptr = maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr; params.q_stride_n = q.stride(0); params.q_stride_h = q.stride(1); params.o_stride_n = o.stride(0); params.o_stride_h = o.stride(1); - if (kv_layout == QKVLayout::kNHD) { + if (kv_layout == QKVLayout::kNHD) {{ params.k_stride_n = k.stride(0); params.k_stride_h = k.stride(1); params.v_stride_n = v.stride(0); params.v_stride_h = v.stride(1); - } else { + }} else {{ params.k_stride_h = k.stride(0); params.k_stride_n = k.stride(1); params.v_stride_h = v.stride(0); params.v_stride_n = v.stride(1); - } + }} params.nnz_qo = q.size(0); params.nnz_kv = k.size(0); - params.head_dim = {{ head_dim }}; + params.head_dim = {{{{ head_dim }}}}; params.num_qo_heads = q.size(1); params.num_kv_heads = k.size(1); params.group_size = params.num_qo_heads / params.num_kv_heads; @@ -240,39 +799,40 @@ def paged_prefill_sm90_inst_templ(mask_mode: str) -> str: params.logits_soft_cap = logits_soft_cap; params.sm_scale_log2 = sm_scale * math::log2e; params.causal = mask_mode_code == 1; - params.qo_tile_indices = - GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); + params.qo_tile_indices = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); params.qo_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_indptr_offset); params.kv_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_indptr_offset); params.qo_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_len_offset); params.kv_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_len_offset); - params.head_indices = - GetPtrFromBaseOffset(int_buffer_ptr, plan_info.head_indices_offset); + params.head_indices = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.head_indices_offset); params.work_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.work_indptr_offset); - bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads; + {{{{ additional_params_setter }}}}; - DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { - DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] { - using AttentionVariant = - std::conditional_t<{{ use_logits_soft_cap }}, LogitsSoftCap, StandardAttention>; + bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads; + DISPATCH_MASK_MODE(mask_mode, MASK_MODE, {{ + DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] {{ + using AttentionVariant = {{{{ variant_name }}}}; cudaError_t status = - BatchPrefillWithRaggedKVCacheDispatched<{{ head_dim }}, MASK_MODE, {{ use_sliding_window }}, - SAME_SCHEDULER_FOR_ALL_HEADS, AttentionVariant>(params, stream); + BatchPrefillWithRaggedKVCacheDispatched + <{{{{ head_dim }}}}, + MASK_MODE, + {{{{ use_sliding_window }}}}, + false, + AttentionVariant>(params, stream); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithRaggedKVCacheSM90Run failed with error: ", cudaGetErrorString(status)); - return true; - }); - }); -} + }}); + }}); +}} """, - *[ - paged_prefill_sm90_inst_templ(mask_mode) - for mask_mode in ["MaskMode::kNone", "MaskMode::kCausal", "MaskMode::kCustom"] - ], - r"""#include -#include + customizable_paged_prefill_sm90_inst_templ("MaskMode::kNone"), + customizable_paged_prefill_sm90_inst_templ("MaskMode::kCausal"), + customizable_paged_prefill_sm90_inst_templ("MaskMode::kCustom"), + f"""// _paged_run.cu +#include +#include #include #include #include @@ -283,52 +843,42 @@ def paged_prefill_sm90_inst_templ(mask_mode: str) -> str: #include "pytorch_extension_utils.h" -namespace flashinfer { +namespace flashinfer {{ -template -cudaError_t BatchPrefillWithPagedKVCacheDispatched( - BatchPrefillPagedParams& params, cudaStream_t stream); +{customizable_batch_prefill_paged_params_templ} -}; // namespace flashinfer +{{{{ variant_decl }}}} + +}}; // namespace flashinfer using namespace flashinfer; -using DTypeQ = cutlass_dtype_t<{{ dtype_q }}>; -using DTypeKV = cutlass_dtype_t<{{ dtype_kv }}>; -using DTypeO = cutlass_dtype_t<{{ dtype_o }}>; -using IdType = cutlass_dtype_t<{{ dtype_idx }}>; +using DTypeQ = cutlass_dtype_t<{{{{ dtype_q }}}}>; +using DTypeKV = cutlass_dtype_t<{{{{ dtype_kv }}}}>; +using DTypeO = cutlass_dtype_t<{{{{ dtype_o }}}}>; +using IdType = cutlass_dtype_t<{{{{ dtype_idx }}}}>; -using PagedParams = BatchPrefillPagedParams; +using PagedParams = BatchPrefillPagedParams; -void BatchPrefillWithPagedKVCacheSM90Run( - unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - std::vector plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, - at::Tensor paged_v_cache, std::optional maybe_custom_mask, - std::optional maybe_alibi_slopes, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, - at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, - std::optional maybe_qk_indptr, at::Tensor o, unsigned int layout, - int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, - std::optional maybe_lse, int64_t cuda_stream) { +{batch_prefill_paged_func_templ} {{ PrefillPlanSM90Info plan_info; plan_info.FromVector(plan_info_vec); - if (maybe_lse) { + if (maybe_lse) {{ const auto& lse = *maybe_lse; TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); - } + }} QKVLayout kv_layout = static_cast(layout); unsigned int num_kv_heads, page_size; unsigned int head_dim = q.size(2); - if (kv_layout == QKVLayout::kHND) { + if (kv_layout == QKVLayout::kHND) {{ num_kv_heads = paged_k_cache.size(1); page_size = paged_k_cache.size(2); - } else { + }} else {{ page_size = paged_k_cache.size(1); num_kv_heads = paged_k_cache.size(2); - } + }} void* float_buffer_ptr = float_workspace_buffer.data_ptr(); void* int_buffer_ptr = int_workspace_buffer.data_ptr(); @@ -349,19 +899,19 @@ def paged_prefill_sm90_inst_templ(mask_mode: str) -> str: params.q_stride_h = q.stride(1); params.o_stride_n = o.stride(0); params.o_stride_h = o.stride(1); - if (kv_layout == QKVLayout::kNHD) { + if (kv_layout == QKVLayout::kNHD) {{ // (num_pages, page_size, num_heads, head_dim) params.k_stride_n = paged_k_cache.stride(1); params.k_stride_h = paged_k_cache.stride(2); params.v_stride_n = paged_v_cache.stride(1); params.v_stride_h = paged_v_cache.stride(2); - } else { + }} else {{ // (num_pages, num_heads, page_size, head_dim) params.k_stride_h = paged_k_cache.stride(1); params.k_stride_n = paged_k_cache.stride(2); params.v_stride_h = paged_v_cache.stride(1); params.v_stride_n = paged_v_cache.stride(2); - } + }} params.nnz_qo = q.size(0); params.head_dim = head_dim; params.num_qo_heads = q.size(1); @@ -372,64 +922,47 @@ def paged_prefill_sm90_inst_templ(mask_mode: str) -> str: params.logits_soft_cap = logits_soft_cap; params.sm_scale_log2 = sm_scale * math::log2e; params.causal = mask_mode_code == 1; - params.qo_tile_indices = - GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); + params.qo_tile_indices = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_tile_indices_offset); params.qo_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_indptr_offset); params.kv_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_indptr_offset); params.qo_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.qo_len_offset); params.kv_lens = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_len_offset); - params.head_indices = - GetPtrFromBaseOffset(int_buffer_ptr, plan_info.head_indices_offset); + params.head_indices = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.head_indices_offset); params.work_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.work_indptr_offset); params.kv_indices = static_cast(paged_kv_indices.data_ptr()); - bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads; + {{{{ additional_params_setter }}}}; - DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { - DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] { - using AttentionVariant = - std::conditional_t<{{ use_logits_soft_cap }}, LogitsSoftCap, StandardAttention>; - cudaError_t status = - BatchPrefillWithPagedKVCacheDispatched<{{ head_dim }}, MASK_MODE, {{ use_sliding_window }}, - SAME_SCHEDULER_FOR_ALL_HEADS, AttentionVariant>(params, stream); + bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads; + DISPATCH_MASK_MODE(mask_mode, MASK_MODE, {{ + DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] {{ + using AttentionVariant = {{{{ variant_name }}}}; + cudaError_t status = BatchPrefillWithPagedKVCacheDispatched + <{{{{ head_dim }}}}, + MASK_MODE, + {{{{ use_sliding_window }}}}, + false, + AttentionVariant>(params, stream); TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCacheSM90Run failed with error: ", cudaGetErrorString(status)); return true; - }); - }); -} + }}); + }}); +}} """, - r"""#include "pytorch_extension_utils.h" - -std::vector BatchPrefillWithKVCacheSM90Plan( - bool causal, at::Tensor float_workspace_buffer, - at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer, - at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned int total_num_rows, - unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, - bool enable_cuda_graph, int64_t cuda_stream); - -void BatchPrefillWithRaggedKVCacheSM90Run( - unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - std::vector plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v, - std::optional maybe_custom_mask, std::optional maybe_alibi_slopes, - at::Tensor qo_indptr, at::Tensor kv_indptr, std::optional maybe_qk_indptr, - at::Tensor o, unsigned int layout, int32_t window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta, std::optional maybe_lse, int64_t cuda_stream); - -void BatchPrefillWithPagedKVCacheSM90Run( - unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, - std::vector plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, - at::Tensor paged_v_cache, std::optional maybe_custom_mask, - std::optional maybe_alibi_slopes, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, - at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, - std::optional maybe_qk_indptr, at::Tensor o, unsigned int layout, - int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, - std::optional maybe_lse, int64_t cuda_stream); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + f"""// _pybind.cu +#include "pytorch_extension_utils.h" + +{batch_prefill_plan_func}; + +{batch_prefill_ragged_func_templ}; + +{batch_prefill_paged_func_templ}; + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {{ m.def("plan", &BatchPrefillWithKVCacheSM90Plan); m.def("ragged_run", &BatchPrefillWithRaggedKVCacheSM90Run); m.def("paged_run", &BatchPrefillWithPagedKVCacheSM90Run); -} +}} """, ] diff --git a/flashinfer/jit/single_prefill_sm90_templ.py b/flashinfer/jit/single_prefill_sm90_templ.py index ac9f7112..3b2d6ba8 100644 --- a/flashinfer/jit/single_prefill_sm90_templ.py +++ b/flashinfer/jit/single_prefill_sm90_templ.py @@ -15,51 +15,72 @@ """ single_prefill_sm90_suffix = [ - *[f"_kernel_mask_{mask_mode}.cu" for mask_mode in [0, 1, 2]], + "_kernel_mask_0.cu", + "_kernel_mask_1.cu", + "_kernel_mask_2.cu", ".cu", "_pybind.cc", ] +single_prefill_sm90_func = r"""void single_prefill_with_kv_cache_sm90( + unsigned int mask_mode_code, + at::Tensor q, + at::Tensor k, + at::Tensor v, + std::optional maybe_packed_custom_mask, + std::optional maybe_alibi_slopes, + at::Tensor o, + unsigned int layout, + int32_t window_left, + float logits_soft_cap, + float sm_scale, + float rope_scale, + float rope_theta, + std::optional maybe_lse, + {{ additional_func_params }} + int64_t cuda_stream)""" + + def single_prefill_sm90_inst_templ(mask_mode: str) -> str: - return ( - r"""#include + return f""" // single_prefill_sm90 template instantiation +#include +#include #include #include -namespace flashinfer { +namespace flashinfer {{ -using DTypeQ = cutlass_dtype_t<{{dtype_q}}>; -using DTypeKV = cutlass_dtype_t<{{dtype_kv}}>; -using DTypeO = cutlass_dtype_t<{{dtype_o}}>; +using DTypeQ = cutlass_dtype_t<{{{{ dtype_q }}}}>; +using DTypeKV = cutlass_dtype_t<{{{{ dtype_kv }}}}>; +using DTypeO = cutlass_dtype_t<{{{{ dtype_o }}}}>; using Params = SinglePrefillParams; -using AttentionVariant = std::conditional_t<{{use_logits_soft_cap}}, LogitsSoftCap, StandardAttention>; +using AttentionVariant = std::conditional_t< + {{{{use_logits_soft_cap}}}}, + LogitsSoftCap, + StandardAttention>; -template cudaError_t SinglePrefillWithKVCacheDispatched<{{ head_dim }},""" - f"{mask_mode}" - r""", /*USE_SWA=*/false, AttentionVariant>( - Params& params, - cudaStream_t stream); +template cudaError_t SinglePrefillWithKVCacheDispatched + <{{{{ head_dim }}}}, {mask_mode}, /*USE_SWA=*/true, AttentionVariant>( + Params& params, cudaStream_t stream); -template cudaError_t SinglePrefillWithKVCacheDispatched<{{ head_dim }},""" - f"{mask_mode}" - r""", /*USE_SWA=*/true, AttentionVariant>( - Params& params, - cudaStream_t stream); +template cudaError_t SinglePrefillWithKVCacheDispatched + <{{{{ head_dim }}}}, {mask_mode}, /*USE_SWA=*/false, AttentionVariant>( + Params& params, cudaStream_t stream); -} // namespace flashinfer +}} // namespace flashinfer """ - ) single_prefill_sm90_templ = [ - *[ - single_prefill_sm90_inst_templ(mask_mode) - for mask_mode in ["MaskMode::kNone", "MaskMode::kCausal", "MaskMode::kCustom"] - ], - r"""#include + single_prefill_sm90_inst_templ("MaskMode::kNone"), + single_prefill_sm90_inst_templ("MaskMode::kCausal"), + single_prefill_sm90_inst_templ("MaskMode::kCustom"), + f"""// _run.cu +#include #include +#include #include #include #include @@ -68,25 +89,196 @@ def single_prefill_sm90_inst_templ(mask_mode: str) -> str: using namespace flashinfer; -namespace flashinfer { +{single_prefill_sm90_func} {{ + unsigned int head_dim = q.size(2); + unsigned int num_qo_heads = q.size(1); + unsigned int qo_len = q.size(0); + + auto q_scalar_type = q.scalar_type(); + + QKVLayout kv_layout = static_cast(layout); + cudaStream_t stream = reinterpret_cast(cuda_stream); + const MaskMode mask_mode = static_cast(mask_mode_code); + + using DTypeQ = cutlass_dtype_t<{{{{ dtype_q }}}}>; + using DTypeKV = cutlass_dtype_t<{{{{ dtype_kv }}}}>; + using DTypeO = cutlass_dtype_t<{{{{ dtype_o }}}}>; + + using SinglePrefillParams = SinglePrefillParams; + SinglePrefillParams params; + params.q_ptr = static_cast(q.data_ptr()); + params.k_ptr = static_cast(k.data_ptr()); + params.v_ptr = static_cast(v.data_ptr()); + params.o_ptr = static_cast(o.data_ptr()); + params.lse_ptr = maybe_lse ? (static_cast(maybe_lse->data_ptr())) : nullptr; + params.q_stride_n = q.stride(0); + params.q_stride_h = q.stride(1); + params.o_stride_n = o.stride(0); + params.o_stride_h = o.stride(1); + if (kv_layout == QKVLayout::kNHD) {{ + params.k_stride_n = k.stride(0); + params.k_stride_h = k.stride(1); + params.v_stride_n = v.stride(0); + params.v_stride_h = v.stride(1); + }} else {{ + params.k_stride_h = k.stride(0); + params.k_stride_n = k.stride(1); + params.v_stride_h = v.stride(0); + params.v_stride_n = v.stride(1); + }} + params.qo_len = q.size(0); + params.kv_len = k.size(0); + params.head_dim = head_dim; + params.num_qo_heads = q.size(1); + params.num_kv_heads = k.size(1); + params.causal = mask_mode == MaskMode::kCausal; + params.group_size = params.num_qo_heads / params.num_kv_heads; + params.window_left = window_left; + params.logits_soft_cap = logits_soft_cap; + params.sm_scale_log2 = sm_scale * math::log2e; -template -cudaError_t SinglePrefillWithKVCacheDispatched(SinglePrefillParams& params, - cudaStream_t stream); + DISPATCH_MASK_MODE(mask_mode, MASK_MODE, {{ + using AttentionVariant = + std::conditional_t<{{{{ use_logits_soft_cap }}}}, + LogitsSoftCap, + StandardAttention>; + cudaError_t status = + SinglePrefillWithKVCacheDispatched + <{{{{ head_dim }}}}, MASK_MODE, {{{{ use_sliding_window }}}}, AttentionVariant> + (params, stream); + TORCH_CHECK(status == cudaSuccess, + "SinglePrefillWithKVCacheDispatched failed with error: ", + cudaGetErrorString(status)); + }}); +}} +""", + f"""// _pybind.cc +#include "pytorch_extension_utils.h" -} // namespace flashinfer +{single_prefill_sm90_func}; + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {{ + m.def("run", &single_prefill_with_kv_cache_sm90, + "Single-request prefill attention with KV-Cache operator"); +}} +""", +] + + +customizable_single_prefill_sm90_func = r"""void single_prefill_with_kv_cache_sm90( + unsigned int mask_mode_code, + at::Tensor q, + at::Tensor k, + at::Tensor v, + at::Tensor buffer, + at::Tensor o, + unsigned int layout, + int32_t window_left, + std::optional maybe_lse, + {{ additional_func_params }} + int64_t cuda_stream)""" + + +customizable_single_prefill_sm90_params_templ = r""" +struct SinglePrefillParams { + using DTypeQ = cutlass_dtype_t<{{ dtype_q }}>; + using DTypeKV = cutlass_dtype_t<{{ dtype_kv }}>; + using DTypeO = cutlass_dtype_t<{{ dtype_o }}>; + using IdType = cutlass_dtype_t; + + // The QKV matrices. + DTypeQ* q_ptr; + DTypeKV* k_ptr; + DTypeKV* v_ptr; + DTypeO* o_ptr; + float* lse_ptr; + + // Additional params + {{ additional_params_decl }}; + + int64_t q_stride_n; + int64_t k_stride_n; + int64_t v_stride_n; + int64_t o_stride_n; + int64_t q_stride_h; + int64_t k_stride_h; + int64_t v_stride_h; + int64_t o_stride_h; + + int qo_len; + int kv_len; + int head_dim; + int num_qo_heads; + int num_kv_heads; + int group_size; + int window_left; + + bool causal; + + // these are bad arguments. we should remove them from default in prefill_sm90.cuh. + float logits_soft_cap = 0.; + float sm_scale_log2 = 0.; +}; +""" + + +def customizable_single_prefill_sm90_inst_templ(mask_mode: str) -> str: + return f"""// single_prefill_sm90 template instantiation +#include +#include +#include + +namespace flashinfer {{ + +using DTypeQ = cutlass_dtype_t<{{{{ dtype_q }}}}>; +using DTypeKV = cutlass_dtype_t<{{{{ dtype_kv }}}}>; +using DTypeO = cutlass_dtype_t<{{{{ dtype_o }}}}>; + +{customizable_single_prefill_sm90_params_templ} + +{{{{ variant_decl }}}} + +using AttentionVariant = {{{{ variant_name }}}}; + +template cudaError_t SinglePrefillWithKVCacheDispatched + <{{{{ head_dim }}}}, {mask_mode}, /*USE_SWA=*/true, AttentionVariant> + (typename AttentionVariant::ParamsT& params, cudaStream_t stream); + +template cudaError_t SinglePrefillWithKVCacheDispatched + <{{{{ head_dim }}}}, {mask_mode}, /*USE_SWA=*/false, AttentionVariant> + (typename AttentionVariant::ParamsT& params, cudaStream_t stream); + +}} // namespace flashinfer +""" + + +customizable_single_prefill_sm90_suffix = single_prefill_sm90_suffix + + +customizable_single_prefill_sm90_templ = [ + customizable_single_prefill_sm90_inst_templ("MaskMode::kNone"), + customizable_single_prefill_sm90_inst_templ("MaskMode::kCausal"), + customizable_single_prefill_sm90_inst_templ("MaskMode::kCustom"), + f"""// _run.cu +#include +#include +#include +#include +#include +#include +#include "pytorch_extension_utils.h" + +namespace flashinfer {{ + +{customizable_single_prefill_sm90_params_templ} + +{{{{ variant_decl }}}} + +}}; // namespace flashinfer using namespace flashinfer; -void single_prefill_with_kv_cache_sm90(unsigned int mask_mode_code, at::Tensor q, at::Tensor k, - at::Tensor v, - std::optional maybe_packed_custom_mask, - std::optional maybe_alibi_slopes, at::Tensor o, - unsigned int layout, int32_t window_left, - float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta, std::optional maybe_lse, - int64_t cuda_stream) { +{customizable_single_prefill_sm90_func} {{ unsigned int head_dim = q.size(2); unsigned int num_qo_heads = q.size(1); unsigned int qo_len = q.size(0); @@ -97,11 +289,11 @@ def single_prefill_sm90_inst_templ(mask_mode: str) -> str: cudaStream_t stream = reinterpret_cast(cuda_stream); const MaskMode mask_mode = static_cast(mask_mode_code); - using DTypeQ = cutlass_dtype_t<{{dtype_q}}>; - using DTypeKV = cutlass_dtype_t<{{dtype_kv}}>; - using DTypeO = cutlass_dtype_t<{{dtype_o}}>; + using DTypeQ = cutlass_dtype_t<{{{{ dtype_q }}}}>; + using DTypeKV = cutlass_dtype_t<{{{{ dtype_kv }}}}>; + using DTypeO = cutlass_dtype_t<{{{{ dtype_o }}}}>; - SinglePrefillParams params; + SinglePrefillParams params; params.q_ptr = static_cast(q.data_ptr()); params.k_ptr = static_cast(k.data_ptr()); params.v_ptr = static_cast(v.data_ptr()); @@ -111,17 +303,17 @@ def single_prefill_sm90_inst_templ(mask_mode: str) -> str: params.q_stride_h = q.stride(1); params.o_stride_n = o.stride(0); params.o_stride_h = o.stride(1); - if (kv_layout == QKVLayout::kNHD) { + if (kv_layout == QKVLayout::kNHD) {{ params.k_stride_n = k.stride(0); params.k_stride_h = k.stride(1); params.v_stride_n = v.stride(0); params.v_stride_h = v.stride(1); - } else { + }} else {{ params.k_stride_h = k.stride(0); params.k_stride_n = k.stride(1); params.v_stride_h = v.stride(0); params.v_stride_n = v.stride(1); - } + }} params.qo_len = q.size(0); params.kv_len = k.size(0); params.head_dim = head_dim; @@ -130,34 +322,28 @@ def single_prefill_sm90_inst_templ(mask_mode: str) -> str: params.causal = mask_mode == MaskMode::kCausal; params.group_size = params.num_qo_heads / params.num_kv_heads; params.window_left = window_left; - params.logits_soft_cap = logits_soft_cap; - params.sm_scale_log2 = sm_scale * math::log2e; - DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { - using AttentionVariant = - std::conditional_t<{{ use_logits_soft_cap }}, LogitsSoftCap, StandardAttention>; + {{{{ additional_params_setter }}}}; + + DISPATCH_MASK_MODE(mask_mode, MASK_MODE, {{ + using AttentionVariant = {{{{ variant_name }}}}; cudaError_t status = - SinglePrefillWithKVCacheDispatched<{{ head_dim }}, MASK_MODE, {{ use_sliding_window }}, AttentionVariant>( - params, stream); + SinglePrefillWithKVCacheDispatched + <{{{{ head_dim }}}}, MASK_MODE, {{{{ use_sliding_window }}}}, AttentionVariant> + (params, stream); TORCH_CHECK(status == cudaSuccess, - "single_prefill_with_kv_cache_sm90 failed with error: " + - std::string(cudaGetErrorString(status))); - }); -} + "SinglePrefillWithKVCacheDispatched failed with error: ", + cudaGetErrorString(status)); + }}); +}} """, - r"""#include "pytorch_extension_utils.h" - -void single_prefill_with_kv_cache_sm90(unsigned int mask_mode_code, at::Tensor q, at::Tensor k, - at::Tensor v, - std::optional maybe_packed_custom_mask, - std::optional maybe_alibi_slopes, at::Tensor o, - unsigned int layout, int32_t window_left, - float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta, std::optional maybe_lse, - int64_t cuda_stream); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + f"""// _pybind.cc +#include "pytorch_extension_utils.h" + +{customizable_single_prefill_sm90_func}; + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {{ m.def("run", &single_prefill_with_kv_cache_sm90, "Single-request prefill attention with KV-Cache operator"); -} +}} """, ] diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 8dd45ed7..b997ab0e 100644 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -60,7 +60,7 @@ _single_prefill_sm90_modules = {} _batch_prefill_modules = {} _batch_prefill_sm90_modules = {} - +_batch_prefill_jit_modules = {} def get_single_prefill_sm90_module(*args): global _single_prefill_sm90_modules @@ -628,6 +628,306 @@ def _fake_paged_run( return _batch_prefill_modules[args] +def get_batch_prefill_jit_module(module_name: str, jit_module: Any): + global _batch_prefill_jit_modules + if module_name in _batch_prefill_jit_modules: + return _batch_prefill_jit_modules[module_name] + + plan_func = jit_module.plan + ragged_run_func = jit_module.ragged_run + paged_run_func = jit_module.paged_run + + # torch library for ragged_run + @register_custom_op( + f"flashinfer::{module_name}_ragged_run", + mutates_args=( + "float_workspace_buffer", + "int_workspace_buffer", + "maybe_lse", + ), + ) + def ragged_run( + mask_mode: int, + float_workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + plan_info_vec: List[int], + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + maybe_custom_mask: Optional[torch.Tensor], + maybe_alibi_slopes: Optional[torch.Tensor], + qo_indptr: torch.Tensor, + kv_indptr: torch.Tensor, + maybe_qk_indptr: Optional[torch.Tensor], + layout: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + maybe_lse: Optional[torch.Tensor], + ) -> torch.Tensor: + with q.device as device: # device guard + o = torch.empty_like(q) + ragged_run_func( + mask_mode, + float_workspace_buffer, + int_workspace_buffer, + plan_info_vec, + q, + k, + v, + maybe_custom_mask, + maybe_alibi_slopes, + qo_indptr, + kv_indptr, + maybe_qk_indptr, + o, + layout, + window_left, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + maybe_lse, + get_cuda_stream(device), + ) + return o + + @register_fake_op(f"flashinfer::{module_name}_ragged_run") + def _fake_ragged_run( + mask_mode: int, + float_workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + plan_info_vec: List[int], + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + maybe_custom_mask: Optional[torch.Tensor], + maybe_alibi_slopes: Optional[torch.Tensor], + qo_indptr: torch.Tensor, + kv_indptr: torch.Tensor, + maybe_qk_indptr: Optional[torch.Tensor], + layout: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + maybe_lse: Optional[torch.Tensor], + ) -> torch.Tensor: + return torch.empty_like(q) + + # torch library for paged_run + @register_custom_op( + f"flashinfer::{module_name}_paged_run", + mutates_args=( + "float_workspace_buffer", + "int_workspace_buffer", + "paged_k_cache", + "paged_v_cache", + "maybe_lse", + ), + ) + def paged_run( + mask_mode: int, + float_workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + plan_info_vec: List[int], + q: torch.Tensor, + paged_k_cache: torch.Tensor, + paged_v_cache: torch.Tensor, + maybe_custom_mask: Optional[torch.Tensor], + maybe_alibi_slopes: Optional[torch.Tensor], + qo_indptr: torch.Tensor, + paged_kv_indptr: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + maybe_qk_indptr: Optional[torch.Tensor], + layout: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + maybe_lse: Optional[torch.Tensor], + ) -> torch.Tensor: + with q.device as device: # device guard + o = torch.empty_like(q) + paged_run_func( + mask_mode, + float_workspace_buffer, + int_workspace_buffer, + plan_info_vec, + q, + paged_k_cache, + paged_v_cache, + maybe_custom_mask, + maybe_alibi_slopes, + qo_indptr, + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_len, + maybe_qk_indptr, + o, + layout, + window_left, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + maybe_lse, + get_cuda_stream(device), + ) + return o + + @register_fake_op(f"flashinfer::{module_name}_paged_run") + def _fake_paged_run( + mask_mode: int, + float_workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + plan_info_vec: List[int], + q: torch.Tensor, + paged_k_cache: torch.Tensor, + paged_v_cache: torch.Tensor, + maybe_custom_mask: Optional[torch.Tensor], + maybe_alibi_slopes: Optional[torch.Tensor], + qo_indptr: torch.Tensor, + paged_kv_indptr: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + maybe_qk_indptr: Optional[torch.Tensor], + layout: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + maybe_lse: Optional[torch.Tensor], + ) -> torch.Tensor: + return torch.empty_like(q) + + # Register the module. + # + # Note that plan is not part of model logic. It should not be included in + # Cuda Graph or torch.compile. So, we don't provide a torch library for plan. + _batch_prefill_jit_modules[module_name] = SimpleNamespace( + plan=plan_func, ragged_run=ragged_run, paged_run=paged_run, + ) + + return _batch_prefill_jit_modules[module_name] + + +def batch_prefill_with_ragged_kv_cache_with_jit_module( + jit_module: Any, + mask_mode: int, + float_workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + plan_info_vec: List[int], + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + maybe_custom_mask: Optional[torch.Tensor], + maybe_alibi_slopes: Optional[torch.Tensor], + qo_indptr: torch.Tensor, + kv_indptr: torch.Tensor, + maybe_qk_indptr: Optional[torch.Tensor], + layout: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + maybe_lse: Optional[torch.Tensor], + *args, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Same arg list with jit_module.ragged_run, but with *args.""" + with q.device as device: # device guard + o = torch.empty_like(q) + jit_module.ragged_run( + mask_mode, + float_workspace_buffer, + int_workspace_buffer, + plan_info_vec, + q, + k, + v, + maybe_custom_mask, + maybe_alibi_slopes, + qo_indptr, + kv_indptr, + maybe_qk_indptr, + o, + layout, + window_left, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + maybe_lse, + *args, + get_cuda_stream(device), + ) + return o + + +def batch_prefill_with_paged_kv_cache_with_jit_module( + jit_module: Any, + mask_mode: int, + float_workspace_buffer: torch.Tensor, + int_workspace_buffer: torch.Tensor, + plan_info_vec: List[int], + q: torch.Tensor, + paged_k_cache: torch.Tensor, + paged_v_cache: torch.Tensor, + maybe_custom_mask: Optional[torch.Tensor], + maybe_alibi_slopes: Optional[torch.Tensor], + qo_indptr: torch.Tensor, + paged_kv_indptr: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + maybe_qk_indptr: Optional[torch.Tensor], + layout: int, + window_left: int, + logits_soft_cap: float, + sm_scale: float, + rope_scale: float, + rope_theta: float, + maybe_lse: Optional[torch.Tensor], + *args, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Same arg list with jit_module.paged_run, but with *args.""" + with q.device as device: # device guard + o = torch.empty_like(q) + jit_module.paged_run( + mask_mode, + float_workspace_buffer, + int_workspace_buffer, + plan_info_vec, + q, + paged_k_cache, + paged_v_cache, + maybe_custom_mask, + maybe_alibi_slopes, + qo_indptr, + paged_kv_indptr, + paged_kv_indices, + paged_kv_last_page_len, + maybe_qk_indptr, + o, + layout, + window_left, + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + maybe_lse, + *args, + get_cuda_stream(device), + ) + return o + + def single_prefill_with_kv_cache_with_jit_module( jit_module: Any, q: torch.Tensor, @@ -1046,6 +1346,7 @@ def __init__( custom_mask_buf: Optional[torch.Tensor] = None, qk_indptr_buf: Optional[torch.Tensor] = None, backend: str = "auto", + jit_module: Any = None, ) -> None: r"""Constructor of :class:`BatchPrefillWithPagedKVCacheWrapper`. @@ -1103,16 +1404,25 @@ def __init__( device architecture and kernel availability. """ _check_kv_layout(kv_layout) + if backend.startswith("jit"): + if jit_module is None: + raise ValueError("backend is `jit` but jit_module is not provided.") + self._jit_backend = True + self._jit_module = jit_module + else: + if jit_module is not None: + raise ValueError(f"backend is `{backend}` but jit_module is provided.") + self._jit_backend = False self._kv_layout = kv_layout self._float_workspace_buffer = float_workspace_buffer self.device = float_workspace_buffer.device - self._int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device - ) - if backend in ["fa3", "auto"]: - # NOTE(Zihao): assume maximum accumulate kv length is 4M + if backend in ["fa3", "auto"] or self._jit_backend: + self._int_workspace_buffer = torch.empty( + (64 * 1024 * 1024,), dtype=torch.uint8, device=self.device + ) + # NOTE(Zihao): assume maximum accumulate kv length is 16M self._vector_sparse_indices_buffer = torch.empty( - (4 * 1024 * 1024,), dtype=torch.int32, device=self.device + (16 * 1024 * 1024,), dtype=torch.int32, device=self.device ) # NOTE(Zihao): assume maximum batch size is 32768 self._vector_sparse_indptr_buffer = torch.empty( @@ -1121,6 +1431,10 @@ def __init__( self._kv_lens_buffer = torch.empty( (32768,), dtype=torch.int32, device=self.device ) + else: + self._int_workspace_buffer = torch.empty( + (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device + ) self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=self._int_workspace_buffer.dtype, @@ -1400,46 +1714,8 @@ def plan( self._cached_q_data_type = q_data_type self._cached_kv_data_type = kv_data_type - if self._backend == "auto": - self._backend = determine_attention_backend( - self.device, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - self._custom_mask_buf is not None, # use_custom_mask - q_data_type, - kv_data_type, - ) - - get_module_args = ( - q_data_type, - kv_data_type, - q_data_type, - paged_kv_indptr.dtype, - head_dim, - PosEncodingMode[pos_encoding_mode].value, - window_left >= 0, # use_sliding_window - logits_soft_cap > 0, # use_logits_soft_cap - allow_fp16_qk_reduction, - ) - if self._backend == "fa2": - self._cached_module = get_batch_prefill_module(*get_module_args) - with self.device as device: - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - qo_indptr_host, - paged_kv_indptr_host, - self._max_total_num_rows or total_num_rows, - batch_size, - num_qo_heads, - num_kv_heads, - page_size, - self.is_cuda_graph_enabled, - get_cuda_stream(device), - ) - else: - self._cached_module = get_batch_prefill_sm90_module(*get_module_args) + if self._jit_backend: + self._cached_module = get_batch_prefill_jit_module(self._backend, self._jit_module) paged_kv_last_page_len_host = paged_kv_last_page_len.to("cpu") kv_lens_arr_host = get_seq_lens( paged_kv_indptr_host, paged_kv_last_page_len_host, page_size @@ -1478,6 +1754,85 @@ def plan( self.is_cuda_graph_enabled, get_cuda_stream(device), ) + else: + if self._backend == "auto": + self._backend = determine_attention_backend( + self.device, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + self._custom_mask_buf is not None, # use_custom_mask + q_data_type, + kv_data_type, + ) + + get_module_args = ( + q_data_type, + kv_data_type, + q_data_type, + paged_kv_indptr.dtype, + head_dim, + PosEncodingMode[pos_encoding_mode].value, + window_left >= 0, # use_sliding_window + logits_soft_cap > 0, # use_logits_soft_cap + allow_fp16_qk_reduction, + ) + if self._backend == "fa2": + self._cached_module = get_batch_prefill_module(*get_module_args) + with self.device as device: + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + paged_kv_indptr_host, + self._max_total_num_rows or total_num_rows, + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + self.is_cuda_graph_enabled, + get_cuda_stream(device), + ) + else: + self._cached_module = get_batch_prefill_sm90_module(*get_module_args) + paged_kv_last_page_len_host = paged_kv_last_page_len.to("cpu") + kv_lens_arr_host = get_seq_lens( + paged_kv_indptr_host, paged_kv_last_page_len_host, page_size + ) + self._kv_lens_buffer[: len(kv_lens_arr_host)].copy_( + kv_lens_arr_host, non_blocking=non_blocking + ) + if page_size != 1: + vector_sparse_indptr_host = torch.cat( + [ + torch.tensor([0], dtype=torch.int32), + torch.cumsum(kv_lens_arr_host, dim=0, dtype=torch.int32), + ], + dim=0, + ) + self._vector_sparse_indptr_buffer[ + : len(vector_sparse_indptr_host) + ].copy_(vector_sparse_indptr_host, non_blocking=non_blocking) + else: + vector_sparse_indptr_host = paged_kv_indptr_host + + with self.device as device: + self._plan_info = self._cached_module.plan( + causal, + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + vector_sparse_indptr_host, + kv_lens_arr_host, + self._max_total_num_rows or total_num_rows, + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + self.is_cuda_graph_enabled, + get_cuda_stream(device), + ) self._causal = causal self._pos_encoding_mode = pos_encoding_mode @@ -1521,6 +1876,7 @@ def run( self, q: torch.Tensor, paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + *args, k_scale: Optional[float] = None, v_scale: Optional[float] = None, return_lse: Literal[False] = False, @@ -1531,6 +1887,7 @@ def run( self, q: torch.Tensor, paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + *args, k_scale: Optional[float] = None, v_scale: Optional[float] = None, return_lse: Literal[True] = True, @@ -1540,6 +1897,7 @@ def run( self, q: torch.Tensor, paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + *args, k_scale: Optional[float] = None, v_scale: Optional[float] = None, return_lse: bool = False, @@ -1620,7 +1978,7 @@ def run( else: mask_mode = MaskMode.NON_CAUSAL.value - if self._backend == "fa3": + if self._jit_backend or self._backend == "fa3": # NOTE(Zihao): we divide both stride_block and stride_n by stride_n # because we will multiply stride_n back in the kernel sparse_indices = block_sparse_indices_to_vector_sparse_offsets( @@ -1636,7 +1994,7 @@ def run( else: sparse_indices = self._paged_kv_indices_buf - out = self._cached_module.paged_run( + run_args = [ mask_mode, self._float_workspace_buffer, self._int_workspace_buffer, @@ -1658,7 +2016,12 @@ def run( rope_scale, rope_theta, lse, - ) + ] + if self._jit_backend: + run_args.extend(list(args)) + out = batch_prefill_with_paged_kv_cache_with_jit_module(self._jit_module, *run_args) + else: + out = self._cached_module.paged_run(*run_args) if v_scale is not None: out *= v_scale @@ -1813,6 +2176,7 @@ def __init__( custom_mask_buf: Optional[torch.Tensor] = None, qk_indptr_buf: Optional[torch.Tensor] = None, backend: str = "auto", + jit_module: Any = None, ) -> None: r"""Constructor of :class:`BatchPrefillWithRaggedKVCacheWrapper`. @@ -1858,12 +2222,26 @@ def __init__( device architecture and kernel availability. """ _check_kv_layout(kv_layout) + if backend.startswith("jit"): + if jit_module is None: + raise ValueError("backend is `jit` but jit_module is not provided.") + self._jit_backend = True + self._jit_module = jit_module + else: + if jit_module is not None: + raise ValueError(f"backend is `{backend}` but jit_module is provided.") + self._jit_backend = False self._kv_layout = kv_layout self._float_workspace_buffer = float_workspace_buffer self.device = float_workspace_buffer.device - self._int_workspace_buffer = torch.empty( - (64 * 1024 * 1024,), dtype=torch.uint8, device=self.device - ) + if backend in ["fa3", "auto"] or self._jit_backend: + self._int_workspace_buffer = torch.empty( + (64 * 1024 * 1024,), dtype=torch.uint8, device=self.device + ) + else: + self._int_workspace_buffer = torch.empty( + (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device + ) self._pin_memory_int_workspace_buffer = torch.empty( self._int_workspace_buffer.shape, dtype=torch.uint8, pin_memory=True ) @@ -2087,46 +2465,8 @@ def plan( self._cached_q_data_type = q_data_type self._cached_kv_data_type = kv_data_type - if self._backend == "auto": - self._backend = determine_attention_backend( - self.device, - PosEncodingMode[pos_encoding_mode].value, - allow_fp16_qk_reduction, - self._custom_mask_buf is not None, # use_custom_mask - q_data_type, - kv_data_type, - ) - - get_module_args = ( - q_data_type, - kv_data_type, - q_data_type, - kv_indptr.dtype, - head_dim, - PosEncodingMode[pos_encoding_mode].value, - window_left >= 0, # use_sliding_window - logits_soft_cap > 0, # use_logits_soft_cap - allow_fp16_qk_reduction, - ) - if self._backend == "fa2": - self._cached_module = get_batch_prefill_module(*get_module_args) - with self.device as device: - self._plan_info = self._cached_module.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._pin_memory_int_workspace_buffer, - qo_indptr_host, - kv_indptr_host, - self._max_total_num_rows or total_num_rows, - batch_size, - num_qo_heads, - num_kv_heads, - 1, # page_size - self.is_cuda_graph_enabled, - get_cuda_stream(device), - ) - else: - self._cached_module = get_batch_prefill_sm90_module(*get_module_args) + if self._jit_backend: + self._cached_module = get_batch_prefill_jit_module(self._backend, self._jit_module) kv_len_arr = kv_indptr_host[1:] - kv_indptr_host[:-1] with self.device as device: # NOTE(Zihao): there are some interface differences between fa2 and fa3 @@ -2147,6 +2487,67 @@ def plan( self.is_cuda_graph_enabled, get_cuda_stream(device), ) + else: + if self._backend == "auto": + self._backend = determine_attention_backend( + self.device, + PosEncodingMode[pos_encoding_mode].value, + allow_fp16_qk_reduction, + self._custom_mask_buf is not None, # use_custom_mask + q_data_type, + kv_data_type, + ) + + get_module_args = ( + q_data_type, + kv_data_type, + q_data_type, + kv_indptr.dtype, + head_dim, + PosEncodingMode[pos_encoding_mode].value, + window_left >= 0, # use_sliding_window + logits_soft_cap > 0, # use_logits_soft_cap + allow_fp16_qk_reduction, + ) + if self._backend == "fa2": + self._cached_module = get_batch_prefill_module(*get_module_args) + with self.device as device: + self._plan_info = self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + kv_indptr_host, + self._max_total_num_rows or total_num_rows, + batch_size, + num_qo_heads, + num_kv_heads, + 1, # page_size + self.is_cuda_graph_enabled, + get_cuda_stream(device), + ) + else: + self._cached_module = get_batch_prefill_sm90_module(*get_module_args) + kv_len_arr = kv_indptr_host[1:] - kv_indptr_host[:-1] + with self.device as device: + # NOTE(Zihao): there are some interface differences between fa2 and fa3 + # we should align the interface in the future + self._plan_info = self._cached_module.plan( + causal, + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + qo_indptr_host, + kv_indptr_host, + kv_len_arr, + self._max_total_num_rows or total_num_rows, + batch_size, + num_qo_heads, + num_kv_heads, + 1, # page_size + self.is_cuda_graph_enabled, + get_cuda_stream(device), + ) self._causal = causal self._pos_encoding_mode = pos_encoding_mode @@ -2190,6 +2591,7 @@ def run( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + *args, return_lse: Literal[False] = False, ) -> torch.Tensor: ... @@ -2199,6 +2601,7 @@ def run( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + *args, return_lse: Literal[True] = True, ) -> Tuple[torch.Tensor, torch.Tensor]: ... @@ -2207,6 +2610,7 @@ def run( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + *args, return_lse: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Compute batch prefill/append attention between query and kv-cache stored as @@ -2271,7 +2675,7 @@ def run( else: mask_mode = MaskMode.NON_CAUSAL.value - out = self._cached_module.ragged_run( + run_args = [ mask_mode, self._float_workspace_buffer, self._int_workspace_buffer, @@ -2291,7 +2695,12 @@ def run( rope_scale, rope_theta, lse, - ) + ] + if self._jit_backend: + run_args.extend(list(args)) + out = batch_prefill_with_ragged_kv_cache_with_jit_module(self._jit_module, *run_args) + else: + out = self._cached_module.ragged_run(*run_args) return (out, lse) if return_lse else out diff --git a/include/flashinfer/attention/hopper/attention_updater.cuh b/include/flashinfer/attention/hopper/attention_updater.cuh index f9fc1abb..aa46abf0 100644 --- a/include/flashinfer/attention/hopper/attention_updater.cuh +++ b/include/flashinfer/attention/hopper/attention_updater.cuh @@ -10,6 +10,8 @@ #include #include +#include "flashinfer/attention/hopper/utils.cuh" + namespace flashinfer { using namespace cute; diff --git a/include/flashinfer/attention/hopper/kernel_traits.cuh b/include/flashinfer/attention/hopper/kernel_traits.cuh index a144b708..31adeb49 100644 --- a/include/flashinfer/attention/hopper/kernel_traits.cuh +++ b/include/flashinfer/attention/hopper/kernel_traits.cuh @@ -40,13 +40,14 @@ struct SharedStorageQKVO { }; template struct AttentionKernelTraits { - using DTypeQ = DTypeQ_; - using DTypeKV = DTypeKV_; - using DTypeO = DTypeO_; - using IdType = IdType_; + using AttentionVariant = AttentionVariant_; + + using DTypeQ = cutlass_dtype_t; + using DTypeKV = cutlass_dtype_t; + using DTypeO = cutlass_dtype_t; + using IdType = cutlass_dtype_t; using DTypeQKAccum = float; static constexpr int CTA_Q = CTA_Q_; @@ -61,7 +62,6 @@ struct AttentionKernelTraits { // where only one warp inside a warp group is used for TMA. static constexpr int NUM_PRODUCER_THREADS = cutlass::NumThreadsPerWarp; - using AttentionVariant = AttentionVariant_; using TileShape_QKD = Shape, Int, Int>; static constexpr int NUM_STAGES = NUM_STAGES_; diff --git a/include/flashinfer/attention/hopper/params.cuh b/include/flashinfer/attention/hopper/params.cuh index fcd80a95..deb075a6 100644 --- a/include/flashinfer/attention/hopper/params.cuh +++ b/include/flashinfer/attention/hopper/params.cuh @@ -22,11 +22,12 @@ namespace flashinfer { -template +template struct SinglePrefillParams { using DTypeQ = DTypeQ_; using DTypeKV = DTypeKV_; using DTypeO = DTypeO_; + using IdType = IdType_; // The QKV matrices. DTypeQ* q_ptr; DTypeKV* k_ptr; diff --git a/include/flashinfer/attention/hopper/prefill_sm90.cuh b/include/flashinfer/attention/hopper/prefill_sm90.cuh index 1a1f0027..68e2dcfc 100644 --- a/include/flashinfer/attention/hopper/prefill_sm90.cuh +++ b/include/flashinfer/attention/hopper/prefill_sm90.cuh @@ -7,6 +7,7 @@ #ifndef FLASHINFER_ATTENTION_HOPPER_PREFILL_SM90_CUH_ #define FLASHINFER_ATTENTION_HOPPER_PREFILL_SM90_CUH_ +#include #include #include #include @@ -16,6 +17,7 @@ #include #include +#include #include "../../cutlass_utils.cuh" #include "../../exception.h" @@ -26,7 +28,6 @@ #include "kernel_traits.cuh" #include "mainloop.cuh" #include "mainloop_mma.cuh" -#include "params.cuh" #include "sparse_mainloop.cuh" #include "tile_scheduler.cuh" #include "utils.cuh" @@ -38,19 +39,18 @@ using namespace cute; template __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp, 1) - PrefillWithKVCacheKernel(CUTE_GRID_CONSTANT - typename CollectiveMainloop::Params const mainloop_params, - CUTE_GRID_CONSTANT - typename CollectiveEpilogue::Params const epilogue_params, - CUTE_GRID_CONSTANT - typename TileScheduler::Params const scheduler_params) { + PrefillWithKVCacheKernel( + CUTE_GRID_CONSTANT typename Ktraits::AttentionVariant::ParamsT const variant_params, + CUTE_GRID_CONSTANT typename CollectiveMainloop::Params const mainloop_params, + CUTE_GRID_CONSTANT typename CollectiveEpilogue::Params const epilogue_params, + CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params) { using DTypeQ = typename Ktraits::DTypeQ; using DTypeKV = typename Ktraits::DTypeKV; using DTypeO = typename Ktraits::DTypeO; using DTypeQKAccum = typename Ktraits::DTypeQKAccum; using TileShape_QKD = typename Ktraits::TileShape_QKD; using AttentionVariant = typename Ktraits::AttentionVariant; - AttentionVariant variant(mainloop_params); + AttentionVariant variant(variant_params); static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS; static constexpr int NUM_COPY_THREADS = cutlass::NumThreadsPerWarpGroup; @@ -230,11 +230,9 @@ __global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp } } -template +template cudaError_t SinglePrefillWithKVCacheKernelTraitsDispatched( - SinglePrefillParams& params, - cudaStream_t stream) { + typename KernelTraits::AttentionVariant::ParamsT& params, cudaStream_t stream) { using DTypeQ = typename KernelTraits::DTypeQ; using DTypeKV = typename KernelTraits::DTypeKV; using DTypeO = typename KernelTraits::DTypeO; @@ -272,7 +270,7 @@ cudaError_t SinglePrefillWithKVCacheKernelTraitsDispatched( auto kernel = (void*)PrefillWithKVCacheKernel; + LEFT_SLIDING_WINDOW, CAUSAL, Scheduler>; int smem_size = sizeof(typename KernelTraits::SharedStorage); FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -285,21 +283,20 @@ cudaError_t SinglePrefillWithKVCacheKernelTraitsDispatched( dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count); static constexpr int num_ctas = KernelTraits::NUM_WARPS * 32; dim3 block_dims(num_ctas); - void* args[] = {&mainloop_params, &epilogue_params, &scheduler_params}; + void* args[] = {¶ms, &mainloop_params, &epilogue_params, &scheduler_params}; FLASHINFER_CUDA_CALL(cudaLaunchKernel(kernel, grid_dims, block_dims, args, smem_size, stream)); return cudaSuccess; } -template +template cudaError_t BatchPrefillWithPagedKVCacheKernelTraitsDispatched( - BatchPrefillPagedParams& params, - cudaStream_t stream) { + typename KernelTraits::AttentionVariant::ParamsT& params, cudaStream_t stream) { using DTypeQ = typename KernelTraits::DTypeQ; using DTypeKV = typename KernelTraits::DTypeKV; using DTypeO = typename KernelTraits::DTypeO; + using IdType = typename KernelTraits::IdType; using TileShape_QKD = typename KernelTraits::TileShape_QKD; using CollectiveMainloop = SparseCollectiveMainloop; @@ -327,6 +324,7 @@ cudaError_t BatchPrefillWithPagedKVCacheKernelTraitsDispatched( params.o_stride_h), // layout_O params.lse_ptr, get_lse_gmem_layout(params.nnz_qo, params.num_qo_heads), // layout_LSE }); + typename Scheduler::Arguments scheduler_args = { params.work_indptr, params.head_indices, params.qo_tile_indices, params.qo_indptr, @@ -338,7 +336,7 @@ cudaError_t BatchPrefillWithPagedKVCacheKernelTraitsDispatched( // Get the ptr to kernel function. auto kernel = (void*)PrefillWithKVCacheKernel; + LEFT_SLIDING_WINDOW, CAUSAL, Scheduler>; int smem_size = sizeof(typename KernelTraits::SharedStorage); FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -351,22 +349,20 @@ cudaError_t BatchPrefillWithPagedKVCacheKernelTraitsDispatched( dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count); static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32; dim3 block_dims(ctaSize); - - void* args[] = {&mainloop_params, &epilogue_params, &scheduler_params}; + void* args[] = {¶ms, &mainloop_params, &epilogue_params, &scheduler_params}; FLASHINFER_CUDA_CALL(cudaLaunchKernel(kernel, grid_dims, block_dims, args, smem_size, stream)); return cudaSuccess; } -template +template cudaError_t BatchPrefillWithRaggedKVCacheKernelTraitsDispatched( - BatchPrefillRaggedParams& params, - cudaStream_t stream) { + typename KernelTraits::AttentionVariant::ParamsT& params, cudaStream_t stream) { using DTypeQ = typename KernelTraits::DTypeQ; using DTypeKV = typename KernelTraits::DTypeKV; using DTypeO = typename KernelTraits::DTypeO; + using IdType = typename KernelTraits::IdType; using TileShape_QKD = typename KernelTraits::TileShape_QKD; using CollectiveMainloop = CollectiveMainloop; @@ -406,7 +402,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheKernelTraitsDispatched( // Get the ptr to kernel function. auto kernel = (void*)PrefillWithKVCacheKernel; + LEFT_SLIDING_WINDOW, CAUSAL, Scheduler>; int smem_size = sizeof(typename KernelTraits::SharedStorage); FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -419,15 +415,15 @@ cudaError_t BatchPrefillWithRaggedKVCacheKernelTraitsDispatched( dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count); static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32; dim3 block_dims(ctaSize); - void* args[] = {&mainloop_params, &epilogue_params, &scheduler_params}; + void* args[] = {¶ms, &mainloop_params, &epilogue_params, &scheduler_params}; FLASHINFER_CUDA_CALL(cudaLaunchKernel(kernel, grid_dims, block_dims, args, smem_size, stream)); return cudaSuccess; } -template -cudaError_t SinglePrefillWithKVCacheDispatched(SinglePrefillParams& params, +template +cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::ParamsT& params, cudaStream_t stream) { static_assert(HEAD_DIM == 64 || HEAD_DIM == 128 || HEAD_DIM == 256); if (MASK_MODE == MaskMode::kCustom) { @@ -436,34 +432,35 @@ cudaError_t SinglePrefillWithKVCacheDispatched(SinglePrefillParams, - LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + /*NUM_STAGES_=*/2, AttentionVariant>, + LEFT_SLIDING_WINDOW, CAUSAL>(params, stream); } else if constexpr (HEAD_DIM == 128) { SinglePrefillWithKVCacheKernelTraitsDispatched< - AttentionKernelTraits, - LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + AttentionKernelTraits, + LEFT_SLIDING_WINDOW, CAUSAL>(params, stream); } else { // HEAD_DIM == 256; SinglePrefillWithKVCacheKernelTraitsDispatched< - AttentionKernelTraits, - LEFT_SLINDING_WINDOW, CAUSAL>(params, stream); + AttentionKernelTraits, + LEFT_SLIDING_WINDOW, CAUSAL>(params, stream); } cudaError_t status = cudaGetLastError(); return status; } -template -cudaError_t BatchPrefillWithRaggedKVCacheDispatched( - BatchPrefillRaggedParams& params, cudaStream_t stream) { +template +cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::ParamsT& params, + cudaStream_t stream) { static_assert(HEAD_DIM == 64 || HEAD_DIM == 128 || HEAD_DIM == 256); if (MASK_MODE == MaskMode::kCustom) { return cudaErrorNotSupported; // Not supported yet. @@ -471,32 +468,35 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( constexpr bool CAUSAL = MASK_MODE == MaskMode::kCausal; if constexpr (HEAD_DIM == 64) { BatchPrefillWithRaggedKVCacheKernelTraitsDispatched< - AttentionKernelTraits, - LEFT_SLINDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); + /*NUM_STAGES_=*/2, AttentionVariant>, + LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); } else if constexpr (HEAD_DIM == 128) { BatchPrefillWithRaggedKVCacheKernelTraitsDispatched< - AttentionKernelTraits, - LEFT_SLINDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); + AttentionKernelTraits, + LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); } else { // HEAD_DIM == 256; BatchPrefillWithRaggedKVCacheKernelTraitsDispatched< - AttentionKernelTraits, - LEFT_SLINDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); + AttentionKernelTraits, + LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); } cudaError_t status = cudaGetLastError(); return status; } -template -cudaError_t BatchPrefillWithPagedKVCacheDispatched( - BatchPrefillPagedParams& params, cudaStream_t stream) { +template +cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::ParamsT& params, + cudaStream_t stream) { static_assert(HEAD_DIM == 64 || HEAD_DIM == 128 || HEAD_DIM == 256); if (MASK_MODE == MaskMode::kCustom) { return cudaErrorNotSupported; // Not supported yet. @@ -505,23 +505,27 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( if constexpr (HEAD_DIM == 64) { // NOTE(Zihao): CTA_KV not tuned for HEAD_DIM == 64, need to optimize later BatchPrefillWithPagedKVCacheKernelTraitsDispatched< - AttentionKernelTraits, - LEFT_SLINDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); + /*NUM_STAGES_=*/2, AttentionVariant>, + LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); } else if constexpr (HEAD_DIM == 128) { BatchPrefillWithPagedKVCacheKernelTraitsDispatched< - AttentionKernelTraits, - LEFT_SLINDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); + AttentionKernelTraits, + LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); } else { // HEAD_DIM == 256; // NOTE(Zihao): CTA_KV not tuned for HEAD_DIM == 256, need to optimize later BatchPrefillWithPagedKVCacheKernelTraitsDispatched< - AttentionKernelTraits, - LEFT_SLINDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); + /*NUM_STAGES_=*/2, AttentionVariant>, + LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream); } cudaError_t status = cudaGetLastError(); return status; diff --git a/include/flashinfer/attention/hopper/variants.cuh b/include/flashinfer/attention/hopper/variants.cuh index 75d7c7bc..94a460a4 100644 --- a/include/flashinfer/attention/hopper/variants.cuh +++ b/include/flashinfer/attention/hopper/variants.cuh @@ -23,36 +23,48 @@ namespace flashinfer { +template struct StandardAttention { + using ParamsT = ParamsT_; + using DTypeQ = typename ParamsT::DTypeQ; + using DTypeKV = typename ParamsT::DTypeKV; + using DTypeO = typename ParamsT::DTypeO; + using IdType = typename ParamsT::IdType; + template using Updater = OnlineSoftmaxWithScale; - template __device__ StandardAttention(const ParamsT& params) {} - template - __device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx, - uint32_t qo_idx, uint32_t kv_idx, + template + __device__ __forceinline__ T LogitsTransform(const MainloopParams& params, T logits, + uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) { return logits; } }; +template struct LogitsSoftCap { + using ParamsT = ParamsT_; + using DTypeQ = typename ParamsT::DTypeQ; + using DTypeKV = typename ParamsT::DTypeKV; + using DTypeO = typename ParamsT::DTypeO; + using IdType = typename ParamsT::IdType; float pre_tanh_scale; float post_tanh_scale; + template using Updater = OnlineSoftmaxWithoutScale; - template __device__ LogitsSoftCap(const ParamsT& params) { pre_tanh_scale = (params.sm_scale_log2 * math::loge2) * math::ptx_rcp(params.logits_soft_cap); post_tanh_scale = math::log2e * params.logits_soft_cap; } - template - __device__ __forceinline__ T LogitsTransform(const ParamsT& params, T logits, uint32_t batch_idx, - uint32_t qo_idx, uint32_t kv_idx, + template + __device__ __forceinline__ T LogitsTransform(const MainloopParams& params, T logits, + uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) { return math::tanh(logits * pre_tanh_scale) * post_tanh_scale; } diff --git a/tests/test_jit_example.py b/tests/test_jit_example.py index f65f3359..12b9dd8b 100644 --- a/tests/test_jit_example.py +++ b/tests/test_jit_example.py @@ -7,6 +7,7 @@ from flashinfer.jit.attention import ( gen_customize_single_decode_module, gen_customize_single_prefill_module, + gen_customize_single_prefill_sm90_module, single_decode_suffix, single_prefill_suffix, ) @@ -309,6 +310,82 @@ def test_debug_print_logits(): torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) +def test_sm90_debug_print_logits(): + torch.manual_seed(42) + variant_decl = r""" +template +struct DebugPrintLogits { + using ParamsT = ParamsT_; + using DTypeQ = typename ParamsT::DTypeQ; + using DTypeKV = typename ParamsT::DTypeKV; + using DTypeO = typename ParamsT::DTypeO; + using IdType = typename ParamsT::IdType; + + template + using Updater = OnlineSoftmaxWithoutScale; + + static constexpr auto use_softmax = true; + + int qo_len; + int kv_len; + float sm_scale_log2; + + // Init + __device__ __host__ DebugPrintLogits(const ParamsT& params) { + sm_scale_log2 = params.sm_scale * math::log2e; + qo_len = params.qo_len; + kv_len = params.kv_len; + } + + template + __device__ __forceinline__ T LogitsTransform(const MainloopParams& params, T logits, + uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, + uint32_t qo_head_idx, uint32_t kv_head_idx) { + if (qo_idx < qo_len && kv_idx < kv_len) { + printf( + "---> LOGITS DEBUG: " + "qo_idx=%-5d " + "kv_idx=%-5d " + "sm_scale_log2=%-12.5f " + "logits=%-12.5f " + "\n", + qo_idx, + kv_idx, + sm_scale_log2, + static_cast(logits)); + } + logits *= sm_scale_log2; + return logits; + } +}; +""" + jit_module = gen_customize_single_prefill_sm90_module( + module_name="sm90_debug_print_logits", + dtype_q=torch.float16, # dtype_q + dtype_kv=torch.float16, # dtype_kv + dtype_o=torch.float16, # dtype_o + head_dim=128, # hidden_dim + additional_input_tensor_var_names=[], # additional_input_tensor_var_names + additional_input_tensor_var_types=[], # additional_input_tensor_var_types + additional_input_scalar_var_names=["sm_scale"], # additional_input_scalar_var_names + additional_input_scalar_var_types=["float"], # additional_input_scalar_var_types + variant_name="DebugPrintLogits", + variant_decl=variant_decl, + ) + + f = functools.partial(single_prefill_with_kv_cache_with_jit_module, jit_module) + + q = torch.randn(16, 2, 128, dtype=torch.float16, device="cuda") + k = torch.randn(16, 1, 128, dtype=torch.float16, device="cuda") + v = torch.randn(16, 1, 128, dtype=torch.float16, device="cuda") + sm_scale = 1. / math.sqrt(128) + o = f(q, k, v, sm_scale, mask_mode=MaskMode.NON_CAUSAL.value) + + p = torch.einsum("mhd,nhd->hmn", q.float(), k.float()) * sm_scale + o_ref = torch.einsum("hmn,nhd->mhd", torch.softmax(p, dim=-1), v.float()).half() + torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + + if __name__ == "__main__": test_single_decode_mask() test_flash_sigmoid()