diff --git a/lib/binding.ts b/lib/binding.ts index 274658a..b8a4459 100644 --- a/lib/binding.ts +++ b/lib/binding.ts @@ -15,6 +15,25 @@ export type LlamaModelOptions = { n_ubatch?: number n_threads?: number n_gpu_layers?: number + flash_attn?: boolean + cache_type_k?: + | 'f16' + | 'f32' + | 'q8_0' + | 'q4_0' + | 'q4_1' + | 'iq4_nl' + | 'q5_0' + | 'q5_1' + cache_type_v?: + | 'f16' + | 'f32' + | 'q8_0' + | 'q4_0' + | 'q4_1' + | 'iq4_nl' + | 'q5_0' + | 'q5_1' use_mlock?: boolean use_mmap?: boolean vocab_only?: boolean diff --git a/src/LlamaContext.cpp b/src/LlamaContext.cpp index a86efaa..0bf511a 100644 --- a/src/LlamaContext.cpp +++ b/src/LlamaContext.cpp @@ -1,3 +1,4 @@ +#include "ggml.h" #include "LlamaContext.h" #include "DetokenizeWorker.h" #include "DisposeWorker.h" @@ -60,6 +61,27 @@ void LlamaContext::Init(Napi::Env env, Napi::Object &exports) { exports.Set("LlamaContext", func); } +const std::vector kv_cache_types = { + GGML_TYPE_F32, + GGML_TYPE_F16, + GGML_TYPE_BF16, + GGML_TYPE_Q8_0, + GGML_TYPE_Q4_0, + GGML_TYPE_Q4_1, + GGML_TYPE_IQ4_NL, + GGML_TYPE_Q5_0, + GGML_TYPE_Q5_1, +}; + +static ggml_type kv_cache_type_from_str(const std::string & s) { + for (const auto & type : kv_cache_types) { + if (ggml_type_name(type) == s) { + return type; + } + } + throw std::runtime_error("Unsupported cache type: " + s); +} + // construct({ model, embedding, n_ctx, n_batch, n_threads, n_gpu_layers, // use_mlock, use_mmap }): LlamaContext throws error LlamaContext::LlamaContext(const Napi::CallbackInfo &info) @@ -96,6 +118,10 @@ LlamaContext::LlamaContext(const Napi::CallbackInfo &info) params.cpuparams.n_threads = get_option(options, "n_threads", cpu_get_num_math() / 2); params.n_gpu_layers = get_option(options, "n_gpu_layers", -1); + params.flash_attn = get_option(options, "flash_attn", false); + params.cache_type_k = kv_cache_type_from_str(get_option(options, "cache_type_k", "f16").c_str()); + params.cache_type_v = kv_cache_type_from_str(get_option(options, "cache_type_v", "f16").c_str()); + params.use_mlock = get_option(options, "use_mlock", false); params.use_mmap = get_option(options, "use_mmap", true); params.numa =