Skip to content

Commit

Permalink
context : add llama_context_rwkv
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Feb 19, 2025
1 parent 5f11a55 commit cacb57f
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 84 deletions.
136 changes: 75 additions & 61 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1700,8 +1700,6 @@ ggml_cgraph * llama_context_kv_self::graph_init() {
inp_KQ_mask_swa_cnv = nullptr;
inp_KQ_mask_cross = nullptr;
inp_k_shift = nullptr;
inp_s_copy = nullptr;
inp_s_mask = nullptr;
inp_embd_enc = nullptr;
inp_pos_bucket = nullptr;

Expand Down Expand Up @@ -2381,53 +2379,6 @@ void llama_context_kv_self::input_set(const llama_ubatch & ubatch) {
}
}

if (kv_self.recurrent) {
const int64_t n_kv = kv_self.n;

if (inp_s_mask) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_mask->buffer));
float * data = (float *) inp_s_mask->data;

// clear unused states
for (int i = 0; i < n_kv; ++i) {
const uint32_t cell_id = i + kv_self.head;
llama_kv_cell & kv_cell = kv_self.cells[cell_id];

data[i] = (float) (kv_cell.src >= 0);

// TODO: do not mutate the KV cache
// only clear once
if (kv_cell.src < 0) {
kv_cell.src = cell_id;
}
}
}

if (inp_s_copy) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_copy->buffer));
int32_t * data = (int32_t *) inp_s_copy->data;

// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_kv; ++i) {
const uint32_t cell_id = i + kv_self.head;
llama_kv_cell & kv_cell = kv_self.cells[cell_id];

// prevent out-of-bound sources
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) {
kv_cell.src = cell_id;
}

data[i] = kv_cell.src;

// TODO: do not mutate the KV cache
// ensure copy only happens once
if (kv_cell.src != (int32_t) cell_id) {
kv_cell.src = cell_id;
}
}
}
}

if (inp_pos_bucket) {
const int64_t n_tokens = ubatch.n_tokens;

Expand Down Expand Up @@ -2614,7 +2565,7 @@ void llama_context_kv_self::build_attn_inp(

void llama_context_kv_self::build_attn_kv_store(
ggml_context * ctx0,
ggml_cgraph * graph,
ggml_cgraph * gf,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
int32_t n_tokens,
Expand All @@ -2635,7 +2586,7 @@ void llama_context_kv_self::build_attn_kv_store(
//cb(k_cache_view, "k_cache_view", il);

// note: storing RoPE-ed version of K in the KV cache
ggml_build_forward_expand(graph, ggml_cpy(ctx0, k_cur, k_cache_view));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));

assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);

Expand All @@ -2653,12 +2604,12 @@ void llama_context_kv_self::build_attn_kv_store(
}
//cb(v_cache_view, "v_cache_view", il);

ggml_build_forward_expand(graph, ggml_cpy(ctx0, v_cur, v_cache_view));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
}

ggml_tensor * llama_context_kv_self::build_attn_qkv(
ggml_context * ctx0,
ggml_cgraph * graph,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
Expand Down Expand Up @@ -2791,7 +2742,7 @@ ggml_tensor * llama_context_kv_self::build_attn_qkv(
}
}

ggml_build_forward_expand(graph, cur);
ggml_build_forward_expand(gf, cur);

if (wo) {
cur = build_lora_mm(ctx0, wo, cur);
Expand Down Expand Up @@ -3152,7 +3103,70 @@ ggml_tensor * llama_context_kv_self::build_inp_KQ_mask_cross(
return inp_KQ_mask_cross;
}

ggml_tensor * llama_context_kv_self::build_inp_s_copy(
//
// llama_context_rwkv
//

ggml_cgraph * llama_context_rwkv::graph_init() {
inp_s_copy = nullptr;
inp_s_mask = nullptr;

return llama_context_kv_self::graph_init();
}

void llama_context_rwkv::input_set(const llama_ubatch & ubatch) {
// call base functionality
llama_context_kv_self::input_set(ubatch);

GGML_ASSERT(kv_self.recurrent);

const int64_t n_kv = kv_self.n;

if (inp_s_mask) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_mask->buffer));
float * data = (float *) inp_s_mask->data;

// clear unused states
for (int i = 0; i < n_kv; ++i) {
const uint32_t cell_id = i + kv_self.head;
llama_kv_cell & kv_cell = kv_self.cells[cell_id];

data[i] = (float) (kv_cell.src >= 0);

// TODO: do not mutate the KV cache
// only clear once
if (kv_cell.src < 0) {
kv_cell.src = cell_id;
}
}
}

if (inp_s_copy) {
GGML_ASSERT(ggml_backend_buffer_is_host(inp_s_copy->buffer));
int32_t * data = (int32_t *) inp_s_copy->data;

// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_kv; ++i) {
const uint32_t cell_id = i + kv_self.head;
llama_kv_cell & kv_cell = kv_self.cells[cell_id];

// prevent out-of-bound sources
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) {
kv_cell.src = cell_id;
}

data[i] = kv_cell.src;

// TODO: do not mutate the KV cache
// ensure copy only happens once
if (kv_cell.src != (int32_t) cell_id) {
kv_cell.src = cell_id;
}
}
}
}

ggml_tensor * llama_context_rwkv::build_inp_s_copy(
ggml_context * ctx0,
bool worst_case) {
const auto n_kv = worst_case ? kv_self.size : kv_self.n;
Expand All @@ -3163,7 +3177,7 @@ ggml_tensor * llama_context_kv_self::build_inp_s_copy(
return inp_s_copy;
}

ggml_tensor * llama_context_kv_self::build_inp_s_mask(
ggml_tensor * llama_context_rwkv::build_inp_s_mask(
ggml_context * ctx0,
bool worst_case) {
const auto n_kv = worst_case ? kv_self.size : kv_self.n;
Expand All @@ -3173,7 +3187,7 @@ ggml_tensor * llama_context_kv_self::build_inp_s_mask(
return inp_s_mask;
}

ggml_tensor * llama_context_kv_self::build_copy_mask_state(
ggml_tensor * llama_context_rwkv::build_copy_mask_state(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * s,
Expand Down Expand Up @@ -3208,7 +3222,7 @@ ggml_tensor * llama_context_kv_self::build_copy_mask_state(
}

// TODO: split
ggml_tensor * llama_context_kv_self::build_mamba_layer(
ggml_tensor * llama_context_rwkv::build_mamba_layer(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * cur,
Expand Down Expand Up @@ -3344,7 +3358,7 @@ ggml_tensor * llama_context_kv_self::build_mamba_layer(
}


ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_load(
ggml_tensor * llama_context_rwkv::build_rwkv_token_shift_load(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * state_copy,
Expand All @@ -3371,7 +3385,7 @@ ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_load(
}


ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_store(
ggml_tensor * llama_context_rwkv::build_rwkv_token_shift_store(
ggml_context * ctx0,
ggml_tensor * token_shift,
const llama_ubatch & ubatch,
Expand All @@ -3395,7 +3409,7 @@ ggml_tensor * llama_context_kv_self::build_rwkv_token_shift_store(
}


ggml_tensor * llama_context_kv_self::build_rwkv6_time_mix(
ggml_tensor * llama_context_rwkv::build_rwkv6_time_mix(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * cur,
Expand Down
23 changes: 14 additions & 9 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -433,13 +433,25 @@ class llama_context_kv_self : public llama_context {
int32_t n_tokens,
bool worst_case) override;

// === recurrent ===
protected:
virtual size_t state_get_data(llama_io_write_i & io) override;
virtual size_t state_set_data(llama_io_read_i & io) override;

virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
};

// TODO: temporary reuse kv_self, but in the future, implement specific context
class llama_context_rwkv : public llama_context_kv_self {
public:
virtual ggml_cgraph * graph_init() override;

virtual void input_set(const llama_ubatch & ubatch) override;

struct ggml_tensor * inp_s_copy; // I32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]

// TODO: add recurrent cache
// TODO: add mamba-specific llama_context

// TODO: change these to build_mamba_inp and hide `state_copy` and `state_mask` inside the llama_context impl
virtual ggml_tensor * build_inp_s_copy(
Expand Down Expand Up @@ -497,13 +509,6 @@ class llama_context_kv_self : public llama_context {
const llama_ubatch & ubatch,
int il,
bool worst_case) override;

protected:
virtual size_t state_get_data(llama_io_write_i & io) override;
virtual size_t state_set_data(llama_io_read_i & io) override;

virtual size_t state_seq_get_data(llama_io_write_i & io, llama_seq_id seq_id) override;
virtual size_t state_seq_set_data(llama_io_read_i & io, llama_seq_id seq_id) override;
};

// For internal test use
Expand Down
114 changes: 114 additions & 0 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
@@ -1 +1,115 @@
#include "llama-graph.h"

#include "ggml.h"

ggml_tensor * llama_graph_i::build_inp_s_copy (
ggml_context * ctx0,
bool worst_case) {
GGML_UNUSED(ctx0);
GGML_UNUSED(worst_case);
GGML_ABORT("not implemented");
}

ggml_tensor * llama_graph_i::build_inp_s_mask(
ggml_context * ctx0,
bool worst_case) {
GGML_UNUSED(ctx0);
GGML_UNUSED(worst_case);
GGML_ABORT("not implemented");
}

ggml_tensor * llama_graph_i::build_copy_mask_state(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
int32_t n_tokens,
int32_t n_state,
int32_t n_seqs,
bool worst_case) {
GGML_UNUSED(ctx0);
GGML_UNUSED(gf);
GGML_UNUSED(s);
GGML_UNUSED(state_copy);
GGML_UNUSED(state_mask);
GGML_UNUSED(n_tokens);
GGML_UNUSED(n_state);
GGML_UNUSED(n_seqs);
GGML_UNUSED(worst_case);
GGML_ABORT("not implemented");
}

ggml_tensor * llama_graph_i::build_mamba_layer(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * cur,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il,
bool worst_case) {
GGML_UNUSED(ctx0);
GGML_UNUSED(gf);
GGML_UNUSED(cur);
GGML_UNUSED(state_copy);
GGML_UNUSED(state_mask);
GGML_UNUSED(ubatch);
GGML_UNUSED(il);
GGML_UNUSED(worst_case);
GGML_ABORT("not implemented");
}

ggml_tensor * llama_graph_i::build_rwkv_token_shift_load(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il,
bool worst_case) {
GGML_UNUSED(ctx0);
GGML_UNUSED(gf);
GGML_UNUSED(state_copy);
GGML_UNUSED(state_mask);
GGML_UNUSED(ubatch);
GGML_UNUSED(il);
GGML_UNUSED(worst_case);
GGML_ABORT("not implemented");
}

ggml_tensor * llama_graph_i::build_rwkv_token_shift_store(
ggml_context * ctx0,
ggml_tensor * token_shift,
const llama_ubatch & ubatch,
int il,
bool worst_case) {
GGML_UNUSED(ctx0);
GGML_UNUSED(token_shift);
GGML_UNUSED(ubatch);
GGML_UNUSED(il);
GGML_UNUSED(worst_case);
GGML_ABORT("not implemented");
}

ggml_tensor * llama_graph_i::build_rwkv6_time_mix(
ggml_context * ctx0,
ggml_cgraph * gf,
ggml_tensor * cur,
ggml_tensor * x_prev,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il,
bool worst_case) {
GGML_UNUSED(ctx0);
GGML_UNUSED(gf);
GGML_UNUSED(cur);
GGML_UNUSED(x_prev);
GGML_UNUSED(state_copy);
GGML_UNUSED(state_mask);
GGML_UNUSED(ubatch);
GGML_UNUSED(il);
GGML_UNUSED(worst_case);
GGML_ABORT("not implemented");
}
Loading

0 comments on commit cacb57f

Please sign in to comment.