Skip to content

Commit

Permalink
feat(android): implement modelInfo method
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Nov 16, 2024
1 parent b9309b3 commit 0c1746e
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 1 deletion.
4 changes: 4 additions & 0 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,10 @@ private static String getCpuFeatures() {
}
}

protected static native WritableMap modelInfo(
String model,
String[] skip
);
protected static native long initContext(
String model,
boolean embedding,
Expand Down
29 changes: 29 additions & 0 deletions android/src/main/java/com/rnllama/RNLlama.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,35 @@ public void setContextLimit(double limit, Promise promise) {
promise.resolve(null);
}

public void modelInfo(final String model, final ReadableArray skip, final Promise promise) {
new AsyncTask<Void, Void, WritableMap>() {
private Exception exception;

@Override
protected WritableMap doInBackground(Void... voids) {
try {
String[] skipArray = new String[skip.size()];
for (int i = 0; i < skip.size(); i++) {
skipArray[i] = skip.getString(i);
}
return LlamaContext.modelInfo(model, skipArray);
} catch (Exception e) {
exception = e;
}
return null;
}

@Override
protected void onPostExecute(WritableMap result) {
if (exception != null) {
promise.reject(exception);
return;
}
promise.resolve(result);
}
}.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
}

public void initContext(double id, final ReadableMap params, final Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
Expand Down
69 changes: 68 additions & 1 deletion android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
#include <thread>
#include <unordered_map>
#include "llama.h"
#include "rn-llama.hpp"
#include "llama-impl.h"
#include "ggml.h"
#include "rn-llama.hpp"

#define UNUSED(x) (void)(x)
#define TAG "RNLLAMA_ANDROID_JNI"
Expand Down Expand Up @@ -132,6 +133,72 @@ static inline void putArray(JNIEnv *env, jobject map, const char *key, jobject v
env->CallVoidMethod(map, putArrayMethod, jKey, value);
}

JNIEXPORT jobject JNICALL
Java_com_rnllama_LlamaContext_modelInfo(
JNIEnv *env,
jobject thiz,
jstring model_path_str,
jobjectArray skip
) {
UNUSED(thiz);

const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);

std::vector<std::string> skip_vec;
int skip_len = env->GetArrayLength(skip);
for (int i = 0; i < skip_len; i++) {
jstring skip_str = (jstring) env->GetObjectArrayElement(skip, i);
const char *skip_chars = env->GetStringUTFChars(skip_str, nullptr);
skip_vec.push_back(skip_chars);
env->ReleaseStringUTFChars(skip_str, skip_chars);
}

struct lm_gguf_init_params params = {
/*.no_alloc = */ false,
/*.ctx = */ NULL,
};
struct lm_gguf_context * ctx = lm_gguf_init_from_file(model_path_chars, params);

if (!ctx) {
LOGI("%s: failed to load '%s'\n", __func__, model_path_chars);
return nullptr;
}

auto info = createWriteableMap(env);
putInt(env, info, "version", lm_gguf_get_version(ctx));
putInt(env, info, "alignment", lm_gguf_get_alignment(ctx));
putInt(env, info, "data_offset", lm_gguf_get_data_offset(ctx));
{
const int n_kv = lm_gguf_get_n_kv(ctx);

for (int i = 0; i < n_kv; ++i) {
const char * key = lm_gguf_get_key(ctx, i);

bool skipped = false;
if (skip_len > 0) {
for (int j = 0; j < skip_len; j++) {
if (skip_vec[j] == key) {
skipped = true;
break;
}
}
}

if (skipped) {
continue;
}

const std::string value = rnllama::lm_gguf_kv_to_str(ctx, i);
putString(env, info, key, value.c_str());
}
}

env->ReleaseStringUTFChars(model_path_str, model_path_chars);
lm_gguf_free(ctx);

return reinterpret_cast<jobject>(info);
}

struct callback_context {
JNIEnv *env;
rnllama::llama_rn_context *llama;
Expand Down
5 changes: 5 additions & 0 deletions android/src/newarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ public void setContextLimit(double limit, Promise promise) {
rnllama.setContextLimit(limit, promise);
}

@ReactMethod
public void modelInfo(final String model, final ReadableArray skip, final Promise promise) {
rnllama.modelInfo(model, skip, promise);
}

@ReactMethod
public void initContext(double id, final ReadableMap params, final Promise promise) {
rnllama.initContext(id, params, promise);
Expand Down
5 changes: 5 additions & 0 deletions android/src/oldarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ public void setContextLimit(double limit, Promise promise) {
rnllama.setContextLimit(limit, promise);
}

@ReactMethod
public void modelInfo(final String model, final ReadableArray skip, final Promise promise) {
rnllama.modelInfo(model, skip, promise);
}

@ReactMethod
public void initContext(double id, final ReadableMap params, final Promise promise) {
rnllama.initContext(id, params, promise);
Expand Down

0 comments on commit 0c1746e

Please sign in to comment.