Skip to content

Commit

Permalink
[KVCache] TIR attention kernel support for MLA
Browse files Browse the repository at this point in the history
This PR introduces the MLA attention kernels written in TIR.
It also implements the KV cache MLA computation logic.

A new unit test file is added to ensure the correctness of the
TIR kernels.

This PR also fixes a few TIR prefill kernel tile size initialization.
  • Loading branch information
MasterJH5574 committed Feb 2, 2025
1 parent 8b4df72 commit 96e3b0e
Show file tree
Hide file tree
Showing 8 changed files with 1,981 additions and 654 deletions.
1,638 changes: 1,275 additions & 363 deletions python/tvm/relax/frontend/nn/llm/kv_cache.py

Large diffs are not rendered by default.

24 changes: 22 additions & 2 deletions python/tvm/relax/frontend/nn/llm/tree_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,17 @@ def tree_attn(

bdx = 32
num_warps = 4
tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16
tile_x, tile_y, tile_z = (
64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1),
d,
64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1),
)
original_tile_y = tile_y
original_tile_z = tile_z
while (tile_x * tile_z) % (bdx * num_warps) != 0:
tile_z += original_tile_z
while (tile_x * tile_y) % (bdx * num_warps) != 0:
tile_y += original_tile_y

# Otherwise we would exceed maxComputeWorkgroupStorageSize
if (
Expand Down Expand Up @@ -476,7 +486,17 @@ def tree_attn_with_paged_kv_cache(

bdx = 32
num_warps = 4
tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16
tile_x, tile_y, tile_z = (
64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1),
d,
64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1),
)
original_tile_y = tile_y
original_tile_z = tile_z
while (tile_x * tile_z) % (bdx * num_warps) != 0:
tile_z += original_tile_z
while (tile_x * tile_y) % (bdx * num_warps) != 0:
tile_y += original_tile_y

# Otherwise we would exceed maxComputeWorkgroupStorageSize
if (
Expand Down
9 changes: 9 additions & 0 deletions src/runtime/relax_vm/kv_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,21 @@ TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_query_positions")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::GetQueryPositions);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::DebugGetKV);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv_mla")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::DebugGetKVMLA);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_attention_with_fused_qkv")
.set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
double attn_score_scaling_factor, NDArray qkv_data, NDArray o_data) {
kv_cache->AttentionWithFusedQKV(layer_id, std::move(qkv_data), NullOpt, std::move(o_data),
attn_score_scaling_factor);
});
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_mla_absorbed")
.set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
double attn_score_scaling_factor, NDArray q_data, NDArray compressed_kv_data,
NDArray k_pe_data, NDArray o_data) {
kv_cache->MLAAbsorbed(layer_id, std::move(q_data), std::move(compressed_kv_data),
std::move(k_pe_data), std::move(o_data), attn_score_scaling_factor);
});

// RNN State methods
TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_get").set_body_method<RNNState>(&RNNStateObj::Get);
Expand Down
24 changes: 10 additions & 14 deletions src/runtime/relax_vm/kv_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,20 +181,6 @@ class AttentionKVCacheObj : public KVStateObj {
virtual void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional<NDArray> mask,
NDArray o_data, double attn_score_scaling_factor) = 0;

/*!
* \brief Compute attention with Q/K/V data.
* \param layer_id The model layer where the attention compute happens.
* \param q_data The input Q data, in layout `(total_length, num_qo_heads, head_dim)`
* \param k_data The input K data, in layout `(total_length, num_kv_heads, head_dim)`
* \param v_data The input V data, in layout `(total_length, num_kv_heads, head_dim)`
* \param mask The input mask data, in layout `(total_sqr_length)`.
* \param o_data The output O data, in layout `(total_length, num_qo_heads, head_dim)`.
* \param attn_score_scaling_factor The additional attention scaling factor.
*/
virtual void AttentionWithSeparateQKV(int64_t layer_id, NDArray q_data, NDArray k_data,
NDArray v_data, Optional<NDArray> mask, NDArray o_data,
double attn_score_scaling_factor) = 0;

/*!
* \brief Compute multi-head latent attention after applying weight absorption.
* \param layer_id The model layer where the attention compute happens.
Expand Down Expand Up @@ -275,6 +261,16 @@ class AttentionKVCacheObj : public KVStateObj {
virtual void DebugGetKV(int64_t seq_id, //
int64_t start_pos, int64_t end_pos, NDArray k_data, NDArray v_data) = 0;

/*!
* \brief Fetch the compact K/V data of the given sequence for MLA cache.
* \param seq_id The sequence whose K/V data is to be fetched.
* \param start_pos The start position (inclusive) of the K/V data to fetch.
* \param end_pos The end position (exclusive) of the K/V data to fetch.
* \param kv_data The output KV data of the given sequence in layout elaborated above.
*/
virtual void DebugGetKVMLA(int64_t seq_id, int64_t start_pos, int64_t end_pos,
NDArray kv_data) = 0;

/*!
* \brief Set the K/V data of the given sequence from input K/V data.
* `start_pos` (inclusive) controls starting position of K/V data
Expand Down
Loading

0 comments on commit 96e3b0e

Please sign in to comment.