Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : refactor llama_kv_cache, llama_context and llm_build_context #11213

Draft
wants to merge 51 commits into
base: master
Choose a base branch
from

Conversation

ggerganov
Copy link
Owner

@ggerganov ggerganov commented Jan 13, 2025

Overview

This PR is an intermediate step towards a more generic implementation that will support different underlying implementations of llama_kv_cache, llama_context and the graph building logic (a.k.a. llm_build_context). The llama_kv_cache is also introduced in the public API as an object, but it's actual functionality is yet to be defined in follow-up PRs.

Currently, no functional changes have been introduced. Mainly the code has been reorganized in a way to allow implementing new abstractions. The main changes in the implementation are:

  • Avoid all explicit references to llama_kv_cache in llm_build_context. The goal is to be able to construct the computation graphs only through the abstract llama_context interface, which will hide the actual KV cache implementation and thus allow to be overloaded based on the parameters of the specific use case. More generally, the llama_context hides not only the KV cache implementation, but all the internal state (such as, applied adapters, masks, etc. if any) with the exception of the model weights - these are still available to the llm_build_context in order to be able to construct the backbone graph of the various architectures.
  • Avoid all explicit references to llama_kv_cache in llama_decode/llama_encode. These are abstracted through a new object llama_batch_manager which is produced by the current llama_context. Again the goal is to not make explicit assumptions about the underlying KV cache implementation while processing the batches and be able to delegate this logic to the llama_context. The llama_batch_manager is produced by the llama_context and will handle logic such as, restoring the KV cache state to consistent state upon errors, batching the input batch into micro batches according to the internal processing logic, etc.
  • Add initial serialization primitives to llama_kv_cache. In the future, these will be overloaded for the specific KV cache implementations through a common abstract interface.

The modifications so far are quite substantial and touch too many lines. Even though the code is in a very intermediate state, with many members still publicly exposed and without proper object-oriented implementation in place, it should still be mergeable.

The general class hierarchy that I have in mind is like this:

graph TD;
llama_kv_cache_unified --> llama_kv_cache;
llama_kv_cache_standard --> llama_kv_cache;
llama_kv_cache_mamba --> llama_kv_cache;
... --> llama_kv_cache;
Loading

Here, llama_kv_cache_unified is basically the llama_kv_cache implementation that we currently have. In the future, we will add more implementations that would be appropriate for multi-user scenarios (e.g. llama_kv_cache_standard) or for Mamba architectures (llama_kv_cache_mamba).

graph TD;
llama_context --> llama_model;
llama_context --> llama_cparams;
llama_context --> llama_adapter;
llama_context --> etc..;

llama_context[<b>llama_context</b>];

llama_context_no_kv[<b>llama_context_no_kv</b><br><br>];
llama_context_unified[<b>llama_context_unified</b><br><br>llama_kv_cache_unified];
llama_context_standard[<b>llama_context_standard</b><br><br>llama_kv_cache_standard];
llama_context_mamba[<b>llama_context_mamba</b><br><br>llama_kv_cache_mamba];
llama_context_enc_dec[<b>llama_context_enc_dec</b><br><br>llama_kv_cache_standard];

llama_context_no_kv -.-> llama_context;
llama_context_unified -.-> llama_context;
llama_context_standard -.-> llama_context;
llama_context_mamba -.-> llama_context;
llama_context_enc_dec -.-> llama_context;
... -.-> llama_context;
Loading

The base llama_context class will implement common functionality such as low-level ggml buffer and backend management + adapters, without the notion of a KV cache. The derived classes will specialize the llama_context for different use-cases.

The llm_build_context would operate only through the llama_build_i interface and the batch processing will respectively only interact with the llama_batch_manager_i interface. The type of llama_context to construct in functions such as llama_init_from_model() would be determined based on the model and the specified context parameters. For example, the user would be able to create both llama_context_unified and llama_context_standard for a LLM_ARCH_QWEN2 model. Or a llama_context_no_kv for an encoding-only LLM_ARCH_BERT model. And so on.

API changes

The current changes are only necessary to make the API more consistent in following the naming convention. To migrate, simply replace the old API calls with the new ones.

  • Deprecate llama_kv_cache_... API
  • Add llama_kv_self_... API

In the future, the llama_kv_cache_... API will be changed to work with struct llama_kv_cache instead of struct llama_context and the functionality will be extended to support things like saving, copying, loading, etc.

Notes

  • Fix build_qwen2vl, inp_pos, lctx.n_pos_per_token hack
  • Worst case for n_outputs and n_outputs_enc in llm_build_context seem incorrect
  • Remove inp_s_seq - not used
  • fix
     const bool           is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
     struct ggml_tensor * KQ_mask_l = is_sliding ? KQ_mask_swa : KQ_mask;
    
  • Fix T5
  • Fix RWKV
  • Fix batch.pos == NULL - llama_context::pos_max() is used incorrectly
  • Dedup the reserve code

PRs to resolve

New features

@ggerganov
Copy link
Owner Author

ggerganov commented Jan 14, 2025

I am thinking about the following API change for this PR:

    // API on `master`
    DEPRECATED(LLAMA_API void llama_kv_cache_clear(ctx));
    DEPRECATED(LLAMA_API bool llama_kv_cache_seq_rm(ctx));
    DEPRECATED(LLAMA_API void llama_kv_cache_seq_cp(ctx));
    DEPRECATED(LLAMA_API void llama_kv_cache_seq_keep(ctx));
    DEPRECATED(LLAMA_API void llama_kv_cache_seq_add(ctx));
    DEPRECATED(LLAMA_API void llama_kv_cache_seq_div(ctx));
    DEPRECATED(LLAMA_API llama_pos llama_kv_cache_seq_pos_max(ctx));
    DEPRECATED(LLAMA_API void llama_kv_cache_defrag(ctx));
    DEPRECATED(LLAMA_API bool llama_kv_cache_can_shift(ctx));
    DEPRECATED(LLAMA_API void llama_kv_cache_update(ctx));

    // works with `ctx.kv_self` - backwards compatible with `master`
    LLAMA_API void llama_kv_self_clear(ctx);
    LLAMA_API bool llama_kv_self_seq_rm(ctx);
    LLAMA_API void llama_kv_self_seq_cp(ctx);
    LLAMA_API void llama_kv_self_seq_keep(ctx);
    LLAMA_API void llama_kv_self_seq_add(ctx);
    LLAMA_API void llama_kv_self_seq_div(ctx);
    LLAMA_API llama_pos llama_kv_self_seq_pos_max(ctx);
    LLAMA_API void llama_kv_self_defrag(ctx);
    LLAMA_API bool llama_kv_self_can_shift(ctx);
    LLAMA_API void llama_kv_self_update(ctx);

    // TODO: llama_kv_cache API
    // can be implemented in a later PR
    // new API to access the KV cache instance
    struct llama_kv_cache;

    LLAMA_API struct llama_kv_cache * llama_get_kv_self(ctx)
    LLAMA_API void                    llama_set_kv_self(ctx, kv);
    // allow to clone, free, save, load the kv cache

@ggerganov ggerganov force-pushed the gg/llama-kv-cache branch 3 times, most recently from bcfda5c to fb74024 Compare January 14, 2025 11:22
@github-actions github-actions bot added the android Issues specific to Android label Jan 14, 2025
src/llama.cpp Outdated
Comment on lines 9111 to 9135
void llama_kv_self_update(llama_context * ctx) {
llama_kv_self_update_impl(*ctx);
const bool need_reserve = ctx->kv_self_update();

// reserve a worst case graph again
if (need_reserve) {
// TODO: extract to a function
const auto & cparams = ctx->cparams;
const auto & model = ctx->model;

// build worst-case graph
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);

llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};

ggml_cgraph * gf = llama_build_graph(*ctx, ubatch, true);

// initialize scheduler with the worst-case graph
ggml_backend_sched_reset(ctx->sched.get());
if (!ggml_backend_sched_reserve(ctx->sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
}
}
}
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@slaren If we have a separate scheduler for the kv_self updates (such as K-shift and defrag), would this worst-case reservation be necessary?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, but that would increase memory usage.

@@ -460,8 +461,9 @@ extern "C" {

DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");

LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); // TODO: remove const?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llama_model should always be immutable, otherwise it would be hard to guarantee the thread-safety when used in multiple contexts. So returning a const should be correct here.

src/llama.cpp Outdated
Comment on lines 9111 to 9135
void llama_kv_self_update(llama_context * ctx) {
llama_kv_self_update_impl(*ctx);
const bool need_reserve = ctx->kv_self_update();

// reserve a worst case graph again
if (need_reserve) {
// TODO: extract to a function
const auto & cparams = ctx->cparams;
const auto & model = ctx->model;

// build worst-case graph
uint32_t n_seqs = 1; // TODO: worst-case number of sequences
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);

llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};

ggml_cgraph * gf = llama_build_graph(*ctx, ubatch, true);

// initialize scheduler with the worst-case graph
ggml_backend_sched_reset(ctx->sched.get());
if (!ggml_backend_sched_reserve(ctx->sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
}
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, but that would increase memory usage.

src/llama-context.cpp Outdated Show resolved Hide resolved
@ggerganov ggerganov marked this pull request as ready for review January 17, 2025 21:09
@ggerganov ggerganov requested a review from ngxson as a code owner January 17, 2025 21:09
@ggerganov
Copy link
Owner Author

ggerganov commented Jan 17, 2025

@slaren Coming back to your comment from earlier: #11110 (review)

  • At some point we should abstract eveything needed to model an architecture to a single class (such that each architecture is a subclass of this class)
  • After that, llm_type should probably be removed entirely, and each architecture should have its own enum if needed, with a function to return the type as a string (which by default could be " ")

In the OP I have outlined a possible approach to make the implementation more abstract. I have focused primarily on the abstraction of the KV cache and the llama context.

If I understand correctly your suggestion, the idea is to have the compute graph build functions for each of the arches (e.g. build_llama()), become part of llama_model (e.g. implement derived classes llama_model_llama, llama_model_qwen, etc..), which would effectively eliminate the need for llm_build_context. This way, the llama_context would be able to simply call model->build(), instead of relying on the graph to come from "outside". Do I understand correctly the idea?

@ggerganov ggerganov changed the title llama : add struct llama_kv_cache llama : refactor llama_kv_cache, llama_context and llm_build_context Jan 17, 2025
@slaren
Copy link
Collaborator

slaren commented Jan 18, 2025

I haven't really though enough about this to make specific suggestions, but I think the goal should be to have an interface that can be used to define everything necessary to implement a model architecture. Ideally, to add support for a new architecture, it should only be necessary be to define a new class and create a mapping between the architecture name in the GGUF file and this class. There may of course be more classes in the interface, but there should be a single entry point. So this should include more than just the graph build function, it should also include all the functions to load a model, create a context, and everything else that may be necessary to run a model. This interface would also need to be supported by other interfaces such as the KV cache abstraction, and graph building helper functions that are currently in llm_build_context and the other llm_build_* functions.

To do this, I think it would be better to create an abstract interface that contains everything necessary to define a model architecture. I think that's likely to result in a cleaner and more maintainable codebase than using llama_model as a base class. Instead, llama_model (and other classes like llama_context) should use this interface to implement the functionality in llama.cpp. It may also be convenient to have one or more base classes that implement some of the common functionality that is shared between multiple model architectures, but it should not be strictly necessary to use these base classes.

This is of course a very high level suggestion, it will take a lot of work to define all the details.

@ggerganov ggerganov marked this pull request as draft January 20, 2025 07:28
@ggerganov
Copy link
Owner Author

Thanks for the suggestions. I'll aim to create the abstract model interface and restructure the implementation so that the llm_build_context is no longer needed and all model-specific code is behind the new abstract interface. Will keep hacking on this PR for a while and try to bring it in a more complete state before merging.

@github-actions github-actions bot added the python python script changes label Feb 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
android Issues specific to Android examples python python script changes server
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants