diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index 399e418c464b..844b237381a0 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -30,7 +30,16 @@ from tvm.target import Target from .position_embedding import llama_rope_with_position_map, switch_rope_freq_func -from .tree_attn import tree_attn, tree_attn_with_paged_kv_cache +from .tree_attn import ( + tree_attn, + tree_attn_cpu, + tree_attn_with_paged_kv_cache, + tree_attn_with_paged_kv_cache_cpu, +) + + +def _var_cpu(dtype): + return T.alloc_buffer((1,), dtype) def get_max_num_threads_per_block(target: Target) -> int: @@ -371,23 +380,230 @@ def __init__( # pylint: disable=too-many-locals # pylint: disable=line-too-long # fmt: off bb.add_func(_kv_cache_transpose_append(num_key_value_heads, head_dim, dtype), "kv_cache_transpose_append"), - bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, False, rope_scaling, target), "tir_attention_prefill"), - bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, False, rope_scaling, target), "tir_attention_decode"), - bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, head_dim, dtype, True, rope_scaling, target), "tir_attention_prefill_sliding_window"), - bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, head_dim, dtype, True, rope_scaling, target), "tir_attention_decode_sliding_window"), - bb.add_func(_attention_prefill_ragged(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_ragged"), - bb.add_func(_merge_state_inplace(num_attention_heads, head_dim, dtype, target), "tir_attention_merge_state"), - bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, rotary_dim), "tir_split_rotary"), - bb.add_func(_copy_single_page(num_key_value_heads, page_size, head_dim, dtype, target), "kv_cache_copy_single_page"), - bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"), - bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype, target), "kv_cache_compact_kv_copy"), - bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask"), - bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache"), - rope_ext_factors, - rx.PrimValue(enable_disaggregation), # fmt: on # pylint: enable=line-too-long ] + + if str(target.kind) == "llvm": + args.extend( + [ + bb.add_func( + _attention_prefill_cpu( + num_key_value_heads, + num_attention_heads, + head_dim, + dtype, + False, + rope_scaling, + ), + "tir_attention_prefill_cpu", + ), + bb.add_func( + _attention_decode_cpu( + num_key_value_heads, + num_attention_heads, + head_dim, + dtype, + False, + rope_scaling, + ), + "tir_attention_decode_cpu", + ), + bb.add_func( + _attention_prefill_cpu( + num_key_value_heads, + num_attention_heads, + head_dim, + dtype, + True, + rope_scaling, + ), + "tir_attention_prefill_cpu_sliding_window", + ), + bb.add_func( + _attention_decode_cpu( + num_key_value_heads, + num_attention_heads, + head_dim, + dtype, + True, + rope_scaling, + ), + "tir_attention_decode_cpu_sliding_window", + ), + bb.add_func( + _attention_prefill_ragged_cpu( + num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling + ), + "tir_attention_prefill_ragged_cpu", + ), + bb.add_func( + _merge_state_inplace_cpu(dtype), + "tir_attention_merge_state_cpu", + ), + bb.add_func( + llama_rope_with_position_map( + rope_theta, + rope_scale, + head_dim, + num_attention_heads, + num_key_value_heads, + dtype, + rope_scaling, + rotary_dim, + ), + "tir_split_rotary", + ), + bb.add_func( + _copy_single_page_cpu(num_key_value_heads, page_size, head_dim, dtype), + "kv_cache_copy_single_page_cpu", + ), + bb.add_func( + _kv_cache_debug_get_kv( + num_hidden_layers, num_key_value_heads, head_dim, dtype + ), + "kv_cache_debug_get_kv", + ), + bb.add_func( + _compact_kv_copy_cpu(num_key_value_heads, head_dim, dtype), + "kv_cache_compact_kv_copy_cpu", + ), + bb.add_func( + tree_attn_cpu( + num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling + ), + "tir_attention_prefill_with_tree_mask_cpu", + ), + bb.add_func( + tree_attn_with_paged_kv_cache_cpu( + num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling + ), + "tir_attention_prefill_with_tree_mask_with_paged_kv_cache_cpu", + ), + rope_ext_factors, + rx.PrimValue(enable_disaggregation), + ] + ) + else: + args.extend( + [ + bb.add_func( + _attention_prefill( + num_key_value_heads, + num_attention_heads, + head_dim, + dtype, + False, + rope_scaling, + target, + ), + "tir_attention_prefill", + ), + bb.add_func( + _attention_decode( + num_key_value_heads, + num_attention_heads, + head_dim, + dtype, + False, + rope_scaling, + target, + ), + "tir_attention_decode", + ), + bb.add_func( + _attention_prefill( + num_key_value_heads, + num_attention_heads, + head_dim, + dtype, + True, + rope_scaling, + target, + ), + "tir_attention_prefill_sliding_window", + ), + bb.add_func( + _attention_decode( + num_key_value_heads, + num_attention_heads, + head_dim, + dtype, + True, + rope_scaling, + target, + ), + "tir_attention_decode_sliding_window", + ), + bb.add_func( + _attention_prefill_ragged( + num_key_value_heads, + num_attention_heads, + head_dim, + dtype, + rope_scaling, + target, + ), + "tir_attention_prefill_ragged", + ), + bb.add_func( + _merge_state_inplace(num_attention_heads, head_dim, dtype, target), + "tir_attention_merge_state", + ), + bb.add_func( + llama_rope_with_position_map( + rope_theta, + rope_scale, + head_dim, + num_attention_heads, + num_key_value_heads, + dtype, + rope_scaling, + rotary_dim, + ), + "tir_split_rotary", + ), + bb.add_func( + _copy_single_page(num_key_value_heads, page_size, head_dim, dtype, target), + "kv_cache_copy_single_page", + ), + bb.add_func( + _kv_cache_debug_get_kv( + num_hidden_layers, num_key_value_heads, head_dim, dtype + ), + "kv_cache_debug_get_kv", + ), + bb.add_func( + _compact_kv_copy(num_key_value_heads, head_dim, dtype, target), + "kv_cache_compact_kv_copy", + ), + bb.add_func( + tree_attn( + num_key_value_heads, + num_attention_heads, + head_dim, + dtype, + rope_scaling, + target, + ), + "tir_attention_prefill_with_tree_mask", + ), + bb.add_func( + tree_attn_with_paged_kv_cache( + num_key_value_heads, + num_attention_heads, + head_dim, + dtype, + rope_scaling, + target, + ), + "tir_attention_prefill_with_tree_mask_with_paged_kv_cache", + ), + rope_ext_factors, + rx.PrimValue(enable_disaggregation), + ] + ) + super().__init__( _expr=rx.call_pure_packed( "vm.builtin.paged_attention_kv_cache_create_reduced", @@ -553,6 +769,161 @@ def _get_seq_offset(pos, seq_id, length_info, sliding_window): ) +def _attention_prefill_cpu(h_kv, h_q, d, dtype, sliding_window: bool, rope_scaling: Dict[str, Any]): + global_symbol = "batch_prefill_paged_kv_cpu" + if sliding_window: + global_symbol += "_sliding_window" + + group_size = h_q // h_kv + sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) + # pylint: disable=line-too-long,too-many-branches + # fmt: off + @T.prim_func + def batch_prefill_paged_kv_cpu( + _0: T.int32, # pylint: disable=unused-argument + var_q: T.handle, # [total_len, h_q, d] + var_q_indptr: T.handle, # [batch_size + 1] + var_pages: T.handle, # [max_num_pages, 2, h_kv, page_size, d] + var_page_indptr: T.handle, # [batch_size + 1] + var_page_values: T.handle, # [nnz_pages] + var_length_info: T.handle, # [b] when sliding window = False, or otherwise [3, b] + var_k_rope_pos_offset: T.handle, # [b] + var_q_rope_position: T.handle, # [total_len] + var_output: T.handle, # [total_len, h_q, d] + var_lse: T.handle, # [total_len, h_q] + causal: T.int32, + rotary_mode: T.int32, + rope_scale: T.float32, + rope_theta: T.float32, + attn_score_scaling_factor: T.float32, + ): + T.func_attr({"global_symbol": global_symbol}) + batch_size = T.int32(is_size_var=True) + total_len = T.int32(is_size_var=True) + nnz_pages = T.int32(is_size_var=True) + max_num_pages = T.int32(is_size_var=True) + q_indptr_elem_offset = T.int32(is_size_var=True) + page_indptr_elem_offset = T.int32(is_size_var=True) + page_values_elem_offset = T.int32(is_size_var=True) + k_rope_pos_offset_elem_offset = T.int32(is_size_var=True) + q_rope_position_elem_offset = T.int32(is_size_var=True) + length_info_elem_offset = T.int32(is_size_var=True) + + q = T.match_buffer(var_q, (total_len, h_q, d), dtype) + q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset) + pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), dtype) + page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", elem_offset=page_indptr_elem_offset) + page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", elem_offset=page_values_elem_offset) + k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset) + q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", elem_offset=q_rope_position_elem_offset) + output = T.match_buffer(var_output, (total_len, h_q, d), dtype) + lse = T.match_buffer(var_lse, (total_len, h_q), "float32") # pylint: disable=unused-variable + # The length information of the sequences. + # - It is in shape `(3, batch_size)` when sliding window is enabled. + # For a sequence "i", location + # - "(0, i)" is the number of KV slots used in the last page of the seq ("last_page_len"), + # - "(1, i)" is the starting offset of the sliding window in the seq, + # - "(2, i)" is the attn sink length of the sequence. + # - It is in shape `(batch_size,)` when sliding window is disabled, + # denoting the "last_page_len". + length_info = _declare_length_info(var_length_info, batch_size, sliding_window, length_info_elem_offset) + + + for h_qo in T.serial(h_q): + for b_idx in T.serial(batch_size): + with T.block("attn"): + O_local = T.alloc_buffer((d, ), "float32") + Q_local = T.alloc_buffer((d, ), "float32") + K_local = T.alloc_buffer((d, ), "float32") + V_local = T.alloc_buffer((d, ), "float32") + + kv_chunk_len = T.alloc_buffer((1, ), "int32") + + m_val = T.alloc_buffer((1, ), "float32") + new_m = T.alloc_buffer((1, ), "float32") + d_val = T.alloc_buffer((1, ), "float32") + S_val = T.alloc_buffer((1, ), "float32") + scale_O = T.alloc_buffer((1, ), "float32") + factor = T.alloc_buffer((1, ), "float32") + cur_page_indptr_begin: T.int32 = page_indptr[b_idx] + cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1] + #max_kv_len: T.int32 = max_num_pages * 16 + kv_chunk_len[0] = T.if_then_else( + cur_page_indptr_begin != cur_page_indptr_end, + _get_kv_chunk_len(cur_page_indptr_end - cur_page_indptr_begin, 16, b_idx, length_info, sliding_window), + 0 + ) + + + for q_idx in T.serial(q_indptr[b_idx + 1] - q_indptr[b_idx]): + #init m, d, O + m_val[0] = -5e4 + d_val[0] = 1.0 + for d_idx in T.serial(d): + O_local[d_idx] = 0.0 + curl_q: T.int32 = q_indptr[b_idx] + q_idx + + for d_idx in T.serial(d): + + Q_local[d_idx] = T.if_then_else( + rotary_mode == 1, + _rope(q, q_rope_position[curl_q], d, rope_theta, rope_scale, (curl_q, h_qo, d_idx), dtype, rope_scaling), + q[curl_q, h_qo, d_idx] + ) + for row_idx in T.serial(max_num_pages * 16): + if row_idx < kv_chunk_len[0]: + # seq_offset: T.int32(is_size_var=True) = _get_seq_offset(row_idx, b_idx, length_info, sliding_window) + #seq_offset: T.int32(is_size_var=True) = row_idx + page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + (_get_seq_offset(row_idx, b_idx, length_info, sliding_window) // 16)] + page_offset: T.int32(is_size_var=True) = _get_seq_offset(row_idx, b_idx, length_info, sliding_window) % 16 + + # Load KV + for d_idx in T.serial(d): + K_local[d_idx] = T.if_then_else( + rotary_mode == 1, + _rope(pages, k_rope_pos_offset[b_idx] + row_idx, d, rope_theta, rope_scale, (page_no, 0, h_qo // group_size, page_offset, d_idx), dtype, rope_scaling), + pages[page_no, 0, h_qo // group_size, page_offset, d_idx] + ) + V_local[d_idx] = pages[page_no, 1, h_qo // group_size, page_offset, d_idx] + + # Compute S + # Q[i] * K[i] * attn_score * sm_scale + S_val[0] = 0.0 + for d_idx in T.serial(d): + S_val[0] += Q_local[d_idx] * K_local[d_idx] + S_val[0] *= attn_score_scaling_factor * sm_scale + + # update m_val, d_val , O_local + if _causal_mask(causal, + row=q_idx, + col=row_idx, + kv_len=kv_chunk_len[0], + qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): + new_m[0] = T.max(m_val[0], S_val[0]) + else: + S_val[0] = -5e4 + # update d_val + d_val[0] *= T.exp2(m_val[0] - new_m[0]) + d_val[0] += T.exp2(S_val[0] - new_m[0]) + + # restore O_local then update O_local + scale_O[0] = T.exp2(m_val[0] - new_m[0]) + m_val[0] = new_m[0] + factor[0] = T.exp2(S_val[0] - m_val[0]) + for d_idx in T.serial(d): + O_local[d_idx] = O_local[d_idx] * scale_O[d_idx] + + + for d_idx in T.serial(d): + O_local[d_idx] += V_local[d_idx] * factor[0] + # Store Output + for d_idx in T.serial(d): + O_local[d_idx] = O_local[d_idx] /d_val[0] + output[curl_q, h_qo, d_idx] = O_local[d_idx] + lse[curl_q, h_qo] = m_val[0] + T.log2(d_val[0]) + return batch_prefill_paged_kv_cpu + + def _attention_prefill( h_kv, h_q, d, dtype, sliding_window: bool, rope_scaling: Dict[str, Any], target: Target ): @@ -920,6 +1291,189 @@ def apply_to_md(sch, block): return sch.mod["main"].with_attr("tir.is_scheduled", 1) +def _attention_decode_cpu( + num_kv_heads, + num_qo_heads, + head_dim, + qkv_dtype, + sliding_window: bool, + rope_scaling: Dict[str, Any], +): + log2e = math.log2(math.exp(1)) + H_qo = num_qo_heads + H_kv = num_kv_heads + D = head_dim + group_size = num_qo_heads // num_kv_heads + + global_symbol = "batch_decode_paged_kv_cpu" + if sliding_window: + global_symbol += "_sliding_window" + + @T.prim_func(check_well_formed=False) + def batch_decode_paged_kv( + _0: T.int32, # pylint: disable=unused-argument + Q_handle: T.handle, + pages_handle: T.handle, + page_table_indptr_handle: T.handle, + page_table_values_handle: T.handle, + var_length_info: T.handle, # [b] when sliding window = False, or otherwise [3, b] + k_rope_pos_offset_handle: T.handle, + q_rope_position_handle: T.handle, + output_handle: T.handle, + lse_handle: T.handle, + rotary_mode: T.int32, + rope_scale: T.float32, + rope_theta: T.float32, + attn_score_scaling_factor: T.float32, + ): + T.func_attr({"tir.is_scheduled": 1, "global_symbol": global_symbol}) + B = T.int32(is_size_var=True) + nnz_pages = T.int32(is_size_var=True) + max_num_pages = T.int32(is_size_var=True) + page_indptr_elem_offset = T.int32(is_size_var=True) + page_values_elem_offset = T.int32(is_size_var=True) + k_rope_pos_offset_elem_offset = T.int32(is_size_var=True) + q_rope_position_elem_offset = T.int32(is_size_var=True) + length_info_elem_offset = T.int32(is_size_var=True) + + Q = T.match_buffer(Q_handle, (B, H_qo, D), qkv_dtype) # query 值 + pages = T.match_buffer(pages_handle, (max_num_pages, 2, H_kv, 16, D), qkv_dtype) + page_table_indptr = T.match_buffer( + page_table_indptr_handle, (B + 1,), "int32", elem_offset=page_indptr_elem_offset + ) + page_table_values = T.match_buffer( + page_table_values_handle, (nnz_pages,), "int32", elem_offset=page_values_elem_offset + ) + k_rope_pos_offset = T.match_buffer( + k_rope_pos_offset_handle, (B,), "int32", elem_offset=k_rope_pos_offset_elem_offset + ) + q_rope_position = T.match_buffer( + q_rope_position_handle, (B,), "int32", elem_offset=q_rope_position_elem_offset + ) + output = T.match_buffer(output_handle, (B, H_qo, D), qkv_dtype) + lse = T.match_buffer(lse_handle, (B, H_qo), "float32") # pylint: disable=unused-variable + # The length information of the sequences. + # - It is in shape `(3, batch_size)` when sliding window is enabled. + # For a sequence "i", location + # - "(0, i)" is the number of KV slots used in the last page of the seq ("last_page_len"), + # - "(1, i)" is the starting offset of the sliding window in the seq, + # - "(2, i)" is the attn sink length of the sequence. + # - It is in shape `(batch_size,)` when sliding window is disabled, + # denoting the "last_page_len". + length_info = _declare_length_info( + var_length_info, B, sliding_window, length_info_elem_offset + ) + + sm_scale = 1.0 / math.sqrt(float(D)) * log2e + + for b in T.serial(B): + with T.block("attn"): + O_local = T.alloc_buffer((D,), "float32") + Q_local = T.alloc_buffer((D,), "float32") + K_local = T.alloc_buffer((D,), "float32") + V_local = T.alloc_buffer((D,), "float32") + + kv_chunk_len = T.alloc_buffer((1,), "int32") + + m_val = T.alloc_buffer((1,), "float32") + new_m = T.alloc_buffer((1,), "float32") + d_val = T.alloc_buffer((1,), "float32") + S_val = T.alloc_buffer((1,), "float32") + scale_O = T.alloc_buffer((1,), "float32") + factor = T.alloc_buffer((1,), "float32") + + cur_page_indptr_begin: T.int32 = page_table_indptr[b] + cur_page_indptr_end: T.int32 = page_table_indptr[b + 1] + + kv_chunk_len[0] = T.if_then_else( + cur_page_indptr_begin != cur_page_indptr_end, + _get_kv_chunk_len( + cur_page_indptr_end - cur_page_indptr_begin, + 16, + b, + length_info, + sliding_window, + ), + 0, + ) + + for h_qo in T.serial(H_qo): + m_val[0] = -5e4 + d_val[0] = 1.0 + + for d in T.serial(D): + O_local[d] = 0.0 + + for d in T.serial(D): + Q_local[d] = T.if_then_else( + rotary_mode == 1, + _rope( + Q, + q_rope_position[b], + head_dim, + rope_theta, + rope_scale, + (b, h_qo, d), + qkv_dtype, + rope_scaling, + ), + Q[b, h_qo, d], + ) + + for row_idx in T.serial(kv_chunk_len[0]): + seq_offset: T.int32(is_size_var=True) = _get_seq_offset( + row_idx, b, length_info, sliding_window + ) + page_no: T.int32(is_size_var=True) = page_table_values[ + cur_page_indptr_begin + (seq_offset // 16) + ] + page_offset: T.int32(is_size_var=True) = seq_offset % 16 + + for d in T.serial(D): + K_local[d] = T.if_then_else( + rotary_mode == 1, + _rope( + pages, + k_rope_pos_offset[b] + row_idx, + head_dim, + rope_theta, + rope_scale, + (page_no, 0, h_qo // group_size, page_offset, d), + qkv_dtype, + rope_scaling, + ), + pages[page_no, 0, h_qo // group_size, page_offset, d], + ) + S_val[0] = 0.0 + for d in T.serial(D): + S_val[0] += Q_local[d] * K_local[d] + S_val[0] *= attn_score_scaling_factor * sm_scale + + new_m[0] = T.max(m_val[0], S_val[0]) + d_val[0] = (d_val[0] * T.exp2(m_val[0] - new_m[0])) + T.exp2( + S_val[0] - new_m[0] + ) + + scale_O[0] = T.exp2(m_val[0] - new_m[0]) + + for d in T.serial(D): + O_local[d] = O_local[d] * scale_O[0] + + m_val[0] = new_m[0] + for d in T.serial(D): + V_local[d] = pages[page_no, 1, h_qo // group_size, page_offset, d] + + factor[0] = T.exp2(S_val[0] - m_val[0]) + for d in T.serial(D): + O_local[d] = O_local[d] + V_local[d] * factor[0] + for d in T.serial(D): + O_local[d] = O_local[d] / d_val[0] + output[b, h_qo, d] = O_local[d] + lse[b, h_qo] = m_val[0] + T.log2(d_val[0]) + + return batch_decode_paged_kv + + def _attention_decode( num_kv_heads, num_qo_heads, @@ -1179,6 +1733,47 @@ def batch_decode_paged_kv( return batch_decode_paged_kv +def _merge_state_inplace_cpu(v_dtype): + @T.prim_func + def merge_state_inplace_cpu( + v: T.handle, + s: T.handle, + v_other: T.handle, + s_other: T.handle, + ): + T.func_attr({"tir.is_scheduled": 1}) + N = T.int32(is_size_var=True) + H = T.int32(is_size_var=True) + D = T.int32(is_size_var=True) + + V = T.match_buffer(v, (N, H, D), v_dtype) + S = T.match_buffer(s, (N, H), "float32") + V_other = T.match_buffer(v_other, (N, H, D), v_dtype) + S_other = T.match_buffer(s_other, (N, H), "float32") + + for n in T.serial(N): + for h in T.serial(H): + with T.block("merge"): + s_val = _var_cpu("float32") + s_other_val = _var_cpu("float32") + s_max = _var_cpu("float32") + scale = _var_cpu("float32") + other_scale = _var_cpu("float32") + + s_val[0] = S[n, h] + s_other_val[0] = S_other[n, h] + s_max[0] = T.max(s_val[0], s_other_val[0]) + s_val[0] = T.exp2(s_val[0] - s_max[0]) + s_other_val[0] = T.exp2(s_other_val[0] - s_max[0]) + scale[0] = s_val[0] / (s_val[0] + s_other_val[0]) + other_scale[0] = s_other_val[0] / (s_val[0] + s_other_val[0]) + for d in T.serial(D): + V[n, h, d] = V[n, h, d] * scale[0] + V_other[n, h, d] * other_scale[0] + S[n, h] = T.log2(s_val[0] + s_other_val[0]) + s_max[0] + + return merge_state_inplace_cpu + + def _merge_state_inplace(num_heads, head_dim, v_dtype, target: Target): v_dtype_bytes = 2 VEC_SIZE = min(max(8 // v_dtype_bytes, head_dim // 32), 4) @@ -1577,6 +2172,175 @@ def apply_schedule(sch): return sch.mod["main"].with_attr("tir.is_scheduled", 1) +def _attention_prefill_ragged_cpu(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any]): + group_size = h_q // h_kv + sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) + + @T.prim_func + def batch_prefill_ragged_kv( # pylint: disable=too-many-branches + var_q: T.handle, # [total_len, h_q, d] + var_q_indptr: T.handle, # [batch_size + 1] + var_k: T.handle, # [total_len, h_kv, d] + var_v: T.handle, # [total_len, h_kv, d] + var_kv_indptr: T.handle, # [batch_size + 1] + var_q_rope_position: T.handle, # [total_q_len] + var_k_rope_pos_offset: T.handle, # [b] + var_output: T.handle, # [total_len, h_q, d] + var_lse: T.handle, # [total_len, h_q] + causal: T.int32, + rotary_mode: T.int32, + rope_scale: T.float32, + rope_theta: T.float32, + attn_score_scaling_factor: T.float32, + ): + batch_size = T.int32(is_size_var=True) + qo_len = T.int32(is_size_var=True) + kv_len = T.int32(is_size_var=True) + q_indptr_elem_offset = T.int32(is_size_var=True) + kv_indptr_elem_offset = T.int32(is_size_var=True) + q_rope_position_elem_offset = T.int32(is_size_var=True) + k_rope_pos_offset_elem_offset = T.int32(is_size_var=True) + + q = T.match_buffer(var_q, (qo_len, h_q, d), dtype) + q_indptr = T.match_buffer( + var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset + ) + k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype) + v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype) + kv_indptr = T.match_buffer( + var_kv_indptr, (batch_size + 1,), "int32", elem_offset=kv_indptr_elem_offset + ) + q_rope_position = T.match_buffer( + var_q_rope_position, (qo_len,), "int32", elem_offset=q_rope_position_elem_offset + ) + k_rope_pos_offset = T.match_buffer( + var_k_rope_pos_offset, (batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset + ) + output = T.match_buffer(var_output, (qo_len, h_q, d), dtype) + lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint: disable=unused-variable + + for b in T.serial(batch_size): + with T.block("attn"): + softmax_sum = T.alloc_buffer([h_q], "float32") + m_prev = T.alloc_buffer([h_q], "float32") + m_new = T.alloc_buffer([h_q], "float32") + d_prev = T.alloc_buffer([h_q], "float32") + d_new = T.alloc_buffer([h_q], "float32") + p_sum = T.alloc_buffer([d], "float32") + max_score = T.alloc_buffer([h_q], "float32") + attention_scores = T.alloc_buffer([kv_len, h_q], "float32") + exp_scores = T.alloc_buffer([kv_len, h_q], "float32") + attention_score = T.alloc_buffer( + [ + 1, + ], + "float32", + ) + query_val = T.alloc_buffer( + [ + 1, + ], + "float32", + ) + key_val = T.alloc_buffer( + [ + 1, + ], + "float32", + ) + result = T.alloc_buffer( + [ + 1, + ], + "float32", + ) + + for q_idx in T.serial(q_indptr[b + 1] - q_indptr[b]): + for i in T.serial(h_q): + max_score[i] = -5e4 + m_prev[i] = -5e4 + d_prev[i] = 1.0 + + for k_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]): + for h in T.serial(h_q): + h_kv_idx = h // group_size + + if _causal_mask( + causal, + row=q_idx, + col=k_idx, + kv_len=kv_indptr[b + 1] - kv_indptr[b], + qo_len=q_indptr[b + 1] - q_indptr[b], + ): + result[0] = 0.0 + for d_idx in T.serial(d): + query_val[0] = T.if_then_else( + rotary_mode == 1, + _rope( + q, + q_rope_position[q_indptr[b] + q_idx], + d, + rope_theta, + rope_scale, + (q_indptr[b] + q_idx, h, d_idx), + dtype, + rope_scaling, + ), + q[q_indptr[b] + q_idx, h, d_idx], + ) + + key_val[0] = T.if_then_else( + rotary_mode == 1, + _rope( + k, + k_rope_pos_offset[b] + k_idx, + d, + rope_theta, + rope_scale, + (kv_indptr[b] + k_idx, h_kv_idx, d_idx), + dtype, + rope_scaling, + ), + k[kv_indptr[b] + k_idx, h_kv_idx, d_idx], + ) + + result[0] += query_val[0] * key_val[0] + attention_score[0] = ( + result[0] * sm_scale * attn_score_scaling_factor + ) + else: + attention_score[0] = -5e4 * sm_scale * attn_score_scaling_factor + attention_scores[k_idx, h] = attention_score[0] + max_score[h] = T.max(max_score[h], attention_score[0]) + m_new[h] = T.max(m_prev[h], max_score[h]) + + for h in T.serial(h_q): + d_new[h] = d_prev[h] * T.exp2(m_prev[h] - m_new[h]) + + for h in T.serial(h_q): + softmax_sum[h] = 0.0 + for k_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]): + exp_scores[k_idx, h] = T.exp2(attention_scores[k_idx, h] - m_new[h]) + softmax_sum[h] += exp_scores[k_idx, h] + d_new[h] += softmax_sum[h] + d_prev = d_new + m_prev = m_new + + for h in T.serial(h_q): + h_kv_idx = h // group_size + for i in T.serial(d): + p_sum[i] = 0.0 + for v_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]): + weight = exp_scores[v_idx, h] / d_new[h] + for i in T.serial(d): + p_sum[i] += v[kv_indptr[b] + v_idx, h_kv_idx, i] * weight + for i in T.serial(d): + output[q_indptr[b] + q_idx, h, i] = p_sum[i] + lse[q_indptr[b] + q_idx, h] = m_prev[h] + T.log2(d_prev[h]) + + return batch_prefill_ragged_kv + + def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], target: Target): # pylint: disable=line-too-long NUM_BLKS = 16 @@ -1949,6 +2713,45 @@ def apply_to_md(sch, block): return sch.mod["main"].with_attr("tir.is_scheduled", 1) +def _copy_single_page_cpu(num_heads, page_size, head_dim, dtype): + tx = 1 + + @T.prim_func + def copy_single_page_cpu( + var_pages: T.handle, + src_page_id: T.int64, + tgt_page_id: T.int64, + copy_length: T.int64, + ): + T.func_attr({"tir.is_scheduled": 1}) + num_pages = T.int32() + pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, page_size, head_dim), dtype) + + for b in T.serial((copy_length * num_heads * head_dim + tx - 1) // tx): + for t in T.serial(tx): + with T.block("copy"): + T.where(b * tx + t < copy_length * num_heads * head_dim) + vh = T.axis.spatial( + num_heads, + T.Cast("int32", (b * tx + t) // (copy_length * head_dim)), + ) + vp = T.axis.spatial( + copy_length, + (b * tx + t) % (copy_length * head_dim) // head_dim, + ) + vd = T.axis.spatial( + head_dim, + T.Cast( + "int32", + (b * tx + t) % head_dim, + ), + ) + pages[tgt_page_id, 0, vh, vp, vd] = pages[src_page_id, 0, vh, vp, vd] + pages[tgt_page_id, 1, vh, vp, vd] = pages[src_page_id, 1, vh, vp, vd] + + return copy_single_page_cpu + + def _copy_single_page(num_heads, page_size, head_dim, dtype, target: Target): tx = get_max_num_threads_per_block(target) @@ -1996,6 +2799,55 @@ def copy_single_page( return copy_single_page +def _compact_kv_copy_cpu(num_heads, head_dim, dtype): + tx = 8 + + @T.prim_func + def compact_kv_copy_cpu( + var_pages: T.handle, + var_copy_length_indptr: T.handle, + var_copy_src_dst_pos: T.handle, + batch_size: T.int32, + ): + T.func_attr({"tir.is_scheduled": 1}) + num_pages = T.int32() + total_copy_length = T.int32() + copy_length_indptr_elem_offset = T.int32() + copy_src_dst_pos_elem_offset = T.int32() + pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype) + copy_length_indptr = T.match_buffer( + var_copy_length_indptr, + (batch_size + 1,), + "int32", + elem_offset=copy_length_indptr_elem_offset, + ) + copy_src_dst_pos = T.match_buffer( + var_copy_src_dst_pos, + (2, total_copy_length), + "int32", + elem_offset=copy_src_dst_pos_elem_offset, + ) + + with T.block("root"): + for bhd_o in T.serial((batch_size * num_heads * head_dim + tx - 1) // tx): + for bhd_i in T.serial(tx): + b: T.int32 = (bhd_o * tx + bhd_i) // (num_heads * head_dim) + h: T.int32 = (bhd_o * tx + bhd_i) // head_dim % num_heads + d: T.int32 = (bhd_o * tx + bhd_i) % head_dim + if (bhd_o * tx + bhd_i) < batch_size * num_heads * head_dim: + for i in T.serial(copy_length_indptr[b + 1] - copy_length_indptr[b]): + src_pos: T.int32 = copy_src_dst_pos[0, copy_length_indptr[b] + i] + dst_pos: T.int32 = copy_src_dst_pos[1, copy_length_indptr[b] + i] + pages[dst_pos // 16, 0, h, dst_pos % 16, d] = pages[ + src_pos // 16, 0, h, src_pos % 16, d + ] + pages[dst_pos // 16, 1, h, dst_pos % 16, d] = pages[ + src_pos // 16, 1, h, src_pos % 16, d + ] + + return compact_kv_copy_cpu + + def _compact_kv_copy(num_heads, head_dim, dtype, target: Target): tx = get_max_num_threads_per_block(target) diff --git a/python/tvm/relax/frontend/nn/llm/tree_attn.py b/python/tvm/relax/frontend/nn/llm/tree_attn.py index 9e4a7ed97e71..fed63b5424ff 100644 --- a/python/tvm/relax/frontend/nn/llm/tree_attn.py +++ b/python/tvm/relax/frontend/nn/llm/tree_attn.py @@ -82,6 +82,213 @@ def _check_tree_order(tree_order_indptr, tree_order, batch, row, col, kv_len, qo ) +def _declare_length_info(var_length_info, batch_size, sliding_window, elem_offset): + return ( + T.match_buffer(var_length_info, (3, batch_size), "int32", elem_offset=elem_offset) + if sliding_window + else T.match_buffer(var_length_info, (batch_size,), "int32", elem_offset=elem_offset) + ) + + +def tree_attn_cpu(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any]): + """Generate tree attention kernel for batched tree attention. + + Parameters + ---------- + h_kv : int + Number of heads for key and value. + h_q : int + Number of heads for query. + d : int + Hidden dimension. + dtype : str + Data type. + target : Target + The target device. + + Returns + ------- + mod : tvm.IRModule + The generated IR module. + """ + group_size = h_q // h_kv + sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) + + # fmt: off + @T.prim_func + def batch_tree_attn( # pylint: disable=too-many-branches,line-too-long + var_q: T.handle, # [total_len, h_q, d] + var_q_indptr: T.handle, # [batch_size + 1] + var_k: T.handle, # [total_len, h_kv, d] + var_v: T.handle, # [total_len, h_kv, d] + var_kv_indptr: T.handle, # [batch_size + 1], kv_indptr should be the same as q_indptr in this case + var_q_rope_position: T.handle, # [total_q_len] + var_mn_indptr: T.handle, # [batch_size + 1] + var_mask: T.handle, # [mn_indptr[batch_size]] + var_output: T.handle, # [total_len, h_q, d] + var_lse: T.handle, # [total_len, h_q] + rotary_mode: T.int32, + rope_scale: T.float32, + rope_theta: T.float32, + attn_score_scaling_factor: T.float32, + batch_size: T.int32, + ): + qo_len = T.int32(is_size_var=True) + kv_len = T.int32(is_size_var=True) + q_indptr_elem_offset = T.int32(is_size_var=True) + kv_indptr_elem_offset = T.int32(is_size_var=True) + q_rope_position_elem_offset = T.int32(is_size_var=True) + mn_indptr_elem_offset = T.int32(is_size_var=True) + mask_elem_offset = T.int32(is_size_var=True) + tree_size = T.int32(is_size_var=True) + + q = T.match_buffer(var_q, (qo_len, h_q, d), dtype) + q_indptr = T.match_buffer( + var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset + ) + k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype) + v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype) + kv_indptr = T.match_buffer( + var_kv_indptr, (batch_size + 1,), "int32", elem_offset=kv_indptr_elem_offset + ) + q_rope_position = T.match_buffer( + var_q_rope_position, (qo_len,), "int32", elem_offset=q_rope_position_elem_offset + ) + mn_indptr = T.match_buffer( + var_mn_indptr, (batch_size + 1,), "int32", elem_offset=mn_indptr_elem_offset + ) + mask = T.match_buffer(var_mask, (tree_size, 2), "int32", elem_offset=mask_elem_offset) + output = T.match_buffer(var_output, (qo_len, h_q, d), dtype) + lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint: disable=unused-variable + + for b in T.serial(batch_size): + with T.block("attn"): + + softmax_sum = T.alloc_buffer([h_q], "float32") + m_prev = T.alloc_buffer([h_q], "float32") + m_new = T.alloc_buffer([h_q], "float32") + d_prev = T.alloc_buffer([h_q], "float32") + d_new = T.alloc_buffer([h_q], "float32") + sum = T.alloc_buffer([d], "float32") + + max_score = T.alloc_buffer([h_q], "float32") + attention_scores = T.alloc_buffer([kv_len, h_q], "float32") + exp_scores = T.alloc_buffer([kv_len, h_q], "float32") + attention_score = T.alloc_buffer( + [ + 1, + ], + "float32", + ) + query_val = T.alloc_buffer( + [ + 1, + ], + "float32", + ) + key_val = T.alloc_buffer( + [ + 1, + ], + "float32", + ) + result = T.alloc_buffer( + [ + 1, + ], + "float32", + ) + + for q_idx in T.serial(q_indptr[b + 1] - q_indptr[b]): + for i in T.serial(h_q): + max_score[i] = -5e4 + m_prev[i] = -5e4 + d_prev[i] = 1.0 + + for k_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]): + for h in T.serial(h_q): + h_kv_idx = h // group_size + + if _check_tree_order( + row=q_idx, + col=k_idx, + batch=b, + tree_order=mask, + tree_order_indptr=mn_indptr, + kv_len=kv_indptr[b + 1] - kv_indptr[b], + qo_len=q_indptr[b + 1] - q_indptr[b], + ): + result[0] = 0.0 + for d_idx in T.serial(d): + query_val[0] = T.if_then_else( + rotary_mode == 1, + _rope( + q, + q_rope_position[q_indptr[b] + q_idx], + d, + rope_theta, + rope_scale, + (q_indptr[b] + q_idx, h, d_idx), + dtype, + rope_scaling, + ), + q[q_indptr[b] + q_idx, h, d_idx], + ) + + key_val[0] = T.if_then_else( + rotary_mode == 1, + _rope( + k, + q_rope_position[kv_indptr[b] + k_idx], + d, + rope_theta, + rope_scale, + (kv_indptr[b] + k_idx, h_kv_idx, d_idx), + dtype, + rope_scaling, + ), + k[kv_indptr[b] + k_idx, h_kv_idx, d_idx], + ) + + result[0] += query_val[0] * key_val[0] + attention_score[0] = ( + result[0] * sm_scale * attn_score_scaling_factor + ) + else: + attention_score[0] = -5e4 * sm_scale * attn_score_scaling_factor + attention_scores[k_idx, h] = attention_score[0] + max_score[h] = T.max(max_score[h], attention_score[0]) + m_new[h] = T.max(m_prev[h], max_score[h]) + + for h in T.serial(h_q): + d_new[h] = d_prev[h] * T.exp2(m_prev[h] - m_new[h]) + + for h in T.serial(h_q): + softmax_sum[h] = 0.0 + for k_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]): + exp_scores[k_idx, h] = T.exp2(attention_scores[k_idx, h] - m_new[h]) + softmax_sum[h] += exp_scores[k_idx, h] + d_new[h] += softmax_sum[h] + d_prev = d_new + m_prev = m_new + + for h in T.serial(h_q): + h_kv_idx = h // group_size + for i in T.serial(d): + sum[i] = 0.0 + for v_idx in T.serial(kv_indptr[b + 1] - kv_indptr[b]): + weight = exp_scores[v_idx, h] / d_new[h] + for i in T.serial(d): + sum[i] += v[kv_indptr[b] + v_idx, h_kv_idx, i] * weight + for i in T.serial(d): + output[q_indptr[b] + q_idx, h, i] = sum[i] + lse[q_indptr[b] + q_idx, h] = m_prev[h] + T.log2(d_prev[h]) + + # fmt: on + # pylint: enable=line-too-long,too-many-branches + return batch_tree_attn + + def tree_attn( h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], target: Target ): # pylint: disable=unused-argument @@ -437,6 +644,204 @@ def apply_to_md(sch, block): return sch.mod["main"].with_attr("tir.is_scheduled", 1) +def tree_attn_with_paged_kv_cache_cpu(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any]): + """Generate tree attention kernel for batched tree attention with paged key-value cache. + + Parameters + ---------- + h_kv : int + Number of heads for key and value. + h_q : int + Number of heads for query. + d : int + Hidden dimension. + dtype : str + Data type. + target : Target + The target device. + + Returns + ------- + mod : tvm.IRModule + The generated IR module. + """ + # pylint: disable=import-outside-toplevel + from .kv_cache import ( + _declare_length_info, + _get_kv_chunk_len, + _get_seq_offset, + ) + + global_symbol = "tree_attn_paged_kv_cpu" + sliding_window = False + group_size = h_q // h_kv + sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1)) + # pylint: disable=line-too-long,too-many-branches + # fmt: off + @T.prim_func(check_well_formed=False) + def tree_attn_paged_kv_cpu( + _0: T.int32, # pylint: disable=unused-argument + var_q: T.handle, # [total_len, h_q, d] + var_q_indptr: T.handle, # [batch_size + 1] + var_pages: T.handle, # [max_num_pages, 2, h_kv, page_size, d] + var_page_indptr: T.handle, # [batch_size + 1] + var_page_values: T.handle, # [nnz_pages] + var_length_info: T.handle, # [b] when sliding window = False, or otherwise [3, b] + var_k_rope_pos_offset: T.handle, # [b] + var_q_rope_position: T.handle, # [total_len] + var_output: T.handle, # [total_len, h_q, d] + var_lse: T.handle, # [total_len, h_q] + rotary_mode: T.int32, + rope_scale: T.float32, + rope_theta: T.float32, + attn_score_scaling_factor: T.float32, + tree_order_indptr_handle: T.handle, # [batch_size + 1] + tree_order_handle: T.handle, # [total_len, 2] + ): + T.func_attr({"global_symbol": global_symbol}) + batch_size = T.int32(is_size_var=True) + total_len = T.int32(is_size_var=True) + nnz_pages = T.int32(is_size_var=True) + max_num_pages = T.int32(is_size_var=True) + q_indptr_elem_offset = T.int32(is_size_var=True) + page_indptr_elem_offset = T.int32(is_size_var=True) + page_values_elem_offset = T.int32(is_size_var=True) + k_rope_pos_offset_elem_offset = T.int32(is_size_var=True) + q_rope_position_elem_offset = T.int32(is_size_var=True) + length_info_elem_offset = T.int32(is_size_var=True) + tree_order_elem_offset = T.int32(is_size_var=True) + tree_order_indptr_elem_offset = T.int32(is_size_var=True) + + q = T.match_buffer(var_q, (total_len, h_q, d), dtype) + q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset) + pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), dtype) + page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", elem_offset=page_indptr_elem_offset) + page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", elem_offset=page_values_elem_offset) + k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset) + q_rope_position = T.match_buffer(var_q_rope_position, (total_len,), "int32", elem_offset=q_rope_position_elem_offset) + output = T.match_buffer(var_output, (total_len, h_q, d), dtype) + lse = T.match_buffer(var_lse, (total_len, h_q), "float32") # pylint: disable=unused-variable + tree_order_indptr = T.match_buffer( + tree_order_indptr_handle, + (batch_size + 1,), + "int32", + elem_offset=tree_order_indptr_elem_offset, + ) + total_tree_order_len = T.int32(is_size_var=True) + tree_order = T.match_buffer( + tree_order_handle, + (total_tree_order_len, 2), + "int32", + elem_offset=tree_order_elem_offset, + ) + # The length information of the sequences. + # - It is in shape `(3, batch_size)` when sliding window is enabled. + # For a sequence "i", location + # - "(0, i)" is the number of KV slots used in the last page of the seq ("last_page_len"), + # - "(1, i)" is the starting offset of the sliding window in the seq, + # - "(2, i)" is the attn sink length of the sequence. + # - It is in shape `(batch_size,)` when sliding window is disabled, + # denoting the "last_page_len". + length_info = _declare_length_info(var_length_info, batch_size, sliding_window, length_info_elem_offset) + + + T.Assert( + rotary_mode == T.int32(0), "Inline rotary mode is not supported in tree attention." + ) + + for h_qo in T.serial(h_q): + for b_idx in T.serial(batch_size): + with T.block("attn"): + O_local = T.alloc_buffer((d, ), "float32") + Q_local = T.alloc_buffer((d, ), "float32") + K_local = T.alloc_buffer((d, ), "float32") + V_local = T.alloc_buffer((d, ), "float32") + + kv_chunk_len = T.alloc_buffer((1, ), "int32") + + m_val = T.alloc_buffer((1, ), "float32") + new_m = T.alloc_buffer((1, ), "float32") + d_val = T.alloc_buffer((1, ), "float32") + S_val = T.alloc_buffer((1, ), "float32") + scale_O = T.alloc_buffer((1, ), "float32") + factor = T.alloc_buffer((1, ), "float32") + cur_page_indptr_begin: T.int32 = page_indptr[b_idx] + cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1] + kv_chunk_len[0] = T.if_then_else( + cur_page_indptr_begin != cur_page_indptr_end, + _get_kv_chunk_len(cur_page_indptr_end - cur_page_indptr_begin, 16, b_idx, length_info, sliding_window), + 0 + ) + + for q_idx in T.serial(q_indptr[b_idx + 1] - q_indptr[b_idx]): + #init m, d, O + m_val[0] = -5e4 + d_val[0] = 1.0 + for d_idx in T.serial(d): + O_local[d_idx] = 0.0 + curl_q: T.int32 = q_indptr[b_idx] + q_idx + + for d_idx in T.serial(d): + Q_local[d_idx] = T.if_then_else( + rotary_mode == 1, + _rope(q, q_rope_position[curl_q], d, rope_theta, rope_scale, (curl_q, h_qo, d_idx), dtype, rope_scaling), + q[curl_q, h_qo, d_idx] + ) + for row_idx in T.serial(max_num_pages * 16): + if row_idx < kv_chunk_len[0]: + page_no: T.int32(is_size_var=True) = page_values[cur_page_indptr_begin + (_get_seq_offset(row_idx, b_idx, length_info, sliding_window) // 16)] + page_offset: T.int32(is_size_var=True) = _get_seq_offset(row_idx, b_idx, length_info, sliding_window) % 16 + + # Load KV + for d_idx in T.serial(d): + K_local[d_idx] = T.if_then_else( + rotary_mode == 1, + _rope(pages, k_rope_pos_offset[b_idx] + row_idx, d, rope_theta, rope_scale, (page_no, 0, h_qo // group_size, page_offset, d_idx), dtype, rope_scaling), + pages[page_no, 0, h_qo // group_size, page_offset, d_idx] + ) + V_local[d_idx] = pages[page_no, 1, h_qo // group_size, page_offset, d_idx] + + # Compute S + S_val[0] = 0.0 + for d_idx in T.serial(d): + S_val[0] += Q_local[d_idx] * K_local[d_idx] + S_val[0] *= attn_score_scaling_factor * sm_scale + + # update m_val, d_val , O_local + if _check_tree_order( + tree_order_indptr=tree_order_indptr, + tree_order=tree_order, + batch=b_idx, + row=q_idx, + col=row_idx, + kv_len=kv_chunk_len[0], + qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx], + ): + new_m[0] = T.max(m_val[0], S_val[0]) + else: + S_val[0] = -5e4 + # update d_val + d_val[0] *= T.exp2(m_val[0] - new_m[0]) + d_val[0] += T.exp2(S_val[0] - new_m[0]) + + # restore O_local then update O_local + scale_O[0] = T.exp2(m_val[0] - new_m[0]) + m_val[0] = new_m[0] + factor[0] = T.exp2(S_val[0] - m_val[0]) + for d_idx in T.serial(d): + O_local[d_idx] = O_local[d_idx] * scale_O[d_idx] + + + for d_idx in T.serial(d): + O_local[d_idx] += V_local[d_idx] * factor[0] + # Store Output + for d_idx in T.serial(d): + O_local[d_idx] = O_local[d_idx] /d_val[0] + output[curl_q, h_qo, d_idx] = O_local[d_idx] + lse[curl_q, h_qo] = m_val[0] + T.log2(d_val[0]) + return tree_attn_paged_kv_cpu + + def tree_attn_with_paged_kv_cache( h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], target: Target ): diff --git a/src/runtime/cpu_device_api.cc b/src/runtime/cpu_device_api.cc index ccd726a6ece6..71c8025207de 100644 --- a/src/runtime/cpu_device_api.cc +++ b/src/runtime/cpu_device_api.cc @@ -34,6 +34,19 @@ #include #endif +#if defined(__linux__) || defined(__ANDROID__) +#include +#endif + +#ifdef _WIN32 +#include +#endif + +#if defined(__APPLE__) +#include +#include +#endif + namespace tvm { namespace runtime { class CPUDeviceAPI final : public DeviceAPI { @@ -43,6 +56,41 @@ class CPUDeviceAPI final : public DeviceAPI { if (kind == kExist) { *rv = 1; } + + switch (kind) { + case kExist: + break; + case kTotalGlobalMemory: { +#if defined(__linux__) || defined(__ANDROID__) + struct sysinfo info; + if (sysinfo(&info) == 0) { + *rv = static_cast(info.totalram) * info.mem_unit; // Convert to bytes + } else { + *rv = -1; + } +#elif defined(_WIN32) + MEMORYSTATUSEX statex; + statex.dwLength = sizeof(statex); + if (GlobalMemoryStatusEx(&statex)) { + *rv = static_cast(statex.ullTotalPhys); // Total physical memory in bytes + } else { + *rv = -1; + } +#elif defined(__APPLE__) + int64_t mem; + size_t size = sizeof(mem); + if (sysctlbyname("hw.memsize", &mem, &size, nullptr, 0) == 0) { + *rv = mem; + } else { + *rv = -1; + } +#else + *rv = -1; +#endif + } + default: + break; + } } void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final { void* ptr;