diff --git a/android/src/main/java/com/rnllama/LlamaContext.java b/android/src/main/java/com/rnllama/LlamaContext.java index 2b36d52..91b2397 100644 --- a/android/src/main/java/com/rnllama/LlamaContext.java +++ b/android/src/main/java/com/rnllama/LlamaContext.java @@ -50,6 +50,8 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa params.hasKey("n_ctx") ? params.getInt("n_ctx") : 512, // int n_batch, params.hasKey("n_batch") ? params.getInt("n_batch") : 512, + // int n_ubatch, + params.hasKey("n_ubatch") ? params.getInt("n_ubatch") : 512, // int n_threads, params.hasKey("n_threads") ? params.getInt("n_threads") : 0, // int n_gpu_layers, // TODO: Support this @@ -412,6 +414,7 @@ protected static native long initContext( int embd_normalize, int n_ctx, int n_batch, + int n_ubatch, int n_threads, int n_gpu_layers, // TODO: Support this boolean flash_attn, diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 8640aa7..2171fd2 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -226,6 +226,7 @@ Java_com_rnllama_LlamaContext_initContext( jint embd_normalize, jint n_ctx, jint n_batch, + jint n_ubatch, jint n_threads, jint n_gpu_layers, // TODO: Support this jboolean flash_attn, @@ -256,6 +257,7 @@ Java_com_rnllama_LlamaContext_initContext( defaultParams.n_ctx = n_ctx; defaultParams.n_batch = n_batch; + defaultParams.n_ubatch = n_ubatch; if (pooling_type != -1) { defaultParams.pooling_type = static_cast(pooling_type); diff --git a/cpp/rn-llama.hpp b/cpp/rn-llama.hpp index 9e7b4b4..a85d514 100644 --- a/cpp/rn-llama.hpp +++ b/cpp/rn-llama.hpp @@ -9,6 +9,9 @@ #include "llama.h" #include "llama-impl.h" #include "sampling.h" +#if defined(__ANDROID__) +#include +#endif namespace rnllama { @@ -53,16 +56,32 @@ static void llama_batch_add(llama_batch *batch, llama_token id, llama_pos pos, s static void log(const char *level, const char *function, int line, const char *format, ...) { - printf("[%s] %s:%d ", level, function, line); - va_list args; - va_start(args, format); - vprintf(format, args); - va_end(args); - - printf("\n"); + #if defined(__ANDROID__) + char prefix[256]; + snprintf(prefix, sizeof(prefix), "%s:%d %s", function, line, format); + + va_start(args, format); + android_LogPriority priority; + if (strcmp(level, "ERROR") == 0) { + priority = ANDROID_LOG_ERROR; + } else if (strcmp(level, "WARNING") == 0) { + priority = ANDROID_LOG_WARN; + } else if (strcmp(level, "INFO") == 0) { + priority = ANDROID_LOG_INFO; + } else { + priority = ANDROID_LOG_DEBUG; + } + __android_log_vprint(priority, "RNLlama", prefix, args); + va_end(args); + #else + printf("[%s] %s:%d ", level, function, line); + va_start(args, format); + vprintf(format, args); + va_end(args); + printf("\n"); + #endif } - static bool rnllama_verbose = false; #if RNLLAMA_VERBOSE != 1 @@ -250,6 +269,10 @@ struct llama_rn_context return false; } n_ctx = llama_n_ctx(ctx); + + // We can uncomment for debugging or after this fix: https://github.com/ggerganov/llama.cpp/pull/11101 + // LOG_INFO("%s\n", common_params_get_system_info(params).c_str()); + return true; } @@ -592,7 +615,11 @@ struct llama_rn_context double tg_std = 0; // TODO: move batch into llama_rn_context (related https://github.com/mybigday/llama.rn/issues/30) - llama_batch batch = llama_batch_init(512, 0, 1); + llama_batch batch = llama_batch_init( + std::min(pp, params.n_ubatch), // max n_tokens is limited by n_ubatch + 0, // No embeddings + 1 // Single sequence + ); for (int i = 0; i < nr; i++) { diff --git a/example/ios/.xcode.env.local b/example/ios/.xcode.env.local index 40c3d36..2e6e88e 100644 --- a/example/ios/.xcode.env.local +++ b/example/ios/.xcode.env.local @@ -1 +1 @@ -export NODE_BINARY=/var/folders/4z/1d45cfts3936kdm7v9jl349r0000gn/T/yarn--1736391556824-0.01101787861122494/node +export NODE_BINARY=/var/folders/4z/1d45cfts3936kdm7v9jl349r0000gn/T/yarn--1736393474904-0.18689251702594478/node diff --git a/example/ios/RNLlamaExample.xcodeproj/project.pbxproj b/example/ios/RNLlamaExample.xcodeproj/project.pbxproj index 5618edc..288db84 100644 --- a/example/ios/RNLlamaExample.xcodeproj/project.pbxproj +++ b/example/ios/RNLlamaExample.xcodeproj/project.pbxproj @@ -596,7 +596,7 @@ GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 12.4; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LD_RUNPATH_SEARCH_PATHS = ( /usr/lib/swift, "$(inherited)", @@ -666,7 +666,7 @@ GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 12.4; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LD_RUNPATH_SEARCH_PATHS = ( /usr/lib/swift, "$(inherited)", diff --git a/ios/RNLlamaContext.mm b/ios/RNLlamaContext.mm index 7b89ea8..4891233 100644 --- a/ios/RNLlamaContext.mm +++ b/ios/RNLlamaContext.mm @@ -94,6 +94,7 @@ + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsig #endif } if (params[@"n_batch"]) defaultParams.n_batch = [params[@"n_batch"] intValue]; + if (params[@"n_ubatch"]) defaultParams.n_ubatch = [params[@"n_ubatch"] intValue]; if (params[@"use_mmap"]) defaultParams.use_mmap = [params[@"use_mmap"] boolValue]; if (params[@"pooling_type"] && [params[@"pooling_type"] isKindOfClass:[NSNumber class]]) { diff --git a/src/NativeRNLlama.ts b/src/NativeRNLlama.ts index 92d8458..a652a41 100644 --- a/src/NativeRNLlama.ts +++ b/src/NativeRNLlama.ts @@ -12,6 +12,7 @@ export type NativeContextParams = { n_ctx?: number n_batch?: number + n_ubatch?: number n_threads?: number n_gpu_layers?: number