Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: expose n_ubatch and dynamically adjust ntokens for bench #104

Merged
merged 2 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa
params.hasKey("n_ctx") ? params.getInt("n_ctx") : 512,
// int n_batch,
params.hasKey("n_batch") ? params.getInt("n_batch") : 512,
// int n_ubatch,
params.hasKey("n_ubatch") ? params.getInt("n_ubatch") : 512,
// int n_threads,
params.hasKey("n_threads") ? params.getInt("n_threads") : 0,
// int n_gpu_layers, // TODO: Support this
Expand Down Expand Up @@ -412,6 +414,7 @@ protected static native long initContext(
int embd_normalize,
int n_ctx,
int n_batch,
int n_ubatch,
int n_threads,
int n_gpu_layers, // TODO: Support this
boolean flash_attn,
Expand Down
2 changes: 2 additions & 0 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ Java_com_rnllama_LlamaContext_initContext(
jint embd_normalize,
jint n_ctx,
jint n_batch,
jint n_ubatch,
jint n_threads,
jint n_gpu_layers, // TODO: Support this
jboolean flash_attn,
Expand Down Expand Up @@ -256,6 +257,7 @@ Java_com_rnllama_LlamaContext_initContext(

defaultParams.n_ctx = n_ctx;
defaultParams.n_batch = n_batch;
defaultParams.n_ubatch = n_ubatch;

if (pooling_type != -1) {
defaultParams.pooling_type = static_cast<enum llama_pooling_type>(pooling_type);
Expand Down
6 changes: 5 additions & 1 deletion cpp/rn-llama.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,11 @@ struct llama_rn_context
double tg_std = 0;

// TODO: move batch into llama_rn_context (related https://github.com/mybigday/llama.rn/issues/30)
llama_batch batch = llama_batch_init(512, 0, 1);
llama_batch batch = llama_batch_init(
std::min(pp, params.n_ubatch), // max n_tokens is limited by n_ubatch
0, // No embeddings
1 // Single sequence
);

for (int i = 0; i < nr; i++)
{
Expand Down
1 change: 1 addition & 0 deletions ios/RNLlamaContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsig
#endif
}
if (params[@"n_batch"]) defaultParams.n_batch = [params[@"n_batch"] intValue];
if (params[@"n_ubatch"]) defaultParams.n_ubatch = [params[@"n_ubatch"] intValue];
if (params[@"use_mmap"]) defaultParams.use_mmap = [params[@"use_mmap"] boolValue];

if (params[@"pooling_type"] && [params[@"pooling_type"] isKindOfClass:[NSNumber class]]) {
Expand Down
1 change: 1 addition & 0 deletions src/NativeRNLlama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export type NativeContextParams = {

n_ctx?: number
n_batch?: number
n_ubatch?: number

n_threads?: number
n_gpu_layers?: number
Expand Down
Loading