Skip to content

Commit

Permalink
feat(ios): add progress callback in initLlama
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Nov 4, 2024
1 parent 192c9ae commit 48cea0a
Show file tree
Hide file tree
Showing 12 changed files with 128 additions and 37 deletions.
3 changes: 3 additions & 0 deletions cpp/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,9 @@ struct llama_model_params common_model_params_to_llama(const common_params & par
mparams.kv_overrides = params.kv_overrides.data();
}

mparams.progress_callback = params.progress_callback;
mparams.progress_callback_user_data = params.progress_callback_user_data;

return mparams;
}

Expand Down
3 changes: 3 additions & 0 deletions cpp/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,9 @@ struct common_params {
bool warmup = true; // warmup run
bool check_tensors = false; // validate tensor data

llama_progress_callback progress_callback;
void * progress_callback_user_data;

std::string cache_type_k = "f16"; // KV cache data type for the K
std::string cache_type_v = "f16"; // KV cache data type for the V

Expand Down
19 changes: 11 additions & 8 deletions cpp/rn-llama.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,12 @@ struct llama_rn_context
common_params params;

llama_model *model = nullptr;
float loading_progress = 0;
bool is_load_interrupted = false;

llama_context *ctx = nullptr;
common_sampler *ctx_sampling = nullptr;

int n_ctx;

bool truncated = false;
Expand Down Expand Up @@ -367,7 +370,7 @@ struct llama_rn_context
n_eval = params.n_batch;
}
if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval)))
{
{
LOG_ERROR("failed to eval, n_eval: %d, n_past: %d, n_threads: %d, embd: %s",
n_eval,
n_past,
Expand All @@ -378,7 +381,7 @@ struct llama_rn_context
return result;
}
n_past += n_eval;

if(is_interrupted) {
LOG_INFO("Decoding Interrupted");
embd.resize(n_past);
Expand All @@ -400,19 +403,19 @@ struct llama_rn_context
candidates.reserve(llama_n_vocab(model));

result.tok = common_sampler_sample(ctx_sampling, ctx, -1);

llama_token_data_array cur_p = *common_sampler_get_candidates(ctx_sampling);

const int32_t n_probs = params.sparams.n_probs;

// deprecated
/*if (params.sparams.temp <= 0 && n_probs > 0)
{
// For llama_sample_token_greedy we need to sort candidates
llama_sampler_init_softmax();
}*/


for (size_t i = 0; i < std::min(cur_p.size, (size_t)n_probs); ++i)
{
Expand Down Expand Up @@ -542,14 +545,14 @@ struct llama_rn_context
return std::vector<float>(n_embd, 0.0f);
}
float *data;

if(params.pooling_type == 0){
data = llama_get_embeddings(ctx);
}
else {
data = llama_get_embeddings_seq(ctx, 0);
}

if(!data) {
return std::vector<float>(n_embd, 0.0f);
}
Expand Down
2 changes: 1 addition & 1 deletion example/ios/.xcode.env.local
Original file line number Diff line number Diff line change
@@ -1 +1 @@
export NODE_BINARY=/var/folders/4z/1d45cfts3936kdm7v9jl349r0000gn/T/yarn--1730514789911-0.16979892623603998/node
export NODE_BINARY=/var/folders/4z/1d45cfts3936kdm7v9jl349r0000gn/T/yarn--1730697817603-0.6786179339916347/node
19 changes: 18 additions & 1 deletion example/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ export default function App() {
metadata: { system: true, ...metadata },
}
addMessage(textMessage)
return textMessage.id
}

const handleReleaseContext = async () => {
Expand All @@ -82,12 +83,28 @@ export default function App() {

const handleInitContext = async (file: DocumentPickerResponse) => {
await handleReleaseContext()
addSystemMessage('Initializing context...')
const msgId = addSystemMessage('Initializing context...')
initLlama({
model: file.uri,
use_mlock: true,
n_gpu_layers: Platform.OS === 'ios' ? 0 : 0, // > 0: enable GPU
// embedding: true,
}, (progress) => {
setMessages((msgs) => {
const index = msgs.findIndex((msg) => msg.id === msgId)
if (index >= 0) {
return msgs.map((msg, i) => {
if (msg.type == 'text' && i === index) {
return {
...msg,
text: `Initializing context... ${progress}%`,
}
}
return msg
})
}
return msgs
})
})
.then((ctx) => {
setContext(ctx)
Expand Down
23 changes: 17 additions & 6 deletions ios/RNLlama.mm
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,17 @@ @implementation RNLlama
resolve(nil);
}

RCT_EXPORT_METHOD(initContext:(NSDictionary *)contextParams
RCT_EXPORT_METHOD(initContext:(double)contextId
withContextParams:(NSDictionary *)contextParams
withResolver:(RCTPromiseResolveBlock)resolve
withRejecter:(RCTPromiseRejectBlock)reject)
{
NSNumber *contextIdNumber = [NSNumber numberWithDouble:contextId];
if (llamaContexts[contextIdNumber] != nil) {
reject(@"llama_error", @"Context already exists", nil);
return;
}

if (llamaDQueue == nil) {
llamaDQueue = dispatch_queue_create("com.rnllama", DISPATCH_QUEUE_SERIAL);
}
Expand All @@ -38,19 +45,19 @@ @implementation RNLlama
return;
}

RNLlamaContext *context = [RNLlamaContext initWithParams:contextParams];
RNLlamaContext *context = [RNLlamaContext initWithParams:contextParams onProgress:^(unsigned int progress) {
dispatch_async(dispatch_get_main_queue(), ^{
[self sendEventWithName:@"@RNLlama_onInitContextProgress" body:@{ @"contextId": @(contextId), @"progress": @(progress) }];
});
}];
if (![context isModelLoaded]) {
reject(@"llama_cpp_error", @"Failed to load the model", nil);
return;
}

double contextId = (double) arc4random_uniform(1000000);

NSNumber *contextIdNumber = [NSNumber numberWithDouble:contextId];
[llamaContexts setObject:context forKey:contextIdNumber];

resolve(@{
@"contextId": contextIdNumber,
@"gpu": @([context isMetalEnabled]),
@"reasonNoGPU": [context reasonNoMetal],
@"model": [context modelInfo],
Expand Down Expand Up @@ -125,6 +132,7 @@ @implementation RNLlama

- (NSArray *)supportedEvents {
return@[
@"@RNLlama_onInitContextProgress",
@"@RNLlama_onToken",
];
}
Expand Down Expand Up @@ -260,6 +268,9 @@ - (NSArray *)supportedEvents {
reject(@"llama_error", @"Context not found", nil);
return;
}
if (![context isModelLoaded]) {
[context interruptLoad];
}
[context stopCompletion];
dispatch_barrier_sync(llamaDQueue, ^{});
[context invalidate];
Expand Down
7 changes: 5 additions & 2 deletions ios/RNLlamaContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@

@interface RNLlamaContext : NSObject {
bool is_metal_enabled;
NSString * reason_no_metal;
bool is_model_loaded;
NSString * reason_no_metal;

void (^onProgress)(unsigned int progress);

rnllama::llama_rn_context * llama;
}

+ (instancetype)initWithParams:(NSDictionary *)params;
+ (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsigned int progress))onProgress;
- (void)interruptLoad;
- (bool)isMetalEnabled;
- (NSString *)reasonNoMetal;
- (NSDictionary *)modelInfo;
Expand Down
21 changes: 18 additions & 3 deletions ios/RNLlamaContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

@implementation RNLlamaContext

+ (instancetype)initWithParams:(NSDictionary *)params {
+ (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsigned int progress))onProgress {
// llama_backend_init(false);
common_params defaultParams;

Expand Down Expand Up @@ -78,9 +78,24 @@ + (instancetype)initWithParams:(NSDictionary *)params {
defaultParams.cpuparams.n_threads = nThreads > 0 ? nThreads : defaultNThreads;

RNLlamaContext *context = [[RNLlamaContext alloc] init];
if (context->llama == nullptr) {
context->llama = new rnllama::llama_rn_context();
context->llama = new rnllama::llama_rn_context();
context->llama->is_load_interrupted = false;
context->llama->loading_progress = 0;
context->onProgress = onProgress;

if (params[@"use_progress_callback"] && [params[@"use_progress_callback"] boolValue]) {
defaultParams.progress_callback = [](float progress, void * user_data) {
RNLlamaContext *context = (__bridge RNLlamaContext *)(user_data);
unsigned percentage = (unsigned) (100 * progress);
if (percentage > context->llama->loading_progress) {
context->llama->loading_progress = percentage;
context->onProgress(percentage);
}
return !context->llama->is_load_interrupted;
};
defaultParams.progress_callback_user_data = context;
}

context->is_model_loaded = context->llama->loadModel(defaultParams);
context->is_metal_enabled = isMetalEnabled;
context->reason_no_metal = reasonNoMetal;
Expand Down
22 changes: 16 additions & 6 deletions scripts/common.cpp.patch
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
--- common.cpp.orig 2024-11-02 10:33:10
+++ common.cpp 2024-11-02 10:33:11
@@ -53,6 +53,12 @@
#include <curl/easy.h>
--- common.cpp.orig 2024-11-04 12:59:08
+++ common.cpp 2024-11-04 12:58:17
@@ -54,6 +54,12 @@
#include <future>
#endif
+

+// build info
+int LLAMA_BUILD_NUMBER = 0;
+char const *LLAMA_COMMIT = "unknown";
+char const *LLAMA_COMPILER = "unknown";
+char const *LLAMA_BUILD_TARGET = "unknown";

+
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
@@ -979,6 +985,8 @@
if (params.n_gpu_layers != -1) {
mparams.n_gpu_layers = params.n_gpu_layers;
Expand All @@ -22,3 +22,13 @@
mparams.rpc_servers = params.rpc_servers.c_str();
mparams.main_gpu = params.main_gpu;
mparams.split_mode = params.split_mode;
@@ -993,6 +1001,9 @@
mparams.kv_overrides = params.kv_overrides.data();
}

+ mparams.progress_callback = params.progress_callback;
+ mparams.progress_callback_user_data = params.progress_callback_user_data;
+
return mparams;
}

22 changes: 16 additions & 6 deletions scripts/common.h.patch
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
--- common.h.orig 2024-11-02 10:33:10
+++ common.h 2024-11-02 10:33:11
@@ -40,6 +40,17 @@
extern char const * LLAMA_BUILD_TARGET;
--- common.h.orig 2024-11-04 12:59:08
+++ common.h 2024-11-04 12:58:24
@@ -41,6 +41,17 @@

struct common_control_vector_load_info;
+

+#define print_build_info() do { \
+ fprintf(stderr, "%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); \
+ fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \
Expand All @@ -15,9 +14,10 @@
+extern char const *LLAMA_COMMIT;
+extern char const *LLAMA_COMPILER;
+extern char const *LLAMA_BUILD_TARGET;

+
//
// CPU utils
//
@@ -154,6 +165,7 @@
};

Expand All @@ -26,3 +26,13 @@
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 0; // context size
int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
@@ -271,6 +283,9 @@
bool warmup = true; // warmup run
bool check_tensors = false; // validate tensor data

+ llama_progress_callback progress_callback;
+ void * progress_callback_user_data;
+
std::string cache_type_k = "f16"; // KV cache data type for the K
std::string cache_type_v = "f16"; // KV cache data type for the V

3 changes: 2 additions & 1 deletion src/NativeRNLlama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { TurboModuleRegistry } from 'react-native'
export type NativeContextParams = {
model: string
is_model_asset?: boolean
use_progress_callback?: boolean

embedding?: boolean

Expand Down Expand Up @@ -119,7 +120,7 @@ export type NativeLlamaChatMessage = {

export interface Spec extends TurboModule {
setContextLimit(limit: number): Promise<void>
initContext(params: NativeContextParams): Promise<NativeLlamaContext>
initContext(contextId: number, params: NativeContextParams): Promise<NativeLlamaContext>

getFormattedChat(
contextId: number,
Expand Down
21 changes: 18 additions & 3 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import { formatChat } from './chat'

export { SchemaGrammarConverter, convertJsonSchemaToGrammar }

const EVENT_ON_INIT_CONTEXT_PROGRESS = '@RNLlama_onInitContextProgress'
const EVENT_ON_TOKEN = '@RNLlama_onToken'

let EventEmitter: NativeEventEmitter | DeviceEventEmitterStatic
Expand Down Expand Up @@ -192,19 +193,33 @@ export async function initLlama({
model,
is_model_asset: isModelAsset,
...rest
}: ContextParams): Promise<LlamaContext> {
}: ContextParams, onProgress?: (progress: number) => void): Promise<LlamaContext> {
let path = model
if (path.startsWith('file://')) path = path.slice(7)
const contextId = Math.floor(Math.random() * 1000000)

let removeProgressListener: any = null
if (onProgress) {
removeProgressListener = EventEmitter.addListener(EVENT_ON_INIT_CONTEXT_PROGRESS, (evt: { contextId: number, progress: number }) => {
if (evt.contextId !== contextId) return
onProgress(evt.progress)
})
}

const {
contextId,
gpu,
reasonNoGPU,
model: modelDetails,
} = await RNLlama.initContext({
} = await RNLlama.initContext(contextId, {
model: path,
is_model_asset: !!isModelAsset,
use_progress_callback: !!onProgress,
...rest,
}).catch((err: any) => {
removeProgressListener?.remove()
throw err
})
removeProgressListener?.remove()
return new LlamaContext({ contextId, gpu, reasonNoGPU, model: modelDetails })
}

Expand Down

0 comments on commit 48cea0a

Please sign in to comment.