Skip to content

Commit

Permalink
feat: add static model info method (#76)
Browse files Browse the repository at this point in the history
* feat(cpp): add static model info method

* fix(cpp): include llama-impl.h

* feat: impl loadLlamaModelInfo
  • Loading branch information
jhen0409 authored Jan 20, 2025
1 parent 2cbf1aa commit 6396083
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 2 deletions.
2 changes: 2 additions & 0 deletions lib/binding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ export interface LlamaContext {
saveSession(path: string): Promise<void>
loadSession(path: string): Promise<void>
release(): Promise<void>
// static
loadModelInfo(path: string, skip: string[]): Promise<Object>
}

export interface Module {
Expand Down
16 changes: 16 additions & 0 deletions lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,19 @@ export const loadModel = async (options: LlamaModelOptionsExtended): Promise<Lla
mods[variant] ??= await loadModule(options.lib_variant)
return new mods[variant].LlamaContext(options)
}

export const initLlama = loadModule

const modelInfoSkip = [
// Large fields
'tokenizer.ggml.tokens',
'tokenizer.ggml.token_type',
'tokenizer.ggml.merges',
'tokenizer.ggml.scores',
]

export const loadLlamaModelInfo = async (path: string): Promise<Object> => {
const variant = 'default'
mods[variant] ??= await loadModule(variant)
return mods[variant].LlamaContext.loadModelInfo(path, modelInfoSkip)
}
57 changes: 56 additions & 1 deletion src/LlamaContext.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include "ggml.h"
#include "gguf.h"
#include "llama-impl.h"
#include "LlamaContext.h"
#include "DetokenizeWorker.h"
#include "DisposeWorker.h"
Expand All @@ -8,6 +10,56 @@
#include "SaveSessionWorker.h"
#include "TokenizeWorker.h"

// loadModelInfo(path: string): object
Napi::Value LlamaContext::ModelInfo(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
struct gguf_init_params params = {
/*.no_alloc = */ false,
/*.ctx = */ NULL,
};
std::string path = info[0].ToString().Utf8Value();

// Convert Napi::Array to vector<string>
std::vector<std::string> skip;
if (info.Length() > 1 && info[1].IsArray()) {
Napi::Array skipArray = info[1].As<Napi::Array>();
for (uint32_t i = 0; i < skipArray.Length(); i++) {
skip.push_back(skipArray.Get(i).ToString().Utf8Value());
}
}

struct gguf_context * ctx = gguf_init_from_file(path.c_str(), params);

Napi::Object metadata = Napi::Object::New(env);
if (std::find(skip.begin(), skip.end(), "version") == skip.end()) {
metadata.Set("version", Napi::Number::New(env, gguf_get_version(ctx)));
}
if (std::find(skip.begin(), skip.end(), "alignment") == skip.end()) {
metadata.Set("alignment", Napi::Number::New(env, gguf_get_alignment(ctx)));
}
if (std::find(skip.begin(), skip.end(), "data_offset") == skip.end()) {
metadata.Set("data_offset", Napi::Number::New(env, gguf_get_data_offset(ctx)));
}

// kv
{
const int n_kv = gguf_get_n_kv(ctx);

for (int i = 0; i < n_kv; ++i) {
const char * key = gguf_get_key(ctx, i);
if (std::find(skip.begin(), skip.end(), key) != skip.end()) {
continue;
}
const std::string value = gguf_kv_to_str(ctx, i);
metadata.Set(key, Napi::String::New(env, value.c_str()));
}
}

gguf_free(ctx);

return metadata;
}

std::vector<common_chat_msg> get_messages(Napi::Array messages) {
std::vector<common_chat_msg> chat;
for (size_t i = 0; i < messages.Length(); i++) {
Expand Down Expand Up @@ -52,7 +104,10 @@ void LlamaContext::Init(Napi::Env env, Napi::Object &exports) {
"loadSession",
static_cast<napi_property_attributes>(napi_enumerable)),
InstanceMethod<&LlamaContext::Release>(
"release", static_cast<napi_property_attributes>(napi_enumerable))});
"release", static_cast<napi_property_attributes>(napi_enumerable)),
StaticMethod<&LlamaContext::ModelInfo>(
"loadModelInfo",
static_cast<napi_property_attributes>(napi_enumerable))});
Napi::FunctionReference *constructor = new Napi::FunctionReference();
*constructor = Napi::Persistent(func);
#if NAPI_VERSION > 5
Expand Down
1 change: 1 addition & 0 deletions src/LlamaContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ class LlamaCompletionWorker;
class LlamaContext : public Napi::ObjectWrap<LlamaContext> {
public:
LlamaContext(const Napi::CallbackInfo &info);
static Napi::Value ModelInfo(const Napi::CallbackInfo& info);
static void Init(Napi::Env env, Napi::Object &exports);

private:
Expand Down
23 changes: 23 additions & 0 deletions test/__snapshots__/index.test.ts.snap
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,29 @@ exports[`embedding 1`] = `
}
`;

exports[`loadModelInfo 1`] = `
{
"alignment": 32,
"data_offset": 724416,
"general.architecture": "llama",
"general.file_type": "1",
"general.name": "LLaMA v2",
"llama.attention.head_count": "2",
"llama.attention.head_count_kv": "2",
"llama.attention.layer_norm_rms_epsilon": "0.000010",
"llama.block_count": "1",
"llama.context_length": "4096",
"llama.embedding_length": "8",
"llama.feed_forward_length": "32",
"llama.rope.dimension_count": "4",
"tokenizer.ggml.bos_token_id": "1",
"tokenizer.ggml.eos_token_id": "2",
"tokenizer.ggml.model": "llama",
"tokenizer.ggml.unknown_token_id": "0",
"version": 3,
}
`;

exports[`tokeneize 1`] = `
{
"tokens": Int32Array [
Expand Down
7 changes: 6 additions & 1 deletion test/index.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import path from 'path'
import waitForExpect from 'wait-for-expect'
import { loadModel } from '../lib'
import { loadModel, loadLlamaModelInfo } from '../lib'

it('works fine', async () => {
let tokens = ''
Expand Down Expand Up @@ -67,3 +67,8 @@ it('embedding', async () => {
expect(result).toMatchSnapshot()
await model.release()
})

it('loadModelInfo', async () => {
const result = await loadLlamaModelInfo(path.resolve(__dirname, './tiny-random-llama.gguf'))
expect(result).toMatchSnapshot()
})

0 comments on commit 6396083

Please sign in to comment.