diff --git a/android/src/main/java/com/rnllama/LlamaContext.java b/android/src/main/java/com/rnllama/LlamaContext.java index 2b36d52..91b2397 100644 --- a/android/src/main/java/com/rnllama/LlamaContext.java +++ b/android/src/main/java/com/rnllama/LlamaContext.java @@ -50,6 +50,8 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa params.hasKey("n_ctx") ? params.getInt("n_ctx") : 512, // int n_batch, params.hasKey("n_batch") ? params.getInt("n_batch") : 512, + // int n_ubatch, + params.hasKey("n_ubatch") ? params.getInt("n_ubatch") : 512, // int n_threads, params.hasKey("n_threads") ? params.getInt("n_threads") : 0, // int n_gpu_layers, // TODO: Support this @@ -412,6 +414,7 @@ protected static native long initContext( int embd_normalize, int n_ctx, int n_batch, + int n_ubatch, int n_threads, int n_gpu_layers, // TODO: Support this boolean flash_attn, diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 774b44d..903e2b2 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -226,6 +226,7 @@ Java_com_rnllama_LlamaContext_initContext( jint embd_normalize, jint n_ctx, jint n_batch, + jint n_ubatch, jint n_threads, jint n_gpu_layers, // TODO: Support this jboolean flash_attn, @@ -256,6 +257,7 @@ Java_com_rnllama_LlamaContext_initContext( defaultParams.n_ctx = n_ctx; defaultParams.n_batch = n_batch; + defaultParams.n_ubatch = n_ubatch; if (pooling_type != -1) { defaultParams.pooling_type = static_cast(pooling_type); diff --git a/cpp/rn-llama.hpp b/cpp/rn-llama.hpp index 9a89a08..aad1d5c 100644 --- a/cpp/rn-llama.hpp +++ b/cpp/rn-llama.hpp @@ -683,7 +683,11 @@ struct llama_rn_context double tg_std = 0; // TODO: move batch into llama_rn_context (related https://github.com/mybigday/llama.rn/issues/30) - llama_batch batch = llama_batch_init(512, 0, 1); + llama_batch batch = llama_batch_init( + std::min(pp, params.n_ubatch), // max n_tokens is limited by n_ubatch + 0, // No embeddings + 1 // Single sequence + ); for (int i = 0; i < nr; i++) { diff --git a/ios/RNLlamaContext.mm b/ios/RNLlamaContext.mm index 9d08af4..1ba44ff 100644 --- a/ios/RNLlamaContext.mm +++ b/ios/RNLlamaContext.mm @@ -94,6 +94,7 @@ + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsig #endif } if (params[@"n_batch"]) defaultParams.n_batch = [params[@"n_batch"] intValue]; + if (params[@"n_ubatch"]) defaultParams.n_ubatch = [params[@"n_ubatch"] intValue]; if (params[@"use_mmap"]) defaultParams.use_mmap = [params[@"use_mmap"] boolValue]; if (params[@"pooling_type"] && [params[@"pooling_type"] isKindOfClass:[NSNumber class]]) { diff --git a/src/NativeRNLlama.ts b/src/NativeRNLlama.ts index 92d8458..a652a41 100644 --- a/src/NativeRNLlama.ts +++ b/src/NativeRNLlama.ts @@ -12,6 +12,7 @@ export type NativeContextParams = { n_ctx?: number n_batch?: number + n_ubatch?: number n_threads?: number n_gpu_layers?: number