Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: sync llama.cpp #108

Merged
merged 13 commits into from
Jan 9, 2025
Merged
14 changes: 14 additions & 0 deletions android/src/main/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,25 @@ set(
${RNLLAMA_LIB_DIR}/ggml-opt.cpp
${RNLLAMA_LIB_DIR}/ggml-threading.cpp
${RNLLAMA_LIB_DIR}/ggml-quants.c
${RNLLAMA_LIB_DIR}/gguf.cpp
${RNLLAMA_LIB_DIR}/log.cpp
${RNLLAMA_LIB_DIR}/llama-impl.cpp
${RNLLAMA_LIB_DIR}/llama-grammar.cpp
${RNLLAMA_LIB_DIR}/llama-sampling.cpp
${RNLLAMA_LIB_DIR}/llama-vocab.cpp
${RNLLAMA_LIB_DIR}/llama-adapter.cpp
${RNLLAMA_LIB_DIR}/llama-chat.cpp
${RNLLAMA_LIB_DIR}/llama-context.cpp
${RNLLAMA_LIB_DIR}/llama-kv-cache.cpp
${RNLLAMA_LIB_DIR}/llama-arch.cpp
${RNLLAMA_LIB_DIR}/llama-batch.cpp
${RNLLAMA_LIB_DIR}/llama-cparams.cpp
${RNLLAMA_LIB_DIR}/llama-hparams.cpp
${RNLLAMA_LIB_DIR}/llama.cpp
${RNLLAMA_LIB_DIR}/llama-model.cpp
${RNLLAMA_LIB_DIR}/llama-model-loader.cpp
${RNLLAMA_LIB_DIR}/llama-mmap.cpp
${RNLLAMA_LIB_DIR}/llama-vocab.cpp
${RNLLAMA_LIB_DIR}/sampling.cpp
${RNLLAMA_LIB_DIR}/unicode-data.cpp
${RNLLAMA_LIB_DIR}/unicode.cpp
Expand Down
25 changes: 8 additions & 17 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__)
#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__)

#define LOGE(...) __android_log_print(ANDROID_LOG_ERROR, TAG, __VA_ARGS__)
static inline int min(int a, int b) {
return (a < b) ? a : b;
}
Expand Down Expand Up @@ -198,7 +198,7 @@ Java_com_rnllama_LlamaContext_modelInfo(
continue;
}

const std::string value = rnllama::lm_gguf_kv_to_str(ctx, i);
const std::string value = lm_gguf_kv_to_str(ctx, i);
putString(env, info, key, value.c_str());
}
}
Expand Down Expand Up @@ -336,13 +336,13 @@ Java_com_rnllama_LlamaContext_initContext(
llama_free(llama->ctx);
}

std::vector<common_lora_adapter_info> lora_adapters;
std::vector<common_lora_adapter_info> lora;
const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr);
if (lora_chars != nullptr && lora_chars[0] != '\0') {
common_lora_adapter_info la;
la.path = lora_chars;
la.scale = lora_scaled;
lora_adapters.push_back(la);
lora.push_back(la);
}

if (lora_list != nullptr) {
Expand All @@ -356,13 +356,13 @@ Java_com_rnllama_LlamaContext_initContext(
common_lora_adapter_info la;
la.path = path_chars;
la.scale = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f);
lora_adapters.push_back(la);
lora.push_back(la);
env->ReleaseStringUTFChars(path, path_chars);
}
}
}
env->ReleaseStringUTFChars(lora_str, lora_chars);
int result = llama->applyLoraAdapters(lora_adapters);
int result = llama->applyLoraAdapters(lora);
if (result != 0) {
LOGI("[RNLlama] Failed to apply lora adapters");
llama_free(llama->ctx);
Expand Down Expand Up @@ -946,7 +946,7 @@ Java_com_rnllama_LlamaContext_getLoadedLoraAdapters(
auto llama = context_map[(long) context_ptr];
auto loaded_lora_adapters = llama->getLoadedLoraAdapters();
auto result = createWritableArray(env);
for (common_lora_adapter_container &la : loaded_lora_adapters) {
for (common_lora_adapter_info &la : loaded_lora_adapters) {
auto map = createWriteableMap(env);
putString(env, map, "path", la.path.c_str());
putDouble(env, map, "scaled", la.scale);
Expand All @@ -961,17 +961,8 @@ Java_com_rnllama_LlamaContext_freeContext(
UNUSED(env);
UNUSED(thiz);
auto llama = context_map[(long) context_ptr];
if (llama->model) {
llama_free_model(llama->model);
}
if (llama->ctx) {
llama_free(llama->ctx);
}
if (llama->ctx_sampling != nullptr)
{
common_sampler_free(llama->ctx_sampling);
}
context_map.erase((long) llama->ctx);
delete llama;
}

JNIEXPORT void JNICALL
Expand Down
66 changes: 42 additions & 24 deletions cpp/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
#endif

#include "ggml.h"
#include "gguf.h"

#include "common.h"
#include "log.h"
#include "llama.h"
Expand All @@ -14,6 +17,7 @@
#include <cstdarg>
#include <cstring>
#include <ctime>
#include <filesystem>
#include <fstream>
#include <iostream>
#include <iterator>
Expand Down Expand Up @@ -64,7 +68,9 @@ char const *LLAMA_BUILD_TARGET = "unknown";
#ifdef __linux__
#include <linux/limits.h>
#elif defined(_WIN32)
#define PATH_MAX MAX_PATH
# if !defined(PATH_MAX)
# define PATH_MAX MAX_PATH
# endif
#else
#include <sys/syslimits.h>
#endif
Expand Down Expand Up @@ -843,7 +849,7 @@ struct common_init_result common_init_from_params(common_params & params) {
} else if (!params.model_url.empty()) {
model = common_load_model_from_url(params.model_url, params.model, params.hf_token, mparams);
} else {
model = llama_load_model_from_file(params.model.c_str(), mparams);
model = llama_model_load_from_file(params.model.c_str(), mparams);
}

if (model == NULL) {
Expand All @@ -870,7 +876,7 @@ struct common_init_result common_init_from_params(common_params & params) {
}

if (!ok) {
llama_free_model(model);
llama_model_free(model);

return iparams;
}
Expand All @@ -881,14 +887,13 @@ struct common_init_result common_init_from_params(common_params & params) {
llama_context * lctx = llama_new_context_with_model(model, cparams);
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str());
llama_free_model(model);
llama_model_free(model);
return iparams;
}

if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) {
LOG_ERR("%s: KV cache shifting is not supported for this model (--no-context-shift to disable)'\n", __func__);
llama_free_model(model);
return iparams;
LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__);
params.ctx_shift = false;
}

if (!params.control_vectors.empty()) {
Expand All @@ -898,7 +903,7 @@ struct common_init_result common_init_from_params(common_params & params) {
const auto cvec = common_control_vector_load(params.control_vectors);
if (cvec.n_embd == -1) {
llama_free(lctx);
llama_free_model(model);
llama_model_free(model);

return iparams;
}
Expand All @@ -911,28 +916,29 @@ struct common_init_result common_init_from_params(common_params & params) {
params.control_vector_layer_end);
if (err) {
llama_free(lctx);
llama_free_model(model);
llama_model_free(model);

return iparams;
}
}

// load and optionally apply lora adapters
for (auto & la : params.lora_adapters) {
common_lora_adapter_container loaded_la;
loaded_la.path = la.path;
loaded_la.scale = la.scale;
loaded_la.adapter = llama_lora_adapter_init(model, la.path.c_str());
if (loaded_la.adapter == nullptr) {
llama_lora_adapter_ptr lora;
lora.reset(llama_lora_adapter_init(model, la.path.c_str()));
if (lora == nullptr) {
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
llama_free(lctx);
llama_free_model(model);
llama_model_free(model);
return iparams;
}
iparams.lora_adapters.push_back(loaded_la); // copy to list of loaded adapters

la.ptr = lora.get();
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
}

if (!params.lora_init_without_apply) {
common_lora_adapters_apply(lctx, iparams.lora_adapters);
common_lora_adapters_apply(lctx, params.lora_adapters);
}

if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) {
Expand Down Expand Up @@ -979,7 +985,7 @@ struct common_init_result common_init_from_params(common_params & params) {
if (llama_model_has_encoder(model)) {
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size()));
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == -1) {
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
decoder_start_token_id = bos;
}
tmp.clear();
Expand All @@ -993,17 +999,17 @@ struct common_init_result common_init_from_params(common_params & params) {
llama_perf_context_reset(lctx);
}

iparams.model = model;
iparams.context = lctx;
iparams.model.reset(model);
iparams.context.reset(lctx);

return iparams;
}

void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_container> & lora_adapters) {
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_info> & lora) {
llama_lora_adapter_clear(ctx);
for (auto & la : lora_adapters) {
for (auto & la : lora) {
if (la.scale != 0.0f) {
llama_lora_adapter_set(ctx, la.adapter, la.scale);
llama_lora_adapter_set(ctx, la.ptr, la.scale);
}
}
}
Expand Down Expand Up @@ -1201,7 +1207,7 @@ struct llama_model * common_load_model_from_url(
}
}

return llama_load_model_from_file(local_path.c_str(), params);
return llama_model_load_from_file(local_path.c_str(), params);
}

struct llama_model * common_load_model_from_hf(
Expand Down Expand Up @@ -1404,6 +1410,18 @@ std::string common_detokenize(llama_context * ctx, const std::vector<llama_token
// Chat template utils
//

std::string common_get_builtin_chat_template(const struct llama_model * model) {
static const char * template_key = "tokenizer.chat_template";
// call with NULL buffer to get the total size of the string
int32_t res = llama_model_meta_val_str(model, template_key, NULL, 0);
if (res > 0) {
std::vector<char> model_template(res + 1, 0);
llama_model_meta_val_str(model, template_key, model_template.data(), model_template.size());
return std::string(model_template.data(), model_template.size() - 1);
}
return "";
}

bool common_chat_verify_template(const std::string & tmpl) {
llama_chat_message chat[] = {{"user", "test"}};
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
Expand Down
29 changes: 18 additions & 11 deletions cpp/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

#pragma once

#include "llama.h"
#include "llama-cpp.h"

#include <string>
#include <vector>
Expand All @@ -27,10 +27,8 @@
struct common_lora_adapter_info {
std::string path;
float scale;
};

struct common_lora_adapter_container : common_lora_adapter_info {
struct llama_lora_adapter * adapter;
struct llama_lora_adapter * ptr;
};

using llama_tokens = std::vector<llama_token>;
Expand Down Expand Up @@ -493,10 +491,12 @@ std::string fs_get_cache_file(const std::string & filename);
// Model utils
//

// note: defines object's lifetime
struct common_init_result {
struct llama_model * model = nullptr;
struct llama_context * context = nullptr;
std::vector<common_lora_adapter_container> lora_adapters;
llama_model_ptr model;
llama_context_ptr context;

std::vector<llama_lora_adapter_ptr> lora;
};

struct common_init_result common_init_from_params(common_params & params);
Expand All @@ -518,7 +518,7 @@ struct llama_model * common_load_model_from_hf(
const struct llama_model_params & params);

// clear LoRA adapters from context, then apply new list of adapters
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_container> & lora_adapters);
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_info> & lora);

//
// Batch utils
Expand Down Expand Up @@ -586,6 +586,9 @@ struct common_chat_msg {
std::string content;
};

// Get the built-in chat template for the model. Return empty string if not present.
std::string common_get_builtin_chat_template(const struct llama_model * model);

// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
bool common_chat_verify_template(const std::string & tmpl);

Expand Down Expand Up @@ -652,6 +655,10 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
// Split utils
//

static const char * const LLM_KV_SPLIT_NO = "split.no";
static const char * const LLM_KV_SPLIT_COUNT = "split.count";
static const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
namespace {

const char * const LLM_KV_SPLIT_NO = "split.no";
const char * const LLM_KV_SPLIT_COUNT = "split.count";
const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";

}
Loading
Loading