Skip to content

Commit

Permalink
cont : return important tensors
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Feb 18, 2025
1 parent c235903 commit 172f616
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 46 deletions.
29 changes: 19 additions & 10 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ void llama_context::init() {
// reserve pp graph first so that buffers are only allocated once
{
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
auto res_pp = graph_build(ubatch_pp, true);
auto ctx = graph_init();
auto res_pp = graph_build(ctx, ubatch_pp, true);
auto & gf_pp = res_pp.gf;
if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute pp buffers\n", __func__);
Expand All @@ -269,7 +270,8 @@ void llama_context::init() {
// reserve with tg graph to get the number of splits and nodes
{
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
auto res_tg = graph_build(ubatch_tg, true);
auto ctx = graph_init();
auto res_tg = graph_build(ctx, ubatch_tg, true);
auto & gf_tg = res_tg.gf;
if (!ggml_backend_sched_reserve(sched.get(), gf_tg)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute tg buffers\n", __func__);
Expand All @@ -282,7 +284,8 @@ void llama_context::init() {
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
{
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
auto res_pp = graph_build(ubatch_pp, true);
auto ctx = graph_init();
auto res_pp = graph_build(ctx, ubatch_pp, true);
auto & gf_pp = res_pp.gf;
if (!ggml_backend_sched_reserve(sched.get(), gf_pp)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute pp buffers\n", __func__);
Expand Down Expand Up @@ -569,6 +572,13 @@ ggml_context_ptr llama_context::graph_init() {
return ggml_context_ptr { ggml_init(params) };
}

llama_graph_result llama_context::graph_build(
ggml_context_ptr & ctx,
const llama_ubatch & ubatch,
bool worst_case) {
return model.build_graph(ctx, *this, cparams, ubatch, worst_case);
}

enum ggml_status llama_context::graph_compute(
ggml_cgraph * graph,
bool batched) {
Expand Down Expand Up @@ -907,10 +917,6 @@ void llama_context::build_cb(
}
}

llama_graph_result llama_context::graph_build(const llama_ubatch & ubatch, bool worst_case) {
return model.build_graph(*this, cparams, ubatch, graph_init(), worst_case);
}

llama_perf_context_data llama_context::perf_get_data() const {
llama_perf_context_data data = {};

Expand Down Expand Up @@ -1831,7 +1837,8 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};

auto res = graph_build(ubatch, true);
auto ctx = graph_init();
auto res = graph_build(ctx, ubatch, true);

// initialize scheduler with the worst-case graph
ggml_backend_sched_reset(sched.get());
Expand All @@ -1845,7 +1852,8 @@ int llama_context_kv_self::decode(llama_batch & inp_batch) {
ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);

auto res = graph_build(ubatch, false);
auto ctx = graph_init();
auto res = graph_build(ctx, ubatch, false);

auto & gf = res.gf;

Expand Down Expand Up @@ -2092,7 +2100,8 @@ int llama_context_kv_self::encode(llama_batch & inp_batch) {
ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);

auto res = graph_build(ubatch, false);
auto ctx = graph_init();
auto res = graph_build(ctx, ubatch, false);

auto & gf = res.gf;

Expand Down
5 changes: 4 additions & 1 deletion src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ struct llama_context : public llama_graph_i {
virtual ggml_context_ptr graph_init();

// TODO: add encode/decode graphs
virtual llama_graph_result graph_build(const llama_ubatch & ubatch, bool worst_case);
virtual llama_graph_result graph_build(
ggml_context_ptr & ctx,
const llama_ubatch & ubatch,
bool worst_case);

// returns the result of ggml_backend_sched_graph_compute_async execution
virtual enum ggml_status graph_compute(
Expand Down
6 changes: 4 additions & 2 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ struct llama_ubatch;
struct llama_graph_result {
ggml_cgraph * gf = nullptr;

ggml_tensor * t_logits = nullptr;
ggml_tensor * t_embd = nullptr;
// important graph nodes
ggml_tensor * t_logits = nullptr;
ggml_tensor * t_embd = nullptr;
ggml_tensor * t_embd_pooled = nullptr;
};

// TODO: can become more granular in the future
Expand Down
Loading

0 comments on commit 172f616

Please sign in to comment.