diff --git a/CMakeLists.txt b/CMakeLists.txt index 0f52ece..8c1f91f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -61,7 +61,21 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(LLAMA_STATIC ON CACHE BOOL "Build llama as static library") add_subdirectory("src/llama.cpp") -file(GLOB SOURCE_FILES "src/addons.cpp") +file( + GLOB SOURCE_FILES + "src/addons.cc" + "src/common.hpp" + "src/DisposeWorker.cpp" + "src/DisposeWorker.h" + "src/LlamaCompletionWorker.cpp" + "src/LlamaCompletionWorker.h" + "src/LlamaContext.cpp" + "src/LlamaContext.h" + "src/LoadSessionWorker.cpp" + "src/LoadSessionWorker.h" + "src/SaveSessionWorker.cpp" + "src/SaveSessionWorker.h" +) add_library(${PROJECT_NAME} SHARED ${SOURCE_FILES} ${CMAKE_JS_SRC}) set_target_properties(${PROJECT_NAME} PROPERTIES PREFIX "" SUFFIX ".node") diff --git a/src/DisposeWorker.cpp b/src/DisposeWorker.cpp new file mode 100644 index 0000000..2bfabe5 --- /dev/null +++ b/src/DisposeWorker.cpp @@ -0,0 +1,11 @@ +#include "DisposeWorker.h" + +DisposeWorker::DisposeWorker(const Napi::CallbackInfo &info, + LlamaSessionPtr sess) + : AsyncWorker(info.Env()), Deferred(info.Env()), sess_(std::move(sess)) {} + +void DisposeWorker::Execute() { sess_->dispose(); } + +void DisposeWorker::OnOK() { Resolve(AsyncWorker::Env().Undefined()); } + +void DisposeWorker::OnError(const Napi::Error &err) { Reject(err.Value()); } diff --git a/src/DisposeWorker.h b/src/DisposeWorker.h new file mode 100644 index 0000000..48adb2c --- /dev/null +++ b/src/DisposeWorker.h @@ -0,0 +1,14 @@ +#include "common.hpp" + +class DisposeWorker : public Napi::AsyncWorker, public Napi::Promise::Deferred { +public: + DisposeWorker(const Napi::CallbackInfo &info, LlamaSessionPtr sess); + +protected: + void Execute(); + void OnOK(); + void OnError(const Napi::Error &err); + +private: + LlamaSessionPtr sess_; +}; diff --git a/src/LlamaCompletionWorker.cpp b/src/LlamaCompletionWorker.cpp new file mode 100644 index 0000000..9895f45 --- /dev/null +++ b/src/LlamaCompletionWorker.cpp @@ -0,0 +1,163 @@ +#include "LlamaCompletionWorker.h" +#include "LlamaContext.h" + +size_t common_part(const std::vector &a, + const std::vector &b) { + size_t i = 0; + while (i < a.size() && i < b.size() && a[i] == b[i]) { + i++; + } + return i; +} + +size_t findStoppingStrings(const std::string &text, + const size_t last_token_size, + const std::vector &stop_words) { + size_t stop_pos = std::string::npos; + + for (const std::string &word : stop_words) { + size_t pos; + + const size_t tmp = word.size() + last_token_size; + const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; + + pos = text.find(word, from_pos); + + if (pos != std::string::npos && + (stop_pos == std::string::npos || pos < stop_pos)) { + stop_pos = pos; + } + } + + return stop_pos; +} + +LlamaCompletionWorker::LlamaCompletionWorker( + const Napi::CallbackInfo &info, LlamaSessionPtr &sess, + Napi::Function callback, gpt_params params, + std::vector stop_words) + : AsyncWorker(info.Env()), Deferred(info.Env()), _sess(sess), + _params(params), _stop_words(stop_words) { + if (!callback.IsEmpty()) { + _tsfn = Napi::ThreadSafeFunction::New(info.Env(), callback, + "LlamaCompletionCallback", 0, 1); + _has_callback = true; + } +} + +LlamaCompletionWorker::~LlamaCompletionWorker() { + if (_has_callback) { + _tsfn.Release(); + } +} + +void LlamaCompletionWorker::Execute() { + _sess->get_mutex().lock(); + const auto t_main_start = ggml_time_us(); + const size_t n_ctx = _params.n_ctx; + const auto n_keep = _params.n_keep; + size_t n_cur = 0; + size_t n_input = 0; + const auto model = llama_get_model(_sess->context()); + const bool add_bos = llama_should_add_bos_token(model); + auto ctx = _sess->context(); + + llama_set_rng_seed(ctx, _params.seed); + + LlamaCppSampling sampling{llama_sampling_init(_params.sparams), + llama_sampling_free}; + + std::vector prompt_tokens = + ::llama_tokenize(ctx, _params.prompt, add_bos); + n_input = prompt_tokens.size(); + if (_sess->tokens_ptr()->size() > 0) { + n_cur = common_part(*(_sess->tokens_ptr()), prompt_tokens); + if (n_cur == n_input) { + --n_cur; + } + n_input -= n_cur; + llama_kv_cache_seq_rm(ctx, 0, n_cur, -1); + } + _sess->set_tokens(std::move(prompt_tokens)); + + const int max_len = _params.n_predict < 0 ? 0 : _params.n_predict; + _sess->tokens_ptr()->reserve(_sess->tokens_ptr()->size() + max_len); + + auto embd = _sess->tokens_ptr(); + for (int i = 0; i < max_len || _stop; i++) { + // check if we need to remove some tokens + if (embd->size() >= _params.n_ctx) { + const int n_left = n_cur - n_keep - 1; + const int n_discard = n_left / 2; + + llama_kv_cache_seq_rm(ctx, 0, n_keep + 1, n_keep + n_discard + 1); + llama_kv_cache_seq_add(ctx, 0, n_keep + 1 + n_discard, n_cur, -n_discard); + + // shift the tokens + embd->insert(embd->begin() + n_keep + 1, + embd->begin() + n_keep + 1 + n_discard, embd->end()); + embd->resize(embd->size() - n_discard); + + n_cur -= n_discard; + _result.truncated = true; + } + int ret = llama_decode( + ctx, llama_batch_get_one(embd->data() + n_cur, n_input, n_cur, 0)); + if (ret < 0) { + SetError("Failed to decode token, code: " + std::to_string(ret)); + break; + } + // sample the next token + const llama_token new_token_id = + llama_sampling_sample(sampling.get(), ctx, nullptr); + // prepare the next batch + embd->emplace_back(new_token_id); + auto token = llama_token_to_piece(ctx, new_token_id); + _result.text += token; + n_cur += n_input; + _result.tokens_evaluated += n_input; + _result.tokens_predicted += 1; + n_input = 1; + if (_has_callback) { + const char *c_token = strdup(token.c_str()); + _tsfn.BlockingCall(c_token, [](Napi::Env env, Napi::Function jsCallback, + const char *value) { + auto obj = Napi::Object::New(env); + obj.Set("token", Napi::String::New(env, value)); + delete value; + jsCallback.Call({obj}); + }); + } + // is it an end of generation? + if (llama_token_is_eog(model, new_token_id)) { + break; + } + // check for stop words + if (!_stop_words.empty()) { + const size_t stop_pos = + findStoppingStrings(_result.text, token.size(), _stop_words); + if (stop_pos != std::string::npos) { + break; + } + } + } + const auto t_main_end = ggml_time_us(); + _sess->get_mutex().unlock(); +} + +void LlamaCompletionWorker::OnOK() { + auto result = Napi::Object::New(Napi::AsyncWorker::Env()); + result.Set("tokens_evaluated", Napi::Number::New(Napi::AsyncWorker::Env(), + _result.tokens_evaluated)); + result.Set("tokens_predicted", Napi::Number::New(Napi::AsyncWorker::Env(), + _result.tokens_predicted)); + result.Set("truncated", + Napi::Boolean::New(Napi::AsyncWorker::Env(), _result.truncated)); + result.Set("text", + Napi::String::New(Napi::AsyncWorker::Env(), _result.text.c_str())); + Napi::Promise::Deferred::Resolve(result); +} + +void LlamaCompletionWorker::OnError(const Napi::Error &err) { + Napi::Promise::Deferred::Reject(err.Value()); +} diff --git a/src/LlamaCompletionWorker.h b/src/LlamaCompletionWorker.h new file mode 100644 index 0000000..3ca7377 --- /dev/null +++ b/src/LlamaCompletionWorker.h @@ -0,0 +1,34 @@ +#include "common.hpp" + +struct CompletionResult { + std::string text = ""; + bool truncated = false; + size_t tokens_predicted = 0; + size_t tokens_evaluated = 0; +}; + +class LlamaCompletionWorker : public Napi::AsyncWorker, + public Napi::Promise::Deferred { +public: + LlamaCompletionWorker(const Napi::CallbackInfo &info, LlamaSessionPtr &sess, + Napi::Function callback, gpt_params params, + std::vector stop_words = {}); + + ~LlamaCompletionWorker(); + + inline void Stop() { _stop = true; } + +protected: + void Execute(); + void OnOK(); + void OnError(const Napi::Error &err); + +private: + LlamaSessionPtr _sess; + gpt_params _params; + std::vector _stop_words; + Napi::ThreadSafeFunction _tsfn; + bool _has_callback = false; + bool _stop = false; + CompletionResult _result; +}; diff --git a/src/LlamaContext.cpp b/src/LlamaContext.cpp new file mode 100644 index 0000000..d905de0 --- /dev/null +++ b/src/LlamaContext.cpp @@ -0,0 +1,200 @@ +#include "LlamaContext.h" +#include "DisposeWorker.h" +#include "LlamaCompletionWorker.h" +#include "LoadSessionWorker.h" +#include "SaveSessionWorker.h" + +void LlamaContext::Init(Napi::Env env, Napi::Object &exports) { + Napi::Function func = DefineClass( + env, "LlamaContext", + {InstanceMethod<&LlamaContext::GetSystemInfo>( + "getSystemInfo", + static_cast(napi_enumerable)), + InstanceMethod<&LlamaContext::Completion>( + "completion", + static_cast(napi_enumerable)), + InstanceMethod<&LlamaContext::StopCompletion>( + "stopCompletion", + static_cast(napi_enumerable)), + InstanceMethod<&LlamaContext::SaveSession>( + "saveSession", + static_cast(napi_enumerable)), + InstanceMethod<&LlamaContext::LoadSession>( + "loadSession", + static_cast(napi_enumerable)), + InstanceMethod<&LlamaContext::Release>( + "release", static_cast(napi_enumerable))}); + Napi::FunctionReference *constructor = new Napi::FunctionReference(); + *constructor = Napi::Persistent(func); +#if NAPI_VERSION > 5 + env.SetInstanceData(constructor); +#endif + exports.Set("LlamaContext", func); +} + +// construct({ model, embedding, n_ctx, n_batch, n_threads, n_gpu_layers, +// use_mlock, use_mmap }): LlamaContext throws error +LlamaContext::LlamaContext(const Napi::CallbackInfo &info) + : Napi::ObjectWrap(info) { + Napi::Env env = info.Env(); + if (info.Length() < 1 || !info[0].IsObject()) { + Napi::TypeError::New(env, "Object expected").ThrowAsJavaScriptException(); + } + auto options = info[0].As(); + + gpt_params params; + params.model = get_option(options, "model", ""); + if (params.model.empty()) { + Napi::TypeError::New(env, "Model is required").ThrowAsJavaScriptException(); + } + params.embedding = get_option(options, "embedding", false); + params.n_ctx = get_option(options, "n_ctx", 512); + params.n_batch = get_option(options, "n_batch", 2048); + params.n_threads = + get_option(options, "n_threads", get_math_cpu_count() / 2); + params.n_gpu_layers = get_option(options, "n_gpu_layers", -1); + params.use_mlock = get_option(options, "use_mlock", false); + params.use_mmap = get_option(options, "use_mmap", true); + params.numa = + static_cast(get_option(options, "numa", 0)); + + llama_backend_init(); + llama_numa_init(params.numa); + + llama_model *model; + llama_context *ctx; + std::tie(model, ctx) = llama_init_from_gpt_params(params); + + if (model == nullptr || ctx == nullptr) { + Napi::TypeError::New(env, "Failed to load model") + .ThrowAsJavaScriptException(); + } + + _sess = std::make_shared(ctx, params); + _info = get_system_info(params); +} + +// getSystemInfo(): string +Napi::Value LlamaContext::GetSystemInfo(const Napi::CallbackInfo &info) { + return Napi::String::New(info.Env(), _info); +} + +// completion(options: LlamaCompletionOptions, onToken?: (token: string) => +// void): Promise +Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) { + Napi::Env env = info.Env(); + if (info.Length() < 1 || !info[0].IsObject()) { + Napi::TypeError::New(env, "Object expected").ThrowAsJavaScriptException(); + } + if (info.Length() >= 2 && !info[1].IsFunction()) { + Napi::TypeError::New(env, "Function expected").ThrowAsJavaScriptException(); + } + if (_sess == nullptr) { + Napi::TypeError::New(env, "Context is disposed") + .ThrowAsJavaScriptException(); + } + auto options = info[0].As(); + + gpt_params params = _sess->params(); + params.prompt = get_option(options, "prompt", ""); + if (params.prompt.empty()) { + Napi::TypeError::New(env, "Prompt is required") + .ThrowAsJavaScriptException(); + } + params.n_predict = get_option(options, "n_predict", -1); + params.sparams.temp = get_option(options, "temperature", 0.80f); + params.sparams.top_k = get_option(options, "top_k", 40); + params.sparams.top_p = get_option(options, "top_p", 0.95f); + params.sparams.min_p = get_option(options, "min_p", 0.05f); + params.sparams.tfs_z = get_option(options, "tfs_z", 1.00f); + params.sparams.mirostat = get_option(options, "mirostat", 0.00f); + params.sparams.mirostat_tau = + get_option(options, "mirostat_tau", 5.00f); + params.sparams.mirostat_eta = + get_option(options, "mirostat_eta", 0.10f); + params.sparams.penalty_last_n = + get_option(options, "penalty_last_n", 64); + params.sparams.penalty_repeat = + get_option(options, "penalty_repeat", 1.00f); + params.sparams.penalty_freq = + get_option(options, "penalty_freq", 0.00f); + params.sparams.penalty_present = + get_option(options, "penalty_present", 0.00f); + params.sparams.penalize_nl = get_option(options, "penalize_nl", false); + params.sparams.typical_p = get_option(options, "typical_p", 1.00f); + params.ignore_eos = get_option(options, "ignore_eos", false); + params.sparams.grammar = get_option(options, "grammar", ""); + params.n_keep = get_option(options, "n_keep", 0); + params.seed = get_option(options, "seed", LLAMA_DEFAULT_SEED); + std::vector stop_words; + if (options.Has("stop") && options.Get("stop").IsArray()) { + auto stop_words_array = options.Get("stop").As(); + for (size_t i = 0; i < stop_words_array.Length(); i++) { + stop_words.push_back(stop_words_array.Get(i).ToString().Utf8Value()); + } + } + + Napi::Function callback; + if (info.Length() >= 2) { + callback = info[1].As(); + } + + auto *worker = + new LlamaCompletionWorker(info, _sess, callback, params, stop_words); + worker->Queue(); + _wip = worker; + return worker->Promise(); +} + +// stopCompletion(): void +void LlamaContext::StopCompletion(const Napi::CallbackInfo &info) { + if (_wip != nullptr) { + _wip->Stop(); + } +} + +// saveSession(path: string): Promise throws error +Napi::Value LlamaContext::SaveSession(const Napi::CallbackInfo &info) { + Napi::Env env = info.Env(); + if (info.Length() < 1 || !info[0].IsString()) { + Napi::TypeError::New(env, "String expected").ThrowAsJavaScriptException(); + } + if (_sess == nullptr) { + Napi::TypeError::New(env, "Context is disposed") + .ThrowAsJavaScriptException(); + } + auto *worker = new SaveSessionWorker(info, _sess); + worker->Queue(); + return worker->Promise(); +} + +// loadSession(path: string): Promise<{ count }> throws error +Napi::Value LlamaContext::LoadSession(const Napi::CallbackInfo &info) { + Napi::Env env = info.Env(); + if (info.Length() < 1 || !info[0].IsString()) { + Napi::TypeError::New(env, "String expected").ThrowAsJavaScriptException(); + } + if (_sess == nullptr) { + Napi::TypeError::New(env, "Context is disposed") + .ThrowAsJavaScriptException(); + } + auto *worker = new LoadSessionWorker(info, _sess); + worker->Queue(); + return worker->Promise(); +} + +// release(): Promise +Napi::Value LlamaContext::Release(const Napi::CallbackInfo &info) { + auto env = info.Env(); + if (_wip != nullptr) { + _wip->Stop(); + } + if (_sess == nullptr) { + auto promise = Napi::Promise::Deferred(env); + promise.Resolve(env.Undefined()); + return promise.Promise(); + } + auto *worker = new DisposeWorker(info, std::move(_sess)); + worker->Queue(); + return worker->Promise(); +} diff --git a/src/LlamaContext.h b/src/LlamaContext.h new file mode 100644 index 0000000..37323df --- /dev/null +++ b/src/LlamaContext.h @@ -0,0 +1,21 @@ +#include "common.hpp" + +class LlamaCompletionWorker; + +class LlamaContext : public Napi::ObjectWrap { +public: + LlamaContext(const Napi::CallbackInfo &info); + static void Init(Napi::Env env, Napi::Object &exports); + +private: + Napi::Value GetSystemInfo(const Napi::CallbackInfo &info); + Napi::Value Completion(const Napi::CallbackInfo &info); + void StopCompletion(const Napi::CallbackInfo &info); + Napi::Value SaveSession(const Napi::CallbackInfo &info); + Napi::Value LoadSession(const Napi::CallbackInfo &info); + Napi::Value Release(const Napi::CallbackInfo &info); + + std::string _info; + LlamaSessionPtr _sess = nullptr; + LlamaCompletionWorker *_wip = nullptr; +}; diff --git a/src/LoadSessionWorker.cpp b/src/LoadSessionWorker.cpp new file mode 100644 index 0000000..70a0a97 --- /dev/null +++ b/src/LoadSessionWorker.cpp @@ -0,0 +1,24 @@ +#include "LoadSessionWorker.h" +#include "LlamaContext.h" + +LoadSessionWorker::LoadSessionWorker(const Napi::CallbackInfo &info, + LlamaSessionPtr &sess) + : AsyncWorker(info.Env()), Deferred(info.Env()), _path(info[0].ToString()), + _sess(sess) {} + +void LoadSessionWorker::Execute() { + _sess->get_mutex().lock(); + // reserve the maximum number of tokens for capacity + std::vector tokens; + tokens.reserve(_sess->params().n_ctx); + if (!llama_state_load_file(_sess->context(), _path.c_str(), tokens.data(), + tokens.capacity(), &count)) { + SetError("Failed to load session"); + } + _sess->set_tokens(std::move(tokens)); + _sess->get_mutex().unlock(); +} + +void LoadSessionWorker::OnOK() { Resolve(AsyncWorker::Env().Undefined()); } + +void LoadSessionWorker::OnError(const Napi::Error &err) { Reject(err.Value()); } diff --git a/src/LoadSessionWorker.h b/src/LoadSessionWorker.h new file mode 100644 index 0000000..4acbbf4 --- /dev/null +++ b/src/LoadSessionWorker.h @@ -0,0 +1,17 @@ +#include "common.hpp" + +class LoadSessionWorker : public Napi::AsyncWorker, + public Napi::Promise::Deferred { +public: + LoadSessionWorker(const Napi::CallbackInfo &info, LlamaSessionPtr &sess); + +protected: + void Execute(); + void OnOK(); + void OnError(const Napi::Error &err); + +private: + std::string _path; + LlamaSessionPtr _sess; + size_t count = 0; +}; diff --git a/src/SaveSessionWorker.cpp b/src/SaveSessionWorker.cpp new file mode 100644 index 0000000..6a8a890 --- /dev/null +++ b/src/SaveSessionWorker.cpp @@ -0,0 +1,21 @@ +#include "SaveSessionWorker.h" +#include "LlamaContext.h" + +SaveSessionWorker::SaveSessionWorker(const Napi::CallbackInfo &info, + LlamaSessionPtr &sess) + : AsyncWorker(info.Env()), Deferred(info.Env()), _path(info[0].ToString()), + _sess(sess) {} + +void SaveSessionWorker::Execute() { + _sess->get_mutex().lock(); + auto tokens = _sess->tokens_ptr(); + if (!llama_state_save_file(_sess->context(), _path.c_str(), tokens->data(), + tokens->size())) { + SetError("Failed to save session"); + } + _sess->get_mutex().unlock(); +} + +void SaveSessionWorker::OnOK() { Resolve(AsyncWorker::Env().Undefined()); } + +void SaveSessionWorker::OnError(const Napi::Error &err) { Reject(err.Value()); } diff --git a/src/SaveSessionWorker.h b/src/SaveSessionWorker.h new file mode 100644 index 0000000..cb41205 --- /dev/null +++ b/src/SaveSessionWorker.h @@ -0,0 +1,16 @@ +#include "common.hpp" + +class SaveSessionWorker : public Napi::AsyncWorker, + public Napi::Promise::Deferred { +public: + SaveSessionWorker(const Napi::CallbackInfo &info, LlamaSessionPtr &sess); + +protected: + void Execute(); + void OnOK(); + void OnError(const Napi::Error &err); + +private: + std::string _path; + LlamaSessionPtr _sess; +}; diff --git a/src/addons.cc b/src/addons.cc new file mode 100644 index 0000000..270040a --- /dev/null +++ b/src/addons.cc @@ -0,0 +1,9 @@ +#include "LlamaContext.h" +#include + +Napi::Object Init(Napi::Env env, Napi::Object exports) { + LlamaContext::Init(env, exports); + return exports; +} + +NODE_API_MODULE(addons, Init) diff --git a/src/addons.cpp b/src/addons.cpp deleted file mode 100644 index d2515f5..0000000 --- a/src/addons.cpp +++ /dev/null @@ -1,537 +0,0 @@ -#include "common/common.h" -#include "llama.h" -#include -#include -#include -#include -#include -#include -#include - -typedef std::unique_ptr LlamaCppModel; -typedef std::unique_ptr LlamaCppContext; -typedef std::unique_ptr - LlamaCppSampling; -typedef std::unique_ptr LlamaCppBatch; - -size_t common_part(const std::vector &a, - const std::vector &b) { - size_t i = 0; - while (i < a.size() && i < b.size() && a[i] == b[i]) { - i++; - } - return i; -} - -template -constexpr T get_option(const Napi::Object &options, const std::string &name, - const T default_value) { - if (options.Has(name) && !options.Get(name).IsUndefined() && - !options.Get(name).IsNull()) { - if constexpr (std::is_same::value) { - return options.Get(name).ToString().operator T(); - } else if constexpr (std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value) { - return options.Get(name).ToNumber().operator T(); - } else if constexpr (std::is_same::value) { - return options.Get(name).ToBoolean().operator T(); - } else { - static_assert(std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value, - "Unsupported type"); - } - } else { - return default_value; - } -} - -class LlamaCompletionWorker; - -class LlamaContext : public Napi::ObjectWrap { -public: - // construct({ model, embedding, n_ctx, n_batch, n_threads, n_gpu_layers, - // use_mlock, use_mmap }): LlamaContext throws error - LlamaContext(const Napi::CallbackInfo &info) - : Napi::ObjectWrap(info) { - Napi::Env env = info.Env(); - if (info.Length() < 1 || !info[0].IsObject()) { - Napi::TypeError::New(env, "Object expected").ThrowAsJavaScriptException(); - } - auto options = info[0].As(); - - params.model = get_option(options, "model", ""); - if (params.model.empty()) { - Napi::TypeError::New(env, "Model is required") - .ThrowAsJavaScriptException(); - } - params.embedding = get_option(options, "embedding", false); - params.n_ctx = get_option(options, "n_ctx", 512); - params.n_batch = get_option(options, "n_batch", 2048); - params.n_threads = - get_option(options, "n_threads", get_math_cpu_count() / 2); - params.n_gpu_layers = get_option(options, "n_gpu_layers", -1); - params.use_mlock = get_option(options, "use_mlock", false); - params.use_mmap = get_option(options, "use_mmap", true); - params.numa = static_cast( - get_option(options, "numa", 0)); - - llama_backend_init(); - llama_numa_init(params.numa); - - auto tuple = llama_init_from_gpt_params(params); - model.reset(std::get<0>(tuple)); - ctx.reset(std::get<1>(tuple)); - - if (model == nullptr || ctx == nullptr) { - Napi::TypeError::New(env, "Failed to load model") - .ThrowAsJavaScriptException(); - } - } - - static void Export(Napi::Env env, Napi::Object &exports) { - Napi::Function func = DefineClass( - env, "LlamaContext", - {InstanceMethod<&LlamaContext::GetSystemInfo>( - "getSystemInfo", - static_cast(napi_enumerable)), - InstanceMethod<&LlamaContext::Completion>( - "completion", - static_cast(napi_enumerable)), - InstanceMethod<&LlamaContext::StopCompletion>( - "stopCompletion", - static_cast(napi_enumerable)), - InstanceMethod<&LlamaContext::SaveSession>( - "saveSession", - static_cast(napi_enumerable)), - InstanceMethod<&LlamaContext::LoadSession>( - "loadSession", - static_cast(napi_enumerable)), - InstanceMethod<&LlamaContext::Release>( - "release", - static_cast(napi_enumerable))}); - Napi::FunctionReference *constructor = new Napi::FunctionReference(); - *constructor = Napi::Persistent(func); -#if NAPI_VERSION > 5 - env.SetInstanceData(constructor); -#endif - exports.Set("LlamaContext", func); - } - - llama_context *getContext() { return ctx.get(); } - llama_model *getModel() { return model.get(); } - - std::vector *getTokens() { return tokens.get(); } - - const gpt_params &getParams() const { return params; } - - void ensureTokens() { - if (tokens == nullptr) { - tokens = std::make_unique>(); - } - } - - void setTokens(std::vector tokens) { - this->tokens.reset(new std::vector(std::move(tokens))); - } - - std::mutex &getMutex() { return mutex; } - - void Dispose() { - std::lock_guard lock(mutex); - compl_worker = nullptr; - ctx.reset(); - tokens.reset(); - model.reset(); - } - -private: - Napi::Value GetSystemInfo(const Napi::CallbackInfo &info); - Napi::Value Completion(const Napi::CallbackInfo &info); - void StopCompletion(const Napi::CallbackInfo &info); - Napi::Value SaveSession(const Napi::CallbackInfo &info); - Napi::Value LoadSession(const Napi::CallbackInfo &info); - Napi::Value Release(const Napi::CallbackInfo &info); - - gpt_params params; - LlamaCppModel model{nullptr, llama_free_model}; - LlamaCppContext ctx{nullptr, llama_free}; - std::unique_ptr> tokens; - std::mutex mutex; - LlamaCompletionWorker *compl_worker = nullptr; -}; - -class LlamaCompletionWorker : public Napi::AsyncWorker, - public Napi::Promise::Deferred { - LlamaContext *_ctx; - gpt_params _params; - std::vector _stop_words; - std::string generated_text = ""; - Napi::ThreadSafeFunction _tsfn; - bool _has_callback = false; - bool _stop = false; - size_t tokens_predicted = 0; - size_t tokens_evaluated = 0; - bool truncated = false; - -public: - LlamaCompletionWorker(const Napi::CallbackInfo &info, LlamaContext *ctx, - Napi::Function callback, gpt_params params, - std::vector stop_words = {}) - : AsyncWorker(info.Env()), Deferred(info.Env()), _ctx(ctx), - _params(params), _stop_words(stop_words) { - _ctx->Ref(); - if (!callback.IsEmpty()) { - _tsfn = Napi::ThreadSafeFunction::New(info.Env(), callback, - "LlamaCompletionCallback", 0, 1); - _has_callback = true; - } - } - - ~LlamaCompletionWorker() { - _ctx->Unref(); - if (_has_callback) { - _tsfn.Abort(); - _tsfn.Release(); - } - } - - void Stop() { _stop = true; } - -protected: - size_t findStoppingStrings(const std::string &text, - const size_t last_token_size) { - size_t stop_pos = std::string::npos; - - for (const std::string &word : _stop_words) { - size_t pos; - - const size_t tmp = word.size() + last_token_size; - const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; - - pos = text.find(word, from_pos); - - if (pos != std::string::npos && - (stop_pos == std::string::npos || pos < stop_pos)) { - stop_pos = pos; - } - } - - return stop_pos; - } - - void Execute() { - _ctx->getMutex().lock(); - _ctx->ensureTokens(); - const auto t_main_start = ggml_time_us(); - const size_t n_ctx = _params.n_ctx; - auto n_keep = _params.n_keep; - auto n_predict = _params.n_predict; - size_t n_cur = 0; - size_t n_input = 0; - const bool add_bos = llama_should_add_bos_token(_ctx->getModel()); - auto *ctx = _ctx->getContext(); - - llama_set_rng_seed(ctx, _params.seed); - - LlamaCppSampling sampling{llama_sampling_init(_params.sparams), - llama_sampling_free}; - - std::vector prompt_tokens = - ::llama_tokenize(ctx, _params.prompt, add_bos); - n_input = prompt_tokens.size(); - if (_ctx->getTokens() != nullptr) { - n_cur = common_part(*_ctx->getTokens(), prompt_tokens); - if (n_cur == n_input) { - --n_cur; - } - n_input -= n_cur; - llama_kv_cache_seq_rm(ctx, 0, n_cur, -1); - } - _ctx->setTokens(std::move(prompt_tokens)); - - const int max_len = _params.n_predict < 0 ? 0 : _params.n_predict; - - for (int i = 0; i < max_len || _stop; i++) { - auto *embd = _ctx->getTokens(); - // check if we need to remove some tokens - if (embd->size() >= n_ctx) { - const int n_left = n_cur - n_keep - 1; - const int n_discard = n_left / 2; - - llama_kv_cache_seq_rm(ctx, 0, n_keep + 1, n_keep + n_discard + 1); - llama_kv_cache_seq_add(ctx, 0, n_keep + 1 + n_discard, n_cur, - -n_discard); - - for (size_t i = n_keep + 1 + n_discard; i < embd->size(); i++) { - (*embd)[i - n_discard] = (*embd)[i]; - } - embd->resize(embd->size() - n_discard); - - n_cur -= n_discard; - truncated = true; - } - int ret = llama_decode( - ctx, llama_batch_get_one(embd->data() + n_cur, n_input, n_cur, 0)); - if (ret < 0) { - SetError("Failed to decode token, code: " + std::to_string(ret)); - break; - } - // sample the next token - const llama_token new_token_id = - llama_sampling_sample(sampling.get(), ctx, nullptr); - // prepare the next batch - embd->push_back(new_token_id); - auto token = llama_token_to_piece(ctx, new_token_id); - generated_text += token; - n_cur += n_input; - tokens_evaluated += n_input; - tokens_predicted += 1; - n_input = 1; - if (_has_callback) { - const char *c_token = strdup(token.c_str()); - _tsfn.BlockingCall(c_token, [](Napi::Env env, Napi::Function jsCallback, - const char *value) { - auto obj = Napi::Object::New(env); - obj.Set("token", Napi::String::New(env, value)); - jsCallback.Call({obj}); - }); - } - // is it an end of generation? - if (llama_token_is_eog(_ctx->getModel(), new_token_id)) { - break; - } - // check for stop words - if (!_stop_words.empty()) { - const size_t stop_pos = - findStoppingStrings(generated_text, token.size()); - if (stop_pos != std::string::npos) { - break; - } - } - } - const auto t_main_end = ggml_time_us(); - _ctx->getMutex().unlock(); - } - - void OnOK() { - auto result = Napi::Object::New(Napi::AsyncWorker::Env()); - result.Set("tokens_evaluated", - Napi::Number::New(Napi::AsyncWorker::Env(), tokens_evaluated)); - result.Set("tokens_predicted", - Napi::Number::New(Napi::AsyncWorker::Env(), tokens_predicted)); - result.Set("truncated", - Napi::Boolean::New(Napi::AsyncWorker::Env(), truncated)); - result.Set("text", - Napi::String::New(Napi::AsyncWorker::Env(), generated_text)); - Napi::Promise::Deferred::Resolve(result); - } - - void OnError(const Napi::Error &err) { - Napi::Promise::Deferred::Reject(err.Value()); - } -}; - -class SaveSessionWorker : public Napi::AsyncWorker, - public Napi::Promise::Deferred { - std::string _path; - LlamaContext *_ctx; - -public: - SaveSessionWorker(const Napi::CallbackInfo &info, LlamaContext *ctx) - : AsyncWorker(info.Env()), Deferred(info.Env()), - _path(info[0].ToString()), _ctx(ctx) { - _ctx->Ref(); - } - -protected: - void Execute() { - _ctx->getMutex().lock(); - if (_ctx->getTokens() == nullptr) { - SetError("Failed to save session"); - return; - } - if (!llama_state_save_file(_ctx->getContext(), _path.c_str(), - _ctx->getTokens()->data(), - _ctx->getTokens()->size())) { - SetError("Failed to save session"); - } - _ctx->getMutex().unlock(); - } - - void OnOK() { Resolve(AsyncWorker::Env().Undefined()); } - - void OnError(const Napi::Error &err) { Reject(err.Value()); } -}; - -class LoadSessionWorker : public Napi::AsyncWorker, - public Napi::Promise::Deferred { - std::string _path; - LlamaContext *_ctx; - size_t count = 0; - -public: - LoadSessionWorker(const Napi::CallbackInfo &info, LlamaContext *ctx) - : AsyncWorker(info.Env()), Deferred(info.Env()), - _path(info[0].ToString()), _ctx(ctx) { - _ctx->Ref(); - } - -protected: - void Execute() { - _ctx->getMutex().lock(); - _ctx->ensureTokens(); - // reserve the maximum number of tokens for capacity - _ctx->getTokens()->reserve(_ctx->getParams().n_ctx); - if (!llama_state_load_file(_ctx->getContext(), _path.c_str(), - _ctx->getTokens()->data(), - _ctx->getTokens()->capacity(), &count)) { - SetError("Failed to load session"); - } - _ctx->getMutex().unlock(); - } - - void OnOK() { Resolve(AsyncWorker::Env().Undefined()); } - - void OnError(const Napi::Error &err) { Reject(err.Value()); } -}; - -class DisposeWorker : public Napi::AsyncWorker, public Napi::Promise::Deferred { -public: - DisposeWorker(Napi::Env env, LlamaContext *ctx) - : AsyncWorker(env), Deferred(env), ctx_(ctx) { - ctx_->Ref(); - } - - ~DisposeWorker() { ctx_->Unref(); } - -protected: - void Execute() override { ctx_->Dispose(); } - - void OnOK() override { Resolve(AsyncWorker::Env().Undefined()); } - - void OnError(const Napi::Error &err) override { Reject(err.Value()); } - -private: - LlamaContext *ctx_; -}; - -// getSystemInfo(): string -Napi::Value LlamaContext::GetSystemInfo(const Napi::CallbackInfo &info) { - return Napi::String::New(info.Env(), get_system_info(params).c_str()); -} - -// completion(options: LlamaCompletionOptions, onToken?: (token: string) => -// void): Promise -Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) { - Napi::Env env = info.Env(); - if (info.Length() < 1 || !info[0].IsObject()) { - Napi::TypeError::New(env, "Object expected").ThrowAsJavaScriptException(); - } - if (info.Length() >= 2 && !info[1].IsFunction()) { - Napi::TypeError::New(env, "Function expected").ThrowAsJavaScriptException(); - } - auto options = info[0].As(); - - gpt_params params; - params.prompt = get_option(options, "prompt", ""); - if (params.prompt.empty()) { - Napi::TypeError::New(env, "Prompt is required") - .ThrowAsJavaScriptException(); - } - params.n_predict = get_option(options, "n_predict", -1); - params.sparams.temp = get_option(options, "temperature", 0.80f); - params.sparams.top_k = get_option(options, "top_k", 40); - params.sparams.top_p = get_option(options, "top_p", 0.95f); - params.sparams.min_p = get_option(options, "min_p", 0.05f); - params.sparams.tfs_z = get_option(options, "tfs_z", 1.00f); - params.sparams.mirostat = get_option(options, "mirostat", 0.00f); - params.sparams.mirostat_tau = - get_option(options, "mirostat_tau", 5.00f); - params.sparams.mirostat_eta = - get_option(options, "mirostat_eta", 0.10f); - params.sparams.penalty_last_n = - get_option(options, "penalty_last_n", 64); - params.sparams.penalty_repeat = - get_option(options, "penalty_repeat", 1.00f); - params.sparams.penalty_freq = - get_option(options, "penalty_freq", 0.00f); - params.sparams.penalty_present = - get_option(options, "penalty_present", 0.00f); - params.sparams.penalize_nl = get_option(options, "penalize_nl", false); - params.sparams.typical_p = get_option(options, "typical_p", 1.00f); - params.ignore_eos = get_option(options, "ignore_eos", false); - params.sparams.grammar = get_option(options, "grammar", ""); - params.n_keep = get_option(options, "n_keep", 0); - params.seed = get_option(options, "seed", LLAMA_DEFAULT_SEED); - std::vector stop_words; - if (options.Has("stop") && options.Get("stop").IsArray()) { - auto stop_words_array = options.Get("stop").As(); - for (size_t i = 0; i < stop_words_array.Length(); i++) { - stop_words.push_back(stop_words_array.Get(i).ToString().Utf8Value()); - } - } - - // options.on_sample - Napi::Function callback; - if (info.Length() >= 2) { - callback = info[1].As(); - } - - auto worker = - new LlamaCompletionWorker(info, this, callback, params, stop_words); - worker->Queue(); - compl_worker = worker; - return worker->Promise(); -} - -// stopCompletion(): void -void LlamaContext::StopCompletion(const Napi::CallbackInfo &info) { - if (compl_worker != nullptr) { - compl_worker->Stop(); - } -} - -// saveSession(path: string): Promise throws error -Napi::Value LlamaContext::SaveSession(const Napi::CallbackInfo &info) { - Napi::Env env = info.Env(); - if (info.Length() < 1 || !info[0].IsString()) { - Napi::TypeError::New(env, "String expected").ThrowAsJavaScriptException(); - } - auto *worker = new SaveSessionWorker(info, this); - worker->Queue(); - return worker->Promise(); -} - -// loadSession(path: string): Promise<{ count }> throws error -Napi::Value LlamaContext::LoadSession(const Napi::CallbackInfo &info) { - Napi::Env env = info.Env(); - if (info.Length() < 1 || !info[0].IsString()) { - Napi::TypeError::New(env, "String expected").ThrowAsJavaScriptException(); - } - auto *worker = new LoadSessionWorker(info, this); - worker->Queue(); - return worker->Promise(); -} - -// release(): Promise -Napi::Value LlamaContext::Release(const Napi::CallbackInfo &info) { - if (compl_worker != nullptr) { - compl_worker->Stop(); - } - auto *worker = new DisposeWorker(info.Env(), this); - worker->Queue(); - return worker->Promise(); -} - -Napi::Object Init(Napi::Env env, Napi::Object exports) { - LlamaContext::Export(env, exports); - return exports; -} - -NODE_API_MODULE(addons, Init) diff --git a/src/common.hpp b/src/common.hpp new file mode 100644 index 0000000..7dfdacd --- /dev/null +++ b/src/common.hpp @@ -0,0 +1,81 @@ +#pragma once + +#include "common/common.h" +#include "llama.h" +#include +#include +#include +#include +#include +#include +#include + +typedef std::unique_ptr LlamaCppModel; +typedef std::unique_ptr LlamaCppContext; +typedef std::unique_ptr + LlamaCppSampling; +typedef std::unique_ptr LlamaCppBatch; + +template +constexpr T get_option(const Napi::Object &options, const std::string &name, + const T default_value) { + if (options.Has(name) && !options.Get(name).IsUndefined() && + !options.Get(name).IsNull()) { + if constexpr (std::is_same::value) { + return options.Get(name).ToString().operator T(); + } else if constexpr (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) { + return options.Get(name).ToNumber().operator T(); + } else if constexpr (std::is_same::value) { + return options.Get(name).ToBoolean().operator T(); + } else { + static_assert(std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value, + "Unsupported type"); + } + } else { + return default_value; + } +} + +class LlamaSession { +public: + LlamaSession(llama_context *ctx, gpt_params params) + : ctx_(LlamaCppContext(ctx, llama_free)), params_(params) { + tokens_.reserve(params.n_ctx); + } + + ~LlamaSession() { dispose(); } + + llama_context *context() { return ctx_.get(); } + + std::vector* tokens_ptr() { return &tokens_; } + + void set_tokens(std::vector tokens) { + tokens_ = std::move(tokens); + } + + const gpt_params ¶ms() const { return params_; } + + std::mutex &get_mutex() { return mutex; } + + void dispose() { + std::lock_guard lock(mutex); + tokens_.clear(); + ctx_.reset(); + } + +private: + LlamaCppContext ctx_; + const gpt_params params_; + std::vector tokens_{}; + std::mutex mutex; +}; + +typedef std::shared_ptr LlamaSessionPtr;