Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: style transfer with openCV(ios) #48

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/computer-vision/screens/StyleTransferScreen.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ export const StyleTransferScreen = ({
}
};

if (model.isModelLoading) {
if (!model.isModelReady) {
return (
<Spinner
visible={model.isModelLoading}
visible={!model.isModelReady}
textContent={`Loading the model...`}
/>
);
Expand Down
Binary file not shown.
Binary file not shown.
19 changes: 19 additions & 0 deletions ios/RnExecutorch.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
objectVersion = 77;
objects = {

/* Begin PBXBuildFile section */
55D6EA8C2D0987D2009BA408 /* ExecutorchLib.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 55D6EA8B2D0987D2009BA408 /* ExecutorchLib.xcframework */; };
55D6EA8E2D0987DF009BA408 /* opencv2.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 55D6EA8D2D0987DF009BA408 /* opencv2.xcframework */; };
/* End PBXBuildFile section */

/* Begin PBXCopyFilesBuildPhase section */
550986872CEF541900FECBB8 /* CopyFiles */ = {
isa = PBXCopyFilesBuildPhase;
Expand All @@ -20,6 +25,8 @@

/* Begin PBXFileReference section */
550986892CEF541900FECBB8 /* libRnExecutorch.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = libRnExecutorch.a; sourceTree = BUILT_PRODUCTS_DIR; };
55D6EA8B2D0987D2009BA408 /* ExecutorchLib.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; path = ExecutorchLib.xcframework; sourceTree = "<group>"; };
55D6EA8D2D0987DF009BA408 /* opencv2.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = opencv2.xcframework; path = ../../../opencv2.xcframework; sourceTree = "<group>"; };
/* End PBXFileReference section */

/* Begin PBXFileSystemSynchronizedGroupBuildPhaseMembershipExceptionSet section */
Expand Down Expand Up @@ -48,6 +55,8 @@
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
55D6EA8E2D0987DF009BA408 /* opencv2.xcframework in Frameworks */,
55D6EA8C2D0987D2009BA408 /* ExecutorchLib.xcframework in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
Expand All @@ -58,6 +67,7 @@
isa = PBXGroup;
children = (
5509868B2CEF541900FECBB8 /* RnExecutorch */,
55D6EA8A2D0987D2009BA408 /* Frameworks */,
5509868A2CEF541900FECBB8 /* Products */,
);
sourceTree = "<group>";
Expand All @@ -70,6 +80,15 @@
name = Products;
sourceTree = "<group>";
};
55D6EA8A2D0987D2009BA408 /* Frameworks */ = {
isa = PBXGroup;
children = (
55D6EA8D2D0987DF009BA408 /* opencv2.xcframework */,
55D6EA8B2D0987D2009BA408 /* ExecutorchLib.xcframework */,
);
name = Frameworks;
sourceTree = "<group>";
};
/* End PBXGroup section */

/* Begin PBXNativeTarget section */
Expand Down
46 changes: 23 additions & 23 deletions ios/RnExecutorch/ETModule.mm
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,30 @@ - (void)loadModule:(NSString *)modelSource
if (!module) {
module = [[ETModel alloc] init];
}

[Fetcher fetchResource:[NSURL URLWithString:modelSource]
resourceType:ResourceType::MODEL
completionHandler:^(NSString *filePath, NSError *error) {
if (error) {
reject(@"init_module_error", @"-1", nil);
return;
}

NSNumber *result = [self->module loadModel:filePath];
if ([result isEqualToNumber:@(0)]) {
resolve(result);
} else {
NSError *error = [NSError
errorWithDomain:@"ETModuleErrorDomain"
code:[result intValue]
userInfo:@{
NSLocalizedDescriptionKey : [NSString
stringWithFormat:@"%ld", (long)[result longValue]]
}];

reject(@"init_module_error", error.localizedDescription, error);
}
}];
if (error) {
reject(@"init_module_error", @"-1", nil);
return;
}
NSNumber *result = [self->module loadModel:filePath];
if ([result isEqualToNumber:@(0)]) {
resolve(result);
} else {
NSError *error = [NSError
errorWithDomain:@"ETModuleErrorDomain"
code:[result intValue]
userInfo:@{
NSLocalizedDescriptionKey : [NSString
stringWithFormat:@"%ld", (long)[result longValue]]
}];
reject(@"init_module_error", error.localizedDescription, error);
}
}];
}

- (void)forward:(NSArray *)input
Expand All @@ -48,7 +48,7 @@ - (void)forward:(NSArray *)input
resolve:(RCTPromiseResolveBlock)resolve
reject:(RCTPromiseRejectBlock)reject {
@try {
NSArray *result = [module forward:input shape:shape inputType:[NSNumber numberWithInt:inputType]];
NSArray *result = [module forward:input shape:shape inputType:[NSNumber numberWithInt:inputType]];
resolve(result);
} @catch (NSException *exception) {
NSLog(@"An exception occurred: %@, %@", exception.name, exception.reason);
Expand All @@ -69,7 +69,7 @@ - (void)loadMethod:(NSString *)methodName
}

- (std::shared_ptr<facebook::react::TurboModule>)getTurboModule:
(const facebook::react::ObjCTurboModule::InitParams &)params {
(const facebook::react::ObjCTurboModule::InitParams &)params {
return std::make_shared<facebook::react::NativeETModuleSpecJSI>(params);
}

Expand Down
122 changes: 61 additions & 61 deletions ios/RnExecutorch/LLM.mm
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#import "LLM.h"
#import <ExecutorchLib/LLaMARunner.h>
#import "utils/ConversationManager.h"
#import "utils/Constants.h"
#import "utils/llms/ConversationManager.h"
#import "utils/llms/Constants.h"
#import "utils/Fetcher.h"
#import "utils/LargeFileFetcher.h"
#import <UIKit/UIKit.h>
Expand Down Expand Up @@ -47,77 +47,77 @@ - (void)onResult:(NSString *)token prompt:(NSString *)prompt {

- (void)updateDownloadProgress:(NSNumber *)progress {
dispatch_async(dispatch_get_main_queue(), ^{
[self emitOnDownloadProgress:progress];
[self emitOnDownloadProgress:progress];
});
}

- (void)loadLLM:(NSString *)modelSource tokenizerSource:(NSString *)tokenizerSource systemPrompt:(NSString *)systemPrompt contextWindowLength:(double)contextWindowLength resolve:(RCTPromiseResolveBlock)resolve reject:(RCTPromiseRejectBlock)reject {
NSURL *modelURL = [NSURL URLWithString:modelSource];
NSURL *tokenizerURL = [NSURL URLWithString:tokenizerSource];

if(self->runner || isFetching){
reject(@"model_already_loaded", @"Model and tokenizer already loaded", nil);
NSURL *modelURL = [NSURL URLWithString:modelSource];
NSURL *tokenizerURL = [NSURL URLWithString:tokenizerSource];

if(self->runner || isFetching){
reject(@"model_already_loaded", @"Model and tokenizer already loaded", nil);
return;
}

isFetching = YES;
[Fetcher fetchResource:tokenizerURL resourceType:ResourceType::TOKENIZER completionHandler:^(NSString *tokenizerFilePath, NSError *error) {
if(error){
reject(@"download_error", error.localizedDescription, nil);
return;
}
LargeFileFetcher *modelFetcher = [[LargeFileFetcher alloc] init];
modelFetcher.onProgress = ^(NSNumber *progress) {
[self updateDownloadProgress:progress];
};

isFetching = YES;
[Fetcher fetchResource:tokenizerURL resourceType:ResourceType::TOKENIZER completionHandler:^(NSString *tokenizerFilePath, NSError *error) {
if(error){
reject(@"download_error", error.localizedDescription, nil);
return;
}
LargeFileFetcher *modelFetcher = [[LargeFileFetcher alloc] init];
modelFetcher.onProgress = ^(NSNumber *progress) {
[self updateDownloadProgress:progress];
};

modelFetcher.onFailure = ^(NSError *error){
reject(@"download_error", error.localizedDescription, nil);
return;
};

modelFetcher.onFinish = ^(NSString *modelFilePath) {
self->runner = [[LLaMARunner alloc] initWithModelPath:modelFilePath tokenizerPath:tokenizerFilePath];
NSUInteger contextWindowLengthUInt = (NSUInteger)round(contextWindowLength);

self->conversationManager = [[ConversationManager alloc] initWithNumMessagesContextWindow: contextWindowLengthUInt systemPrompt: systemPrompt];
self->isFetching = NO;
resolve(@"Model and tokenizer loaded successfully");
return;
};
modelFetcher.onFailure = ^(NSError *error){
reject(@"download_error", error.localizedDescription, nil);
return;
};

modelFetcher.onFinish = ^(NSString *modelFilePath) {
self->runner = [[LLaMARunner alloc] initWithModelPath:modelFilePath tokenizerPath:tokenizerFilePath];
NSUInteger contextWindowLengthUInt = (NSUInteger)round(contextWindowLength);

[modelFetcher startDownloadingFileFromURL:modelURL];
}];
self->conversationManager = [[ConversationManager alloc] initWithNumMessagesContextWindow: contextWindowLengthUInt systemPrompt: systemPrompt];
self->isFetching = NO;
resolve(@"Model and tokenizer loaded successfully");
return;
};

[modelFetcher startDownloadingFileFromURL:modelURL];
}];
}


- (void) runInference:(NSString *)input resolve:(RCTPromiseResolveBlock)resolve reject:(RCTPromiseRejectBlock)reject {
[conversationManager addResponse:input senderRole:ChatRole::USER];
NSString *prompt = [conversationManager getConversation];

dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{
NSError *error = nil;
[self->runner generate:prompt withTokenCallback:^(NSString *token) {
[self onResult:token prompt:prompt];
} error:&error];

// make sure to add eot token once generation is done
if (![self->tempLlamaResponse hasSuffix:END_OF_TEXT_TOKEN_NS]) {
[self onResult:END_OF_TEXT_TOKEN_NS prompt:prompt];
}

if (self->tempLlamaResponse) {
[self->conversationManager addResponse:self->tempLlamaResponse senderRole:ChatRole::ASSISTANT];
self->tempLlamaResponse = [NSMutableString string];
}

if (error) {
reject(@"error_in_generation", error.localizedDescription, nil);
return;
}
resolve(@"Inference completed successfully");
return;
});
[conversationManager addResponse:input senderRole:ChatRole::USER];
NSString *prompt = [conversationManager getConversation];
dispatch_async(dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT, 0), ^{
NSError *error = nil;
[self->runner generate:prompt withTokenCallback:^(NSString *token) {
[self onResult:token prompt:prompt];
} error:&error];
// make sure to add eot token once generation is done
if (![self->tempLlamaResponse hasSuffix:END_OF_TEXT_TOKEN_NS]) {
[self onResult:END_OF_TEXT_TOKEN_NS prompt:prompt];
}
if (self->tempLlamaResponse) {
[self->conversationManager addResponse:self->tempLlamaResponse senderRole:ChatRole::ASSISTANT];
self->tempLlamaResponse = [NSMutableString string];
}
if (error) {
reject(@"error_in_generation", error.localizedDescription, nil);
return;
}
resolve(@"Inference completed successfully");
return;
});
}


Expand Down
52 changes: 18 additions & 34 deletions ios/RnExecutorch/StyleTransfer.mm
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#import "StyleTransfer.h"
#import "utils/Fetcher.h"
#import "models/BaseModel.h"
#import "utils/ETError.h"
#import "ImageProcessor.h"
#import <ExecutorchLib/ETModel.h>
#import <React/RCTBridgeModule.h>
#import "models/StyleTransferModel.h"
#include <string>
#import <opencv2/opencv.hpp>

@implementation StyleTransfer {
StyleTransferModel* model;
Expand All @@ -22,52 +24,34 @@ - (void)loadModule:(NSString *)modelSource
return;
}

NSError *error = [NSError
errorWithDomain:@"StyleTransferErrorDomain"
code:[errorCode intValue]
userInfo:@{
NSLocalizedDescriptionKey : [NSString
stringWithFormat:@"%ld", (long)[errorCode longValue]]
}];

reject(@"init_module_error", error.localizedDescription, error);
reject(@"init_module_error", [NSString
stringWithFormat:@"%ld", (long)[errorCode longValue]], nil);
return;
}];
}

- (std::shared_ptr<facebook::react::TurboModule>)getTurboModule:
(const facebook::react::ObjCTurboModule::InitParams &)params {
return std::make_shared<facebook::react::NativeStyleTransferSpecJSI>(params);
}

- (void)forward:(NSString *)input
resolve:(RCTPromiseResolveBlock)resolve
reject:(RCTPromiseRejectBlock)reject {
@try {
NSURL *url = [NSURL URLWithString:input];
NSData *data = [NSData dataWithContentsOfURL:url];
if (!data) {
reject(@"img_loading_error", @"Unable to load image data", nil);
return;
}
UIImage *inputImage = [UIImage imageWithData:data];

UIImage* result = [model runModel:inputImage];

// save img to tmp dir, return URI
NSString *outputPath = [NSTemporaryDirectory() stringByAppendingPathComponent:[@"test" stringByAppendingString:@".png"]];
if ([UIImagePNGRepresentation(result) writeToFile:outputPath atomically:YES]) {
NSURL *fileURL = [NSURL fileURLWithPath:outputPath];
resolve([fileURL absoluteString]);
} else {
reject(@"img_write_error", @"Failed to write processed image to file", nil);
}
cv::Mat image = [ImageProcessor readImage:input];
cv::Mat resultImage = [model runModel:image];

NSString* tempFilePath = [ImageProcessor saveToTempFile:resultImage];
resolve(tempFilePath);
return;
} @catch (NSException *exception) {
NSLog(@"An exception occurred: %@, %@", exception.name, exception.reason);
reject(@"result_error", [NSString stringWithFormat:@"%@", exception.reason],
reject(@"forward_error", [NSString stringWithFormat:@"%@", exception.reason],
nil);
return;
}
}


- (std::shared_ptr<facebook::react::TurboModule>)getTurboModule:
(const facebook::react::ObjCTurboModule::InitParams &)params {
return std::make_shared<facebook::react::NativeStyleTransferSpecJSI>(params);
}

@end
6 changes: 3 additions & 3 deletions ios/RnExecutorch/models/BaseModel.h
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#import <Foundation/Foundation.h>
#import <UIKit/UIKit.h>
#import "ETModel.h"
#import "ExecutorchLib/ETModel.h"

@interface BaseModel : NSObject
{
@protected
ETModel *module;
ETModel *module;
}

- (NSArray *)forward:(NSArray *)input shape:(NSArray *)shape inputType:(NSNumber *)inputType error:(NSError **)error;
- (NSArray *)forward:(NSArray *)input;
- (void)loadModel:(NSURL *)modelURL completion:(void (^)(BOOL success, NSNumber *code))completion;

@end
Loading
Loading