diff --git a/android/src/main/java/com/rnllama/LlamaContext.java b/android/src/main/java/com/rnllama/LlamaContext.java index 337ed04..36b1a1c 100644 --- a/android/src/main/java/com/rnllama/LlamaContext.java +++ b/android/src/main/java/com/rnllama/LlamaContext.java @@ -358,6 +358,10 @@ private static String getCpuFeatures() { } } + protected static native WritableMap modelInfo( + String model, + String[] skip + ); protected static native long initContext( String model, boolean embedding, diff --git a/android/src/main/java/com/rnllama/RNLlama.java b/android/src/main/java/com/rnllama/RNLlama.java index eb02755..1f02f2d 100644 --- a/android/src/main/java/com/rnllama/RNLlama.java +++ b/android/src/main/java/com/rnllama/RNLlama.java @@ -42,6 +42,35 @@ public void setContextLimit(double limit, Promise promise) { promise.resolve(null); } + public void modelInfo(final String model, final ReadableArray skip, final Promise promise) { + new AsyncTask() { + private Exception exception; + + @Override + protected WritableMap doInBackground(Void... voids) { + try { + String[] skipArray = new String[skip.size()]; + for (int i = 0; i < skip.size(); i++) { + skipArray[i] = skip.getString(i); + } + return LlamaContext.modelInfo(model, skipArray); + } catch (Exception e) { + exception = e; + } + return null; + } + + @Override + protected void onPostExecute(WritableMap result) { + if (exception != null) { + promise.reject(exception); + return; + } + promise.resolve(result); + } + }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR); + } + public void initContext(double id, final ReadableMap params, final Promise promise) { final int contextId = (int) id; AsyncTask task = new AsyncTask() { diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 128aa3f..3beb203 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -9,8 +9,9 @@ #include #include #include "llama.h" -#include "rn-llama.hpp" +#include "llama-impl.h" #include "ggml.h" +#include "rn-llama.hpp" #define UNUSED(x) (void)(x) #define TAG "RNLLAMA_ANDROID_JNI" @@ -132,6 +133,72 @@ static inline void putArray(JNIEnv *env, jobject map, const char *key, jobject v env->CallVoidMethod(map, putArrayMethod, jKey, value); } +JNIEXPORT jobject JNICALL +Java_com_rnllama_LlamaContext_modelInfo( + JNIEnv *env, + jobject thiz, + jstring model_path_str, + jobjectArray skip +) { + UNUSED(thiz); + + const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr); + + std::vector skip_vec; + int skip_len = env->GetArrayLength(skip); + for (int i = 0; i < skip_len; i++) { + jstring skip_str = (jstring) env->GetObjectArrayElement(skip, i); + const char *skip_chars = env->GetStringUTFChars(skip_str, nullptr); + skip_vec.push_back(skip_chars); + env->ReleaseStringUTFChars(skip_str, skip_chars); + } + + struct lm_gguf_init_params params = { + /*.no_alloc = */ false, + /*.ctx = */ NULL, + }; + struct lm_gguf_context * ctx = lm_gguf_init_from_file(model_path_chars, params); + + if (!ctx) { + LOGI("%s: failed to load '%s'\n", __func__, model_path_chars); + return nullptr; + } + + auto info = createWriteableMap(env); + putInt(env, info, "version", lm_gguf_get_version(ctx)); + putInt(env, info, "alignment", lm_gguf_get_alignment(ctx)); + putInt(env, info, "data_offset", lm_gguf_get_data_offset(ctx)); + { + const int n_kv = lm_gguf_get_n_kv(ctx); + + for (int i = 0; i < n_kv; ++i) { + const char * key = lm_gguf_get_key(ctx, i); + + bool skipped = false; + if (skip_len > 0) { + for (int j = 0; j < skip_len; j++) { + if (skip_vec[j] == key) { + skipped = true; + break; + } + } + } + + if (skipped) { + continue; + } + + const std::string value = rnllama::lm_gguf_kv_to_str(ctx, i); + putString(env, info, key, value.c_str()); + } + } + + env->ReleaseStringUTFChars(model_path_str, model_path_chars); + lm_gguf_free(ctx); + + return reinterpret_cast(info); +} + struct callback_context { JNIEnv *env; rnllama::llama_rn_context *llama; diff --git a/android/src/newarch/java/com/rnllama/RNLlamaModule.java b/android/src/newarch/java/com/rnllama/RNLlamaModule.java index 5bab9b1..a41aa05 100644 --- a/android/src/newarch/java/com/rnllama/RNLlamaModule.java +++ b/android/src/newarch/java/com/rnllama/RNLlamaModule.java @@ -37,6 +37,11 @@ public void setContextLimit(double limit, Promise promise) { rnllama.setContextLimit(limit, promise); } + @ReactMethod + public void modelInfo(final String model, final ReadableArray skip, final Promise promise) { + rnllama.modelInfo(model, skip, promise); + } + @ReactMethod public void initContext(double id, final ReadableMap params, final Promise promise) { rnllama.initContext(id, params, promise); diff --git a/android/src/oldarch/java/com/rnllama/RNLlamaModule.java b/android/src/oldarch/java/com/rnllama/RNLlamaModule.java index 2719515..4f01542 100644 --- a/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +++ b/android/src/oldarch/java/com/rnllama/RNLlamaModule.java @@ -38,6 +38,11 @@ public void setContextLimit(double limit, Promise promise) { rnllama.setContextLimit(limit, promise); } + @ReactMethod + public void modelInfo(final String model, final ReadableArray skip, final Promise promise) { + rnllama.modelInfo(model, skip, promise); + } + @ReactMethod public void initContext(double id, final ReadableMap params, final Promise promise) { rnllama.initContext(id, params, promise);