Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BatchPrefillWithPagedKVCacheDispatched for sm90 #745

Open
tangcy98 opened this issue Jan 21, 2025 · 1 comment
Open

BatchPrefillWithPagedKVCacheDispatched for sm90 #745

tangcy98 opened this issue Jan 21, 2025 · 1 comment

Comments

@tangcy98
Copy link

tangcy98 commented Jan 21, 2025

Hello😀

I am currently implementing batch decode computation using the BatchPrefillHandler and BatchPrefillWithPagedKVCacheWrapper in FlashInfer.

template <typename DTypeQ, typename DTypeKV, typename DTypeO, typename IdType>
cudaError_t BatchPrefillWithPagedKVCacheWrapper(
BatchPrefillHandler* handler, DTypeQ* q, IdType* qo_indptr, IdType* q_offset,
paged_kv_t<DTypeKV, IdType> paged_kv, DTypeO* o, float* lse, uint32_t num_qo_heads,
bool causal = true, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone,
bool allow_fp16_qk_reduction = false, std::optional<float> maybe_sm_scale = std::nullopt,
float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) {

And I mainly refer to src/bench_batch_decode.cu.

template <typename T, typename TKV>
void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) {
constexpr size_t head_dim = 128;
constexpr auto pos_encoding_mode = PosEncodingMode::kNone;
size_t seqlen = state.get_int64("seqlen");
size_t batch_size = state.get_int64("batch_size");
size_t page_size = state.get_int64("page_size");
size_t num_qo_heads = state.get_int64("num_qo_heads");
size_t num_kv_heads = state.get_int64("num_kv_heads");
// KV cache:
auto pages_per_seq = (seqlen + page_size - 1) / page_size;
auto num_pages = pages_per_seq * batch_size;
std::vector<int32_t> kv_indptr_host{0};
std::vector<int32_t> kv_indicies_host;
std::vector<int32_t> kv_last_page_len_host;
for (size_t i = 0; i < batch_size; ++i) {
for (size_t p = 0; p < pages_per_seq; ++p) {
kv_indicies_host.push_back(i * pages_per_seq + p);
}
kv_indptr_host.push_back(kv_indptr_host.back() + pages_per_seq);
kv_last_page_len_host.push_back((seqlen - 1) % page_size + 1);
}
thrust::device_vector<TKV> k_data(num_pages * num_kv_heads * page_size * head_dim);
thrust::device_vector<TKV> v_data(num_pages * num_kv_heads * page_size * head_dim);
thrust::device_vector<int32_t> kv_indptr(kv_indptr_host);
thrust::device_vector<int32_t> kv_indices(kv_indicies_host);
thrust::device_vector<int32_t> kv_last_page_len(kv_last_page_len_host);
paged_kv_t<TKV, int32_t> paged_kv(
num_kv_heads, page_size, head_dim, batch_size, kv_layout,
thrust::raw_pointer_cast(k_data.data()), thrust::raw_pointer_cast(v_data.data()),
thrust::raw_pointer_cast(kv_indices.data()), thrust::raw_pointer_cast(kv_indptr.data()),
thrust::raw_pointer_cast(kv_last_page_len.data()));
// Allocate input data:
thrust::device_vector<T> q(batch_size * num_qo_heads * head_dim);
thrust::device_vector<T> o(batch_size * num_qo_heads * head_dim);
std::vector<int32_t> qo_indptr_h{0};
for (uint32_t i = 0; i < batch_size; ++i) {
qo_indptr_h.push_back(qo_indptr_h.back() + 1);
}
thrust::device_vector<int32_t> qo_indptr_d(qo_indptr_h);
state.add_global_memory_reads<uint8_t>(
vec_bytes(q) + (num_pages * 2 * num_kv_heads * page_size * head_dim) * sizeof(TKV) +
vec_bytes(kv_indptr) + vec_bytes(kv_indices) + vec_bytes(kv_last_page_len),
"Read");
state.add_global_memory_writes<uint8_t>(vec_bytes(o), "Write");
BatchPrefillHandler handler;
size_t float_workspace_size_in_bytes = 128 * 1024 * 1024;
thrust::device_vector<char> float_buffer(float_workspace_size_in_bytes);
size_t int_workspace_size_in_bytes = 8 * 1024 * 1024;
thrust::device_vector<char> int_buffer(int_workspace_size_in_bytes);
handler.Plan<T, int32_t>(
(void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes,
(void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes,
qo_indptr_h.data(), kv_indptr_host.data(), /*total_num_rows=*/batch_size, batch_size,
num_qo_heads, num_kv_heads, head_dim, page_size);
state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) {
cudaError_t status = BatchPrefillWithPagedKVCacheWrapper<T, TKV, T, int32_t>(
&handler, thrust::raw_pointer_cast(q.data()), thrust::raw_pointer_cast(qo_indptr_d.data()),
/*q_offset=*/nullptr, paged_kv, thrust::raw_pointer_cast(o.data()),
/*lse=*/nullptr, num_qo_heads,
/*causal=*/false, pos_encoding_mode);
});
}

🤔Recently, I noticed that FlashInfer has been updated to support the sm90 Hopper architecture. The BatchPrefillWithPagedKVCacheDispatched (prefill.cuh) called within BatchPrefillWithPagedKVCacheWrapper has a counterpart implementation for sm90 (prefill_sm90.cuh).

Although they share the same function name, their parameters and templates are not identical.

See

template <uint32_t CTA_TILE_Q, uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE,
bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE, typename AttentionVariant>
cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::ParamsT params,
typename AttentionVariant::DTypeO* tmp_v,
float* tmp_s, cudaStream_t stream) {

and

template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLIDING_WINDOW,
bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant>
cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::ParamsT& params,
cudaStream_t stream) {

Based on my observations, the sm90 version does not appear to be integrated into BatchPrefillWithPagedKVCacheWrapper , and there are no new tests or benchmarks added for batch decode.

😊I am very interested in knowing whether you have attempted to use BatchPrefillWithPagedKVCacheDispatched for batch decode computation so far. In the future, will it be possible to enable the sm90 implementation without modifying my existing code, or with only minor changes?Thank you very much for your time and effort. I look forward to your response.

Best regards😇

@yzh119
Copy link
Collaborator

yzh119 commented Jan 21, 2025

@tangcy98 the sm90 version was integrated into BatchPrefillWithPagedKVCacheWrapper but not BatchDecodeWithPagedKVCacheWrapper, becaues the minimal wgmma size on query dimension is 64, while the unpacked (after head-group fusion in Appendix 1) query length for decoding is at most num_qo_heads / num_kv_heads for decoding, so most of the flops are wasted. As a comparison, the minimal tile size of our fa2 template is 16. So I keep using the fa2 template for decoding, using fa3 template might improve a little bit but I haven't verified it yet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants