diff --git a/android/src/main/java/com/rnllama/LlamaContext.java b/android/src/main/java/com/rnllama/LlamaContext.java index 4cdb55a1..61985a5e 100644 --- a/android/src/main/java/com/rnllama/LlamaContext.java +++ b/android/src/main/java/com/rnllama/LlamaContext.java @@ -216,6 +216,10 @@ public WritableMap embedding(String text) { return result; } + public String bench(int pp, int tg, int pl, int nr) { + return bench(this.context, pp, tg, pl, nr); + } + public void release() { freeContext(context); } @@ -329,5 +333,6 @@ protected static native WritableMap doCompletion( protected static native String detokenize(long contextPtr, int[] tokens); protected static native boolean isEmbeddingEnabled(long contextPtr); protected static native WritableArray embedding(long contextPtr, String text); + protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr); protected static native void freeContext(long contextPtr); } diff --git a/android/src/main/java/com/rnllama/RNLlama.java b/android/src/main/java/com/rnllama/RNLlama.java index eb423dd0..6ed206ec 100644 --- a/android/src/main/java/com/rnllama/RNLlama.java +++ b/android/src/main/java/com/rnllama/RNLlama.java @@ -316,6 +316,38 @@ protected void onPostExecute(WritableMap result) { tasks.put(task, "embedding-" + contextId); } + public void bench(double id, final double pp, final double tg, final double pl, final double nr, final Promise promise) { + final int contextId = (int) id; + AsyncTask task = new AsyncTask() { + private Exception exception; + + @Override + protected String doInBackground(Void... voids) { + try { + LlamaContext context = contexts.get(contextId); + if (context == null) { + throw new Exception("Context not found"); + } + return context.bench((int) pp, (int) tg, (int) pl, (int) nr); + } catch (Exception e) { + exception = e; + } + return null; + } + + @Override + protected void onPostExecute(String result) { + if (exception != null) { + promise.reject(exception); + return; + } + promise.resolve(result); + tasks.remove(this); + } + }.execute(); + tasks.put(task, "bench-" + contextId); + } + public void releaseContext(double id, Promise promise) { final int contextId = (int) id; AsyncTask task = new AsyncTask() { diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index ce632427..8631b2a5 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -561,6 +561,22 @@ Java_com_rnllama_LlamaContext_embedding( return result; } +JNIEXPORT jstring JNICALL +Java_com_rnllama_LlamaContext_bench( + JNIEnv *env, + jobject thiz, + jlong context_ptr, + jint pp, + jint tg, + jint pl, + jint nr +) { + UNUSED(thiz); + auto llama = context_map[(long) context_ptr]; + std::string result = llama->bench(pp, tg, pl, nr); + return env->NewStringUTF(result.c_str()); +} + JNIEXPORT void JNICALL Java_com_rnllama_LlamaContext_freeContext( JNIEnv *env, jobject thiz, jlong context_ptr) { diff --git a/android/src/newarch/java/com/rnllama/RNLlamaModule.java b/android/src/newarch/java/com/rnllama/RNLlamaModule.java index 38c5a1c6..93d27222 100644 --- a/android/src/newarch/java/com/rnllama/RNLlamaModule.java +++ b/android/src/newarch/java/com/rnllama/RNLlamaModule.java @@ -77,6 +77,11 @@ public void embedding(double id, final String text, final Promise promise) { rnllama.embedding(id, text, promise); } + @ReactMethod + public void bench(double id, final double pp, final double tg, final double pl, final double nr, final Promise promise) { + rnllama.bench(id, pp, tg, pl, nr, promise); + } + @ReactMethod public void releaseContext(double id, Promise promise) { rnllama.releaseContext(id, promise); diff --git a/android/src/oldarch/java/com/rnllama/RNLlamaModule.java b/android/src/oldarch/java/com/rnllama/RNLlamaModule.java index 4b34e2a5..814fb17e 100644 --- a/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +++ b/android/src/oldarch/java/com/rnllama/RNLlamaModule.java @@ -78,6 +78,11 @@ public void embedding(double id, final String text, final Promise promise) { rnllama.embedding(id, text, promise); } + @ReactMethod + public void bench(double id, final double pp, final double tg, final double pl, final double nr, final Promise promise) { + rnllama.bench(id, pp, tg, pl, nr, promise); + } + @ReactMethod public void releaseContext(double id, Promise promise) { rnllama.releaseContext(id, promise); diff --git a/cpp/rn-llama.hpp b/cpp/rn-llama.hpp index ba6c8607..8d9fe4a2 100644 --- a/cpp/rn-llama.hpp +++ b/cpp/rn-llama.hpp @@ -8,6 +8,21 @@ namespace rnllama { +static void llama_batch_clear(llama_batch *batch) { + batch->n_tokens = 0; +} + +static void llama_batch_add(llama_batch *batch, llama_token id, llama_pos pos, std::vector seq_ids, bool logits) { + batch->token [batch->n_tokens] = id; + batch->pos [batch->n_tokens] = pos; + batch->n_seq_id[batch->n_tokens] = seq_ids.size(); + for (size_t i = 0; i < seq_ids.size(); i++) { + batch->seq_id[batch->n_tokens][i] = seq_ids[i]; + } + batch->logits [batch->n_tokens] = logits ? 1 : 0; + batch->n_tokens += 1; +} + // NOTE: Edit from https://github.com/ggerganov/llama.cpp/blob/master/examples/server/server.cpp static void log(const char *level, const char *function, int line, @@ -506,6 +521,109 @@ struct llama_rn_context std::vector embedding(data, data + n_embd); return embedding; } + + std::string bench(int pp, int tg, int pl, int nr) + { + if (is_predicting) { + LOG_ERROR("cannot benchmark while predicting", ""); + return std::string("[]"); + } + + is_predicting = true; + + double pp_avg = 0; + double tg_avg = 0; + + double pp_std = 0; + double tg_std = 0; + + // TODO: move batch into llama_rn_context (related https://github.com/mybigday/llama.rn/issues/30) + llama_batch batch = llama_batch_init(512, 0, 1); + + for (int i = 0; i < nr; i++) + { + llama_batch_clear(&batch); + + const int n_tokens = pp; + + for (int i = 0; i < n_tokens; i++) + { + llama_batch_add(&batch, 0, i, {0}, false); + } + batch.logits[batch.n_tokens - 1] = 1; // true + + llama_kv_cache_clear(ctx); + + const int64_t t_pp_start = llama_time_us(); + if (llama_decode(ctx, batch) != 0) + { + LOG_ERROR("llama_decode() failed during prompt", ""); + } + const int64_t t_pp_end = llama_time_us(); + llama_kv_cache_clear(ctx); + + if (is_interrupted) break; + + const int64_t t_tg_start = llama_time_us(); + + for (int i = 0; i < tg; i++) + { + llama_batch_clear(&batch); + + for (int j = 0; j < pl; j++) + { + llama_batch_add(&batch, 0, i, {j}, true); + } + + if (llama_decode(ctx, batch) != 0) + { + LOG_ERROR("llama_decode() failed during text generation", ""); + } + if (is_interrupted) break; + } + + const int64_t t_tg_end = llama_time_us(); + + llama_kv_cache_clear(ctx); + + const double t_pp = (t_pp_end - t_pp_start) / 1000000.0; + const double t_tg = (t_tg_end - t_tg_start) / 1000000.0; + + const double speed_pp = pp / t_pp; + const double speed_tg = (pl * tg) / t_tg; + + pp_avg += speed_pp; + tg_avg += speed_tg; + + pp_std += speed_pp * speed_pp; + tg_std += speed_tg * speed_tg; + } + + pp_avg /= nr; + tg_avg /= nr; + + if (nr > 1) { + pp_std = sqrt(pp_std / (nr - 1) - pp_avg * pp_avg * nr / (nr - 1)); + tg_std = sqrt(tg_std / (nr - 1) - tg_avg * tg_avg * nr / (nr - 1)); + } else { + pp_std = 0; + tg_std = 0; + } + + if (is_interrupted) llama_kv_cache_clear(ctx); + is_predicting = false; + + char model_desc[128]; + llama_model_desc(model, model_desc, sizeof(model_desc)); + return std::string("[\"") + model_desc + std::string("\",") + + std::to_string(llama_model_size(model)) + std::string(",") + + std::to_string(llama_model_n_params(model)) + std::string(",") + + std::to_string(pp_avg) + std::string(",") + + std::to_string(pp_std) + std::string(",") + + std::to_string(tg_avg) + std::string(",") + + std::to_string(tg_std) + + std::string("]"); + } }; } diff --git a/docs/API/README.md b/docs/API/README.md index d282527f..42060ff8 100644 --- a/docs/API/README.md +++ b/docs/API/README.md @@ -11,6 +11,7 @@ llama.rn ### Type Aliases +- [BenchResult](README.md#benchresult) - [CompletionParams](README.md#completionparams) - [ContextParams](README.md#contextparams) - [TokenData](README.md#tokendata) @@ -24,13 +25,35 @@ llama.rn ## Type Aliases +### BenchResult + +Ƭ **BenchResult**: `Object` + +#### Type declaration + +| Name | Type | +| :------ | :------ | +| `modelDesc` | `string` | +| `modelNParams` | `number` | +| `modelSize` | `number` | +| `ppAvg` | `number` | +| `ppStd` | `number` | +| `tgAvg` | `number` | +| `tgStd` | `number` | + +#### Defined in + +[index.ts:43](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L43) + +___ + ### CompletionParams Ƭ **CompletionParams**: `Omit`<`NativeCompletionParams`, ``"emit_partial_completion"``\> #### Defined in -[index.ts:40](https://github.com/mybigday/llama.rn/blob/8738c99/src/index.ts#L40) +[index.ts:41](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L41) ___ @@ -40,7 +63,7 @@ ___ #### Defined in -[index.ts:38](https://github.com/mybigday/llama.rn/blob/8738c99/src/index.ts#L38) +[index.ts:39](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L39) ___ @@ -57,7 +80,7 @@ ___ #### Defined in -[index.ts:28](https://github.com/mybigday/llama.rn/blob/8738c99/src/index.ts#L28) +[index.ts:29](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L29) ## Functions @@ -79,7 +102,7 @@ ___ #### Defined in -[grammar.ts:134](https://github.com/mybigday/llama.rn/blob/8738c99/src/grammar.ts#L134) +[grammar.ts:134](https://github.com/mybigday/llama.rn/blob/427a856/src/grammar.ts#L134) ___ @@ -99,7 +122,7 @@ ___ #### Defined in -[index.ts:127](https://github.com/mybigday/llama.rn/blob/8738c99/src/index.ts#L127) +[index.ts:160](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L160) ___ @@ -113,7 +136,7 @@ ___ #### Defined in -[index.ts:143](https://github.com/mybigday/llama.rn/blob/8738c99/src/index.ts#L143) +[index.ts:176](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L176) ___ @@ -133,4 +156,4 @@ ___ #### Defined in -[index.ts:123](https://github.com/mybigday/llama.rn/blob/8738c99/src/index.ts#L123) +[index.ts:156](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L156) diff --git a/docs/API/classes/LlamaContext.md b/docs/API/classes/LlamaContext.md index 0666076a..1d199156 100644 --- a/docs/API/classes/LlamaContext.md +++ b/docs/API/classes/LlamaContext.md @@ -16,6 +16,7 @@ ### Methods +- [bench](LlamaContext.md#bench) - [completion](LlamaContext.md#completion) - [detokenize](LlamaContext.md#detokenize) - [embedding](LlamaContext.md#embedding) @@ -39,7 +40,7 @@ #### Defined in -[index.ts:49](https://github.com/mybigday/llama.rn/blob/8738c99/src/index.ts#L49) +[index.ts:60](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L60) ## Properties @@ -49,7 +50,7 @@ #### Defined in -[index.ts:45](https://github.com/mybigday/llama.rn/blob/8738c99/src/index.ts#L45) +[index.ts:56](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L56) ___ @@ -59,7 +60,7 @@ ___ #### Defined in -[index.ts:43](https://github.com/mybigday/llama.rn/blob/8738c99/src/index.ts#L43) +[index.ts:54](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L54) ___ @@ -69,10 +70,33 @@ ___ #### Defined in -[index.ts:47](https://github.com/mybigday/llama.rn/blob/8738c99/src/index.ts#L47) +[index.ts:58](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L58) ## Methods +### bench + +▸ **bench**(`pp`, `tg`, `pl`, `nr`): `Promise`<[`BenchResult`](../README.md#benchresult)\> + +#### Parameters + +| Name | Type | +| :------ | :------ | +| `pp` | `number` | +| `tg` | `number` | +| `pl` | `number` | +| `nr` | `number` | + +#### Returns + +`Promise`<[`BenchResult`](../README.md#benchresult)\> + +#### Defined in + +[index.ts:129](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L129) + +___ + ### completion ▸ **completion**(`params`, `callback?`): `Promise`<`NativeCompletionResult`\> @@ -90,7 +114,7 @@ ___ #### Defined in -[index.ts:73](https://github.com/mybigday/llama.rn/blob/8738c99/src/index.ts#L73) +[index.ts:84](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L84) ___ @@ -110,7 +134,7 @@ ___ #### Defined in -[index.ts:110](https://github.com/mybigday/llama.rn/blob/8738c99/src/index.ts#L110) +[index.ts:121](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L121) ___ @@ -130,13 +154,13 @@ ___ #### Defined in -[index.ts:114](https://github.com/mybigday/llama.rn/blob/8738c99/src/index.ts#L114) +[index.ts:125](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L125) ___ ### loadSession -▸ **loadSession**(`filepath`): `Promise`<`number`\> +▸ **loadSession**(`filepath`): `Promise`<`NativeSessionLoadResult`\> Load cached prompt & completion state from a file. @@ -148,11 +172,11 @@ Load cached prompt & completion state from a file. #### Returns -`Promise`<`number`\> +`Promise`<`NativeSessionLoadResult`\> #### Defined in -[index.ts:62](https://github.com/mybigday/llama.rn/blob/8738c99/src/index.ts#L62) +[index.ts:73](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L73) ___ @@ -166,13 +190,13 @@ ___ #### Defined in -[index.ts:118](https://github.com/mybigday/llama.rn/blob/8738c99/src/index.ts#L118) +[index.ts:151](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L151) ___ ### saveSession -▸ **saveSession**(`filepath`): `Promise`<`number`\> +▸ **saveSession**(`filepath`, `options?`): `Promise`<`number`\> Save current cached prompt & completion state to a file. @@ -181,6 +205,8 @@ Save current cached prompt & completion state to a file. | Name | Type | | :------ | :------ | | `filepath` | `string` | +| `options?` | `Object` | +| `options.tokenSize` | `number` | #### Returns @@ -188,7 +214,7 @@ Save current cached prompt & completion state to a file. #### Defined in -[index.ts:69](https://github.com/mybigday/llama.rn/blob/8738c99/src/index.ts#L69) +[index.ts:80](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L80) ___ @@ -202,7 +228,7 @@ ___ #### Defined in -[index.ts:102](https://github.com/mybigday/llama.rn/blob/8738c99/src/index.ts#L102) +[index.ts:113](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L113) ___ @@ -222,4 +248,4 @@ ___ #### Defined in -[index.ts:106](https://github.com/mybigday/llama.rn/blob/8738c99/src/index.ts#L106) +[index.ts:117](https://github.com/mybigday/llama.rn/blob/427a856/src/index.ts#L117) diff --git a/docs/API/classes/SchemaGrammarConverter.md b/docs/API/classes/SchemaGrammarConverter.md index e7062a51..9d559643 100644 --- a/docs/API/classes/SchemaGrammarConverter.md +++ b/docs/API/classes/SchemaGrammarConverter.md @@ -33,7 +33,7 @@ #### Defined in -[grammar.ts:39](https://github.com/mybigday/llama.rn/blob/8738c99/src/grammar.ts#L39) +[grammar.ts:39](https://github.com/mybigday/llama.rn/blob/427a856/src/grammar.ts#L39) ## Properties @@ -43,7 +43,7 @@ #### Defined in -[grammar.ts:35](https://github.com/mybigday/llama.rn/blob/8738c99/src/grammar.ts#L35) +[grammar.ts:35](https://github.com/mybigday/llama.rn/blob/427a856/src/grammar.ts#L35) ___ @@ -53,7 +53,7 @@ ___ #### Defined in -[grammar.ts:37](https://github.com/mybigday/llama.rn/blob/8738c99/src/grammar.ts#L37) +[grammar.ts:37](https://github.com/mybigday/llama.rn/blob/427a856/src/grammar.ts#L37) ## Methods @@ -74,7 +74,7 @@ ___ #### Defined in -[grammar.ts:45](https://github.com/mybigday/llama.rn/blob/8738c99/src/grammar.ts#L45) +[grammar.ts:45](https://github.com/mybigday/llama.rn/blob/427a856/src/grammar.ts#L45) ___ @@ -88,7 +88,7 @@ ___ #### Defined in -[grammar.ts:125](https://github.com/mybigday/llama.rn/blob/8738c99/src/grammar.ts#L125) +[grammar.ts:125](https://github.com/mybigday/llama.rn/blob/427a856/src/grammar.ts#L125) ___ @@ -109,4 +109,4 @@ ___ #### Defined in -[grammar.ts:65](https://github.com/mybigday/llama.rn/blob/8738c99/src/grammar.ts#L65) +[grammar.ts:65](https://github.com/mybigday/llama.rn/blob/427a856/src/grammar.ts#L65) diff --git a/example/ios/Podfile.lock b/example/ios/Podfile.lock index 3758c715..a795c345 100644 --- a/example/ios/Podfile.lock +++ b/example/ios/Podfile.lock @@ -8,7 +8,7 @@ PODS: - hermes-engine/Pre-built (= 0.72.3) - hermes-engine/Pre-built (0.72.3) - libevent (2.1.12) - - llama-rn (0.3.0-rc.5): + - llama-rn (0.3.0-rc.7): - RCT-Folly - RCTRequired - RCTTypeSafety @@ -1079,6 +1079,22 @@ PODS: - React-jsi (= 0.72.3) - React-logger (= 0.72.3) - React-perflogger (= 0.72.3) + - RNCClipboard (1.13.1): + - hermes-engine + - RCT-Folly (= 2021.07.22.00) + - RCTRequired + - RCTTypeSafety + - React-Codegen + - React-Core + - React-debug + - React-Fabric + - React-graphics + - React-NativeModulesApple + - React-RCTFabric + - React-utils + - ReactCommon/turbomodule/bridging + - ReactCommon/turbomodule/core + - Yoga - SocketRocket (0.6.1) - Yoga (1.14.0) @@ -1131,6 +1147,7 @@ DEPENDENCIES: - React-runtimescheduler (from `../node_modules/react-native/ReactCommon/react/renderer/runtimescheduler`) - React-utils (from `../node_modules/react-native/ReactCommon/react/utils`) - ReactCommon/turbomodule/core (from `../node_modules/react-native/ReactCommon`) + - "RNCClipboard (from `../node_modules/@react-native-clipboard/clipboard`)" - Yoga (from `../node_modules/react-native/ReactCommon/yoga`) SPEC REPOS: @@ -1231,6 +1248,8 @@ EXTERNAL SOURCES: :path: "../node_modules/react-native/ReactCommon/react/utils" ReactCommon: :path: "../node_modules/react-native/ReactCommon" + RNCClipboard: + :path: "../node_modules/@react-native-clipboard/clipboard" Yoga: :path: "../node_modules/react-native/ReactCommon/yoga" @@ -1242,7 +1261,7 @@ SPEC CHECKSUMS: glog: 04b94705f318337d7ead9e6d17c019bd9b1f6b1b hermes-engine: 10fbd3f62405c41ea07e71973ea61e1878d07322 libevent: 4049cae6c81cdb3654a443be001fb9bdceff7913 - llama-rn: 0a0f4d56e8c2ca348c77847cd18709330314042a + llama-rn: e2a9023b5e3d836bd5ce11000a89cf2d532fffc8 RCT-Folly: 424b8c9a7a0b9ab2886ffe9c3b041ef628fd4fb1 RCTRequired: a2faf4bad4e438ca37b2040cb8f7799baa065c18 RCTTypeSafety: cb09f3e4747b6d18331a15eb05271de7441ca0b3 @@ -1282,6 +1301,7 @@ SPEC CHECKSUMS: React-runtimescheduler: 837c1bebd2f84572db17698cd702ceaf585b0d9a React-utils: bcb57da67eec2711f8b353f6e3d33bd8e4b2efa3 ReactCommon: 3ccb8fb14e6b3277e38c73b0ff5e4a1b8db017a9 + RNCClipboard: 9e7ee48ef151b3536fb6f06aa83995d7afba5fb8 SocketRocket: f32cd54efbe0f095c4d7594881e52619cfe80b17 Yoga: 8796b55dba14d7004f980b54bcc9833ee45b28ce diff --git a/example/package.json b/example/package.json index ea0ccb10..731d770f 100644 --- a/example/package.json +++ b/example/package.json @@ -11,6 +11,7 @@ }, "dependencies": { "@flyerhq/react-native-chat-ui": "^1.4.3", + "@react-native-clipboard/clipboard": "^1.13.1", "react": "18.2.0", "react-native": "0.72.3", "react-native-blob-util": "^0.19.1", diff --git a/example/src/App.tsx b/example/src/App.tsx index d14eaa54..731de7d7 100644 --- a/example/src/App.tsx +++ b/example/src/App.tsx @@ -54,12 +54,7 @@ const renderBubble = ({ }: { child: ReactNode message: MessageType.Any -}) => ( - -) +}) => export default function App() { const [context, setContext] = useState(undefined) @@ -78,14 +73,14 @@ export default function App() { } } - const addSystemMessage = (text: string) => { + const addSystemMessage = (text: string, metadata = {} ) => { const textMessage: MessageType.Text = { author: system, createdAt: Date.now(), id: randId(), text, type: 'text', - metadata: { system: true }, + metadata: { system: true, ...metadata }, } addMessage(textMessage) } @@ -120,6 +115,7 @@ export default function App() { ctx.reasonNoGPU })\n\n` + 'You can use the following commands:\n\n' + + '- /bench: to benchmark the model\n' + '- /release: release the context\n' + '- /stop: stop the current completion\n' + '- /reset: reset the conversation', @@ -131,15 +127,20 @@ export default function App() { } const handlePickModel = async () => { - DocumentPicker.pick({type: Platform.OS === 'ios' ? 'public.data' : 'application/octet-stream'}) + DocumentPicker.pick({ + type: Platform.OS === 'ios' ? 'public.data' : 'application/octet-stream', + }) .then(async (res) => { let [file] = res if (file) { if (Platform.OS === 'android' && file.uri.startsWith('content://')) { const dir = `${ReactNativeBlobUtil.fs.dirs.CacheDir}/models` - if (!(await ReactNativeBlobUtil.fs.isDir(dir))) await ReactNativeBlobUtil.fs.mkdir(dir) + if (!(await ReactNativeBlobUtil.fs.isDir(dir))) + await ReactNativeBlobUtil.fs.mkdir(dir) - const filepath = `${dir}/${file.uri.split('/').pop() || 'model'}.gguf` + const filepath = `${dir}/${ + file.uri.split('/').pop() || 'model' + }.gguf` if (await ReactNativeBlobUtil.fs.exists(filepath)) { handleInitContext({ uri: filepath } as DocumentPickerResponse) return @@ -163,6 +164,37 @@ export default function App() { const handleSendPress = async (message: MessageType.PartialText) => { if (context) { switch (message.text) { + case '/bench': + addSystemMessage('Heating up the model...') + const t0 = Date.now() + await context.bench(8, 4, 1, 1) + const tHeat = Date.now() - t0 + if (tHeat > 1E4) { + addSystemMessage('Heat up time is too long, please try again.') + return + } + addSystemMessage(`Heat up time: ${tHeat}ms`) + + addSystemMessage('Benchmarking the model...') + const { + modelDesc, + modelSize, + modelNParams, + ppAvg, + ppStd, + tgAvg, + tgStd, + } = await context.bench(512, 128, 1, 3) + + const size = `${(modelSize / 1024.0 / 1024.0 / 1024.0).toFixed(2)} GiB` + const nParams = `${(modelNParams / 1e9).toFixed(2)}B` + const md = + '| model | size | params | test | t/s |\n' + + '| --- | --- | --- | --- | --- |\n' + + `| ${modelDesc} | ${size} | ${nParams} | pp 512 | ${ppAvg.toFixed(2)} ± ${ppStd.toFixed(2)} |\n` + + `| ${modelDesc} | ${size} | ${nParams} | tg 128 | ${tgAvg.toFixed(2)} ± ${tgStd.toFixed(2)}` + addSystemMessage(md, { copyable: true }) + return case '/release': await handleReleaseContext() return @@ -171,20 +203,25 @@ export default function App() { return case '/reset': conversationIdRef.current = randId() - addMessage({ - author: system, - createdAt: Date.now(), - id: randId(), - text: 'Conversation reset!', - type: 'text', - metadata: { system: true }, - }) + addSystemMessage('Conversation reset!') return case '/save-session': - await context.saveSession(`${dirs.DocumentDir}/llama-session.bin`) + context.saveSession(`${dirs.DocumentDir}/llama-session.bin`).then(tokensSaved => { + console.log('Session tokens saved:', tokensSaved) + addSystemMessage(`Session saved! ${tokensSaved} tokens saved.`) + }).catch(e => { + console.log('Session save failed:', e) + addSystemMessage(`Session save failed: ${e.message}`) + }) return case '/load-session': - console.log('Session loaded:', await context.loadSession(`${dirs.DocumentDir}/llama-session.bin`)) + context.loadSession(`${dirs.DocumentDir}/llama-session.bin`).then(details => { + console.log('Session loaded:', details) + addSystemMessage(`Session loaded! ${details.tokens_loaded} tokens loaded.`) + }).catch(e => { + console.log('Session load failed:', e) + addSystemMessage(`Session load failed: ${e.message}`) + }) return } } diff --git a/example/src/Bubble.tsx b/example/src/Bubble.tsx index 3f782e92..e4d8812c 100644 --- a/example/src/Bubble.tsx +++ b/example/src/Bubble.tsx @@ -1,6 +1,7 @@ import React, { useContext } from 'react' import type { ReactNode } from 'react' -import { View, Text } from 'react-native' +import { View, Text, TouchableOpacity } from 'react-native' +import Clipboard from '@react-native-clipboard/clipboard' import { ThemeContext, UserContext } from '@flyerhq/react-native-chat-ui' import type { MessageType } from '@flyerhq/react-native-chat-ui' @@ -14,8 +15,11 @@ export const Bubble = ({ const theme = useContext(ThemeContext) const user = useContext(UserContext) const currentUserIsAuthor = user?.id === message.author.id + const { copyable, timings } = message.metadata || {} + + const Container: React.ComponentClass = copyable ? TouchableOpacity : View return ( - { + if (message.type !== 'text') return + Clipboard.setString(message.text); + }} > {child} - {message.metadata?.timings && ( + {timings && ( - {message.metadata.timings} + {timings} )} - + ) } diff --git a/example/yarn.lock b/example/yarn.lock index 34fa42d1..204708e6 100644 --- a/example/yarn.lock +++ b/example/yarn.lock @@ -1333,6 +1333,11 @@ "@nodelib/fs.scandir" "2.1.5" fastq "^1.6.0" +"@react-native-clipboard/clipboard@^1.13.1": + version "1.13.1" + resolved "https://registry.yarnpkg.com/@react-native-clipboard/clipboard/-/clipboard-1.13.1.tgz#e313110aa487c510acb9f810637a41d1d0511857" + integrity sha512-sXWYgjPOK9lDLJQ2AebYY4t19UKh0JLNzX/KTnNulS9XL7Hd4mFZPwPfL4ysF6/SiGJaP6QlEFQbOZA+x4SIPg== + "@react-native-community/cli-clean@11.3.5": version "11.3.5" resolved "https://registry.yarnpkg.com/@react-native-community/cli-clean/-/cli-clean-11.3.5.tgz#07c8a01e433ea6c6e32eb647908be48952888cdd" diff --git a/ios/RNLlama.mm b/ios/RNLlama.mm index 8d98fe70..6b8f32b6 100644 --- a/ios/RNLlama.mm +++ b/ios/RNLlama.mm @@ -215,6 +215,27 @@ - (NSArray *)supportedEvents { } } +RCT_EXPORT_METHOD(bench:(double)contextId + pp:(int)pp + tg:(int)tg + pl:(int)pl + nr:(int)nr + withResolver:(RCTPromiseResolveBlock)resolve + withRejecter:(RCTPromiseRejectBlock)reject) +{ + RNLlamaContext *context = llamaContexts[[NSNumber numberWithDouble:contextId]]; + if (context == nil) { + reject(@"llama_error", @"Context not found", nil); + return; + } + @try { + NSString *benchResults = [context bench:pp tg:tg pl:pl nr:nr]; + resolve(benchResults); + } @catch (NSException *exception) { + reject(@"llama_cpp_error", exception.reason, nil); + } +} + RCT_EXPORT_METHOD(releaseContext:(double)contextId withResolver:(RCTPromiseResolveBlock)resolve withRejecter:(RCTPromiseRejectBlock)reject) diff --git a/ios/RNLlamaContext.h b/ios/RNLlamaContext.h index 88a77d34..aa6d85bc 100644 --- a/ios/RNLlamaContext.h +++ b/ios/RNLlamaContext.h @@ -24,6 +24,7 @@ - (NSArray *)embedding:(NSString *)text; - (NSDictionary *)loadSession:(NSString *)path; - (int)saveSession:(NSString *)path size:(int)size; +- (NSString *)bench:(int)pp tg:(int)tg pl:(int)pl nr:(int)nr; - (void)invalidate; diff --git a/ios/RNLlamaContext.mm b/ios/RNLlamaContext.mm index 55adae2c..37445952 100644 --- a/ios/RNLlamaContext.mm +++ b/ios/RNLlamaContext.mm @@ -373,6 +373,10 @@ - (int)saveSession:(NSString *)path size:(int)size { return session_tokens.size(); } +- (NSString *)bench:(int)pp tg:(int)tg pl:(int)pl nr:(int)nr { + return [NSString stringWithUTF8String:llama->bench(pp, tg, pl, nr).c_str()]; +} + - (void)invalidate { delete llama; // llama_backend_free(); diff --git a/src/NativeRNLlama.ts b/src/NativeRNLlama.ts index d1b1a75a..22017638 100644 --- a/src/NativeRNLlama.ts +++ b/src/NativeRNLlama.ts @@ -121,6 +121,8 @@ export interface Spec extends TurboModule { tokenize(contextId: number, text: string): Promise; detokenize(contextId: number, tokens: number[]): Promise; embedding(contextId: number, text: string): Promise; + bench(contextId: number, pp: number, tg: number, pl: number, nr: number): Promise; + releaseContext(contextId: number): Promise; releaseAllContexts(): Promise; diff --git a/src/index.ts b/src/index.ts index 3150fe1c..cc7843a1 100644 --- a/src/index.ts +++ b/src/index.ts @@ -40,6 +40,16 @@ export type ContextParams = NativeContextParams export type CompletionParams = Omit +export type BenchResult = { + modelDesc: string + modelSize: number + modelNParams: number + ppAvg: number + ppStd: number + tgAvg: number + tgStd: number +} + export class LlamaContext { id: number @@ -116,6 +126,28 @@ export class LlamaContext { return RNLlama.embedding(this.id, text) } + async bench(pp: number, tg: number, pl: number, nr: number): Promise { + const result = await RNLlama.bench(this.id, pp, tg, pl, nr) + const [ + modelDesc, + modelSize, + modelNParams, + ppAvg, + ppStd, + tgAvg, + tgStd, + ] = JSON.parse(result) + return { + modelDesc, + modelSize, + modelNParams, + ppAvg, + ppStd, + tgAvg, + tgStd, + } + } + async release(): Promise { return RNLlama.releaseContext(this.id) }