Skip to content

Commit

Permalink
feat(android): add progress callback in initLlama
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Nov 4, 2024
1 parent 48cea0a commit a72317a
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 13 deletions.
33 changes: 30 additions & 3 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa
}
Log.d(NAME, "Setting log callback");
logToAndroid();
eventEmitter = reactContext.getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter.class);
this.id = id;
this.context = initContext(
// String model,
Expand Down Expand Up @@ -64,11 +65,16 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa
// float rope_freq_base,
params.hasKey("rope_freq_base") ? (float) params.getDouble("rope_freq_base") : 0.0f,
// float rope_freq_scale
params.hasKey("rope_freq_scale") ? (float) params.getDouble("rope_freq_scale") : 0.0f
params.hasKey("rope_freq_scale") ? (float) params.getDouble("rope_freq_scale") : 0.0f,
// LoadProgressCallback load_progress_callback
params.hasKey("use_progress_callback") ? new LoadProgressCallback(this) : null
);
this.modelDetails = loadModelDetails(this.context);
this.reactContext = reactContext;
eventEmitter = reactContext.getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter.class);
}

public void interruptLoad() {
interruptLoad(this.context);
}

public long getContext() {
Expand All @@ -87,6 +93,25 @@ public String getFormattedChat(ReadableArray messages, String chatTemplate) {
return getFormattedChat(this.context, msgs, chatTemplate == null ? "" : chatTemplate);
}

private void emitLoadProgress(int progress) {
WritableMap event = Arguments.createMap();
event.putInt("contextId", LlamaContext.this.id);
event.putInt("progress", progress);
eventEmitter.emit("@RNLlama_onInitContextProgress", event);
}

private static class LoadProgressCallback {
LlamaContext context;

public LoadProgressCallback(LlamaContext context) {
this.context = context;
}

void onLoadProgress(int progress) {
context.emitLoadProgress(progress);
}
}

private void emitPartialCompletion(WritableMap tokenResult) {
WritableMap event = Arguments.createMap();
event.putInt("contextId", LlamaContext.this.id);
Expand Down Expand Up @@ -346,8 +371,10 @@ protected static native long initContext(
String lora,
float lora_scaled,
float rope_freq_base,
float rope_freq_scale
float rope_freq_scale,
LoadProgressCallback load_progress_callback
);
protected static native void interruptLoad(long contextPtr);
protected static native WritableMap loadModelDetails(
long contextPtr
);
Expand Down
14 changes: 9 additions & 5 deletions android/src/main/java/com/rnllama/RNLlama.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,24 @@ public void setContextLimit(double limit, Promise promise) {
promise.resolve(null);
}

public void initContext(final ReadableMap params, final Promise promise) {
public void initContext(double id, final ReadableMap params, final Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, WritableMap>() {
private Exception exception;

@Override
protected WritableMap doInBackground(Void... voids) {
try {
int id = Math.abs(new Random().nextInt());
LlamaContext llamaContext = new LlamaContext(id, reactContext, params);
LlamaContext context = contexts.get(contextId);
if (context != null) {
throw new Exception("Context already exists");
}
LlamaContext llamaContext = new LlamaContext(contextId, reactContext, params);
if (llamaContext.getContext() == 0) {
throw new Exception("Failed to initialize context");
}
contexts.put(id, llamaContext);
contexts.put(contextId, llamaContext);
WritableMap result = Arguments.createMap();
result.putInt("contextId", id);
result.putBoolean("gpu", false);
result.putString("reasonNoGPU", "Currently not supported");
result.putMap("model", llamaContext.getModelDetails());
Expand Down Expand Up @@ -393,6 +396,7 @@ protected Void doInBackground(Void... voids) {
if (context == null) {
throw new Exception("Context " + id + " not found");
}
context.interruptLoad();
context.stopCompletion();
AsyncTask completionTask = null;
for (AsyncTask task : tasks.keySet()) {
Expand Down
48 changes: 47 additions & 1 deletion android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ static inline void putArray(JNIEnv *env, jobject map, const char *key, jobject v
env->CallVoidMethod(map, putArrayMethod, jKey, value);
}

struct callback_context {
JNIEnv *env;
rnllama::llama_rn_context *llama;
jobject callback;
};

std::unordered_map<long, rnllama::llama_rn_context *> context_map;

Expand All @@ -151,7 +156,8 @@ Java_com_rnllama_LlamaContext_initContext(
jstring lora_str,
jfloat lora_scaled,
jfloat rope_freq_base,
jfloat rope_freq_scale
jfloat rope_freq_scale,
jobject load_progress_callback
) {
UNUSED(thiz);

Expand Down Expand Up @@ -190,6 +196,32 @@ Java_com_rnllama_LlamaContext_initContext(
defaultParams.rope_freq_scale = rope_freq_scale;

auto llama = new rnllama::llama_rn_context();
llama->is_load_interrupted = false;
llama->loading_progress = 0;

if (load_progress_callback != nullptr) {
defaultParams.progress_callback = [](float progress, void * user_data) {
callback_context *cb_ctx = (callback_context *)user_data;
JNIEnv *env = cb_ctx->env;
auto llama = cb_ctx->llama;
jobject callback = cb_ctx->callback;
int percentage = (int) (100 * progress);
if (percentage > llama->loading_progress) {
llama->loading_progress = percentage;
jclass callback_class = env->GetObjectClass(callback);
jmethodID onLoadProgress = env->GetMethodID(callback_class, "onLoadProgress", "(I)V");
env->CallVoidMethod(callback, onLoadProgress, percentage);
}
return !llama->is_load_interrupted;
};

callback_context *cb_ctx = new callback_context;
cb_ctx->env = env;
cb_ctx->llama = llama;
cb_ctx->callback = env->NewGlobalRef(load_progress_callback);
defaultParams.progress_callback_user_data = cb_ctx;
}

bool is_model_loaded = llama->loadModel(defaultParams);

LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false"));
Expand All @@ -205,6 +237,20 @@ Java_com_rnllama_LlamaContext_initContext(
return reinterpret_cast<jlong>(llama->ctx);
}


JNIEXPORT void JNICALL
Java_com_rnllama_LlamaContext_interruptLoad(
JNIEnv *env,
jobject thiz,
jlong context_ptr
) {
UNUSED(thiz);
auto llama = context_map[(long) context_ptr];
if (llama) {
llama->is_load_interrupted = true;
}
}

JNIEXPORT jobject JNICALL
Java_com_rnllama_LlamaContext_loadModelDetails(
JNIEnv *env,
Expand Down
4 changes: 2 additions & 2 deletions android/src/newarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ public void setContextLimit(double limit, Promise promise) {
}

@ReactMethod
public void initContext(final ReadableMap params, final Promise promise) {
rnllama.initContext(params, promise);
public void initContext(double id, final ReadableMap params, final Promise promise) {
rnllama.initContext(id, params, promise);
}

@ReactMethod
Expand Down
4 changes: 2 additions & 2 deletions android/src/oldarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ public void setContextLimit(double limit, Promise promise) {
}

@ReactMethod
public void initContext(final ReadableMap params, final Promise promise) {
rnllama.initContext(params, promise);
public void initContext(double id, final ReadableMap params, final Promise promise) {
rnllama.initContext(id, params, promise);
}

@ReactMethod
Expand Down
4 changes: 4 additions & 0 deletions ios/RNLlamaContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsig
return context;
}

- (void)interruptLoad {
llama->is_load_interrupted = true;
}

- (bool)isMetalEnabled {
return is_metal_enabled;
}
Expand Down

0 comments on commit a72317a

Please sign in to comment.