Skip to content

Commit

Permalink
feat: sync llama.cpp (#108)
Browse files Browse the repository at this point in the history
* feat: sync llama.cpp

* fix: api changes

* fix(temp): modelInfo

* feat: sync llama.cpp

* fix(cpp): keep common_init_result

* fix(ios): revert modelInfo

* fix(cpp): use gguf_kv_to_str from llama-impl

* fix(cpp): llama_free_model -> llama_model_free

* feat: move lora init into rn-llama

* fix(cpp): remove unnecessary free

* feat: sync llama.cpp
  • Loading branch information
jhen0409 authored Jan 9, 2025
1 parent 4cf10a7 commit b539012
Show file tree
Hide file tree
Showing 62 changed files with 23,315 additions and 21,771 deletions.
14 changes: 14 additions & 0 deletions android/src/main/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,25 @@ set(
${RNLLAMA_LIB_DIR}/ggml-opt.cpp
${RNLLAMA_LIB_DIR}/ggml-threading.cpp
${RNLLAMA_LIB_DIR}/ggml-quants.c
${RNLLAMA_LIB_DIR}/gguf.cpp
${RNLLAMA_LIB_DIR}/log.cpp
${RNLLAMA_LIB_DIR}/llama-impl.cpp
${RNLLAMA_LIB_DIR}/llama-grammar.cpp
${RNLLAMA_LIB_DIR}/llama-sampling.cpp
${RNLLAMA_LIB_DIR}/llama-vocab.cpp
${RNLLAMA_LIB_DIR}/llama-adapter.cpp
${RNLLAMA_LIB_DIR}/llama-chat.cpp
${RNLLAMA_LIB_DIR}/llama-context.cpp
${RNLLAMA_LIB_DIR}/llama-kv-cache.cpp
${RNLLAMA_LIB_DIR}/llama-arch.cpp
${RNLLAMA_LIB_DIR}/llama-batch.cpp
${RNLLAMA_LIB_DIR}/llama-cparams.cpp
${RNLLAMA_LIB_DIR}/llama-hparams.cpp
${RNLLAMA_LIB_DIR}/llama.cpp
${RNLLAMA_LIB_DIR}/llama-model.cpp
${RNLLAMA_LIB_DIR}/llama-model-loader.cpp
${RNLLAMA_LIB_DIR}/llama-mmap.cpp
${RNLLAMA_LIB_DIR}/llama-vocab.cpp
${RNLLAMA_LIB_DIR}/sampling.cpp
${RNLLAMA_LIB_DIR}/unicode-data.cpp
${RNLLAMA_LIB_DIR}/unicode.cpp
Expand Down
25 changes: 8 additions & 17 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__)

#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
static inline int min(int a, int b) {
return (a < b) ? a : b;
}
Expand Down Expand Up @@ -198,7 +198,7 @@ Java_com_rnllama_LlamaContext_modelInfo(
continue;
}

const std::string value = rnllama::lm_gguf_kv_to_str(ctx, i);
const std::string value = lm_gguf_kv_to_str(ctx, i);
putString(env, info, key, value.c_str());
}
}
Expand Down Expand Up @@ -336,13 +336,13 @@ Java_com_rnllama_LlamaContext_initContext(
llama_free(llama->ctx);
}

std::vector<common_lora_adapter_info> lora_adapters;
std::vector<common_lora_adapter_info> lora;
const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr);
if (lora_chars != nullptr && lora_chars[0] != '\0') {
common_lora_adapter_info la;
la.path = lora_chars;
la.scale = lora_scaled;
lora_adapters.push_back(la);
lora.push_back(la);
}

if (lora_list != nullptr) {
Expand All @@ -356,13 +356,13 @@ Java_com_rnllama_LlamaContext_initContext(
common_lora_adapter_info la;
la.path = path_chars;
la.scale = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f);
lora_adapters.push_back(la);
lora.push_back(la);
env->ReleaseStringUTFChars(path, path_chars);
}
}
}
env->ReleaseStringUTFChars(lora_str, lora_chars);
int result = llama->applyLoraAdapters(lora_adapters);
int result = llama->applyLoraAdapters(lora);
if (result != 0) {
LOGI("[RNLlama] Failed to apply lora adapters");
llama_free(llama->ctx);
Expand Down Expand Up @@ -946,7 +946,7 @@ Java_com_rnllama_LlamaContext_getLoadedLoraAdapters(
auto llama = context_map[(long) context_ptr];
auto loaded_lora_adapters = llama->getLoadedLoraAdapters();
auto result = createWritableArray(env);
for (common_lora_adapter_container &la : loaded_lora_adapters) {
for (common_lora_adapter_info &la : loaded_lora_adapters) {
auto map = createWriteableMap(env);
putString(env, map, "path", la.path.c_str());
putDouble(env, map, "scaled", la.scale);
Expand All @@ -961,17 +961,8 @@ Java_com_rnllama_LlamaContext_freeContext(
UNUSED(env);
UNUSED(thiz);
auto llama = context_map[(long) context_ptr];
if (llama->model) {
llama_free_model(llama->model);
}
if (llama->ctx) {
llama_free(llama->ctx);
}
if (llama->ctx_sampling != nullptr)
{
common_sampler_free(llama->ctx_sampling);
}
context_map.erase((long) llama->ctx);
delete llama;
}

JNIEXPORT void JNICALL
Expand Down
66 changes: 42 additions & 24 deletions cpp/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
#endif

#include "ggml.h"
#include "gguf.h"

#include "common.h"
#include "log.h"
#include "llama.h"
Expand All @@ -14,6 +17,7 @@
#include <cstdarg>
#include <cstring>
#include <ctime>
#include <filesystem>
#include <fstream>
#include <iostream>
#include <iterator>
Expand Down Expand Up @@ -64,7 +68,9 @@ char const *LLAMA_BUILD_TARGET = "unknown";
#ifdef __linux__
#include <linux/limits.h>
#elif defined(_WIN32)
#define PATH_MAX MAX_PATH
# if !defined(PATH_MAX)
# define PATH_MAX MAX_PATH
# endif
#else
#include <sys/syslimits.h>
#endif
Expand Down Expand Up @@ -843,7 +849,7 @@ struct common_init_result common_init_from_params(common_params & params) {
} else if (!params.model_url.empty()) {
model = common_load_model_from_url(params.model_url, params.model, params.hf_token, mparams);
} else {
model = llama_load_model_from_file(params.model.c_str(), mparams);
model = llama_model_load_from_file(params.model.c_str(), mparams);
}

if (model == NULL) {
Expand All @@ -870,7 +876,7 @@ struct common_init_result common_init_from_params(common_params & params) {
}

if (!ok) {
llama_free_model(model);
llama_model_free(model);

return iparams;
}
Expand All @@ -881,14 +887,13 @@ struct common_init_result common_init_from_params(common_params & params) {
llama_context * lctx = llama_new_context_with_model(model, cparams);
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str());
llama_free_model(model);
llama_model_free(model);
return iparams;
}

if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) {
LOG_ERR("%s: KV cache shifting is not supported for this model (--no-context-shift to disable)'\n", __func__);
llama_free_model(model);
return iparams;
LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__);
params.ctx_shift = false;
}

if (!params.control_vectors.empty()) {
Expand All @@ -898,7 +903,7 @@ struct common_init_result common_init_from_params(common_params & params) {
const auto cvec = common_control_vector_load(params.control_vectors);
if (cvec.n_embd == -1) {
llama_free(lctx);
llama_free_model(model);
llama_model_free(model);

return iparams;
}
Expand All @@ -911,28 +916,29 @@ struct common_init_result common_init_from_params(common_params & params) {
params.control_vector_layer_end);
if (err) {
llama_free(lctx);
llama_free_model(model);
llama_model_free(model);

return iparams;
}
}

// load and optionally apply lora adapters
for (auto & la : params.lora_adapters) {
common_lora_adapter_container loaded_la;
loaded_la.path = la.path;
loaded_la.scale = la.scale;
loaded_la.adapter = llama_lora_adapter_init(model, la.path.c_str());
if (loaded_la.adapter == nullptr) {
llama_lora_adapter_ptr lora;
lora.reset(llama_lora_adapter_init(model, la.path.c_str()));
if (lora == nullptr) {
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
llama_free(lctx);
llama_free_model(model);
llama_model_free(model);
return iparams;
}
iparams.lora_adapters.push_back(loaded_la); // copy to list of loaded adapters

la.ptr = lora.get();
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
}

if (!params.lora_init_without_apply) {
common_lora_adapters_apply(lctx, iparams.lora_adapters);
common_lora_adapters_apply(lctx, params.lora_adapters);
}

if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) {
Expand Down Expand Up @@ -979,7 +985,7 @@ struct common_init_result common_init_from_params(common_params & params) {
if (llama_model_has_encoder(model)) {
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size()));
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == -1) {
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
decoder_start_token_id = bos;
}
tmp.clear();
Expand All @@ -993,17 +999,17 @@ struct common_init_result common_init_from_params(common_params & params) {
llama_perf_context_reset(lctx);
}

iparams.model = model;
iparams.context = lctx;
iparams.model.reset(model);
iparams.context.reset(lctx);

return iparams;
}

void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_container> & lora_adapters) {
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_info> & lora) {
llama_lora_adapter_clear(ctx);
for (auto & la : lora_adapters) {
for (auto & la : lora) {
if (la.scale != 0.0f) {
llama_lora_adapter_set(ctx, la.adapter, la.scale);
llama_lora_adapter_set(ctx, la.ptr, la.scale);
}
}
}
Expand Down Expand Up @@ -1201,7 +1207,7 @@ struct llama_model * common_load_model_from_url(
}
}

return llama_load_model_from_file(local_path.c_str(), params);
return llama_model_load_from_file(local_path.c_str(), params);
}

struct llama_model * common_load_model_from_hf(
Expand Down Expand Up @@ -1404,6 +1410,18 @@ std::string common_detokenize(llama_context * ctx, const std::vector<llama_token
// Chat template utils
//

std::string common_get_builtin_chat_template(const struct llama_model * model) {
static const char * template_key = "tokenizer.chat_template";
// call with NULL buffer to get the total size of the string
int32_t res = llama_model_meta_val_str(model, template_key, NULL, 0);
if (res > 0) {
std::vector<char> model_template(res + 1, 0);
llama_model_meta_val_str(model, template_key, model_template.data(), model_template.size());
return std::string(model_template.data(), model_template.size() - 1);
}
return "";
}

bool common_chat_verify_template(const std::string & tmpl) {
llama_chat_message chat[] = {{"user", "test"}};
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
Expand Down
29 changes: 18 additions & 11 deletions cpp/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#pragma once

#include "llama.h"
#include "llama-cpp.h"

#include <string>
#include <vector>
Expand All @@ -27,10 +27,8 @@
struct common_lora_adapter_info {
std::string path;
float scale;
};

struct common_lora_adapter_container : common_lora_adapter_info {
struct llama_lora_adapter * adapter;
struct llama_lora_adapter * ptr;
};

using llama_tokens = std::vector<llama_token>;
Expand Down Expand Up @@ -493,10 +491,12 @@ std::string fs_get_cache_file(const std::string & filename);
// Model utils
//

// note: defines object's lifetime
struct common_init_result {
struct llama_model * model = nullptr;
struct llama_context * context = nullptr;
std::vector<common_lora_adapter_container> lora_adapters;
llama_model_ptr model;
llama_context_ptr context;

std::vector<llama_lora_adapter_ptr> lora;
};

struct common_init_result common_init_from_params(common_params & params);
Expand All @@ -518,7 +518,7 @@ struct llama_model * common_load_model_from_hf(
const struct llama_model_params & params);

// clear LoRA adapters from context, then apply new list of adapters
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_container> & lora_adapters);
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_info> & lora);

//
// Batch utils
Expand Down Expand Up @@ -586,6 +586,9 @@ struct common_chat_msg {
std::string content;
};

// Get the built-in chat template for the model. Return empty string if not present.
std::string common_get_builtin_chat_template(const struct llama_model * model);

// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool common_chat_verify_template(const std::string & tmpl);

Expand Down Expand Up @@ -652,6 +655,10 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
// Split utils
//

static const char * const LLM_KV_SPLIT_NO = "split.no";
static const char * const LLM_KV_SPLIT_COUNT = "split.count";
static const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
namespace {

const char * const LLM_KV_SPLIT_NO = "split.no";
const char * const LLM_KV_SPLIT_COUNT = "split.count";
const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";

}
Loading

0 comments on commit b539012

Please sign in to comment.