Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Make third-party/ExecuTorchLib's forward() accept multiple inputs #83

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions android/src/main/java/com/swmansion/rnexecutorch/ETModule.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ import com.swmansion.rnexecutorch.utils.ArrayUtils
import com.swmansion.rnexecutorch.utils.ETError
import com.swmansion.rnexecutorch.utils.Fetcher
import com.swmansion.rnexecutorch.utils.TensorUtils
import org.pytorch.executorch.EValue
import org.pytorch.executorch.Module

class ETModule(reactContext: ReactApplicationContext) : NativeETModuleSpec(reactContext) {
private lateinit var module: Module

private var reactApplicationContext = reactContext;
override fun getName(): String {
return NAME
}
Expand Down Expand Up @@ -44,26 +45,40 @@ class ETModule(reactContext: ReactApplicationContext) : NativeETModuleSpec(react
}

override fun forward(
input: ReadableArray,
shape: ReadableArray,
inputType: Double,
inputs: ReadableArray,
shapes: ReadableArray,
inputTypes: ReadableArray,
promise: Promise
) {
val inputEValues = ArrayList<EValue>()
try {
val executorchInput =
TensorUtils.getExecutorchInput(input, ArrayUtils.createLongArray(shape), inputType.toInt())
for (i in 0 until inputs.size()) {
val currentInput = inputs.getArray(i)
?: throw Exception(ETError.InvalidArgument.code.toString())
val currentShape = shapes.getArray(i)
?: throw Exception(ETError.InvalidArgument.code.toString())
val currentInputType = inputTypes.getInt(i)

val result = module.forward(executorchInput)
val resultArray = Arguments.createArray()
val currentEValue = TensorUtils.getExecutorchInput(
currentInput,
ArrayUtils.createLongArray(currentShape),
currentInputType
)

for (evalue in result) {
resultArray.pushArray(ArrayUtils.createReadableArray(evalue.toTensor()))
inputEValues.add(currentEValue)
}

promise.resolve(resultArray)
return
val forwardOutputs = module.forward(*inputEValues.toTypedArray());
val outputArray = Arguments.createArray()

for (output in forwardOutputs) {
val arr = ArrayUtils.createReadableArrayFromTensor(output.toTensor())
outputArray.pushArray(arr)
}
promise.resolve(outputArray)

} catch (e: IllegalArgumentException) {
//The error is thrown when transformation to Tensor fails
// The error is thrown when transformation to Tensor fails
promise.reject("Forward Failed Execution", ETError.InvalidArgument.code.toString())
return
} catch (e: Exception) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,82 +7,52 @@ import org.pytorch.executorch.Tensor

class ArrayUtils {
companion object {
fun createByteArray(input: ReadableArray): ByteArray {
val byteArray = ByteArray(input.size())
for (i in 0 until input.size()) {
byteArray[i] = input.getInt(i).toByte()
}
return byteArray
private inline fun <reified T> createTypedArrayFromReadableArray(input: ReadableArray, transform: (ReadableArray, Int) -> T): Array<T> {
return Array(input.size()) { index -> transform(input, index) }
}

fun createByteArray(input: ReadableArray): ByteArray {
return createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index).toByte() }.toByteArray()
}
fun createIntArray(input: ReadableArray): IntArray {
val intArray = IntArray(input.size())
for (i in 0 until input.size()) {
intArray[i] = input.getInt(i)
}
return intArray
return createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index) }.toIntArray()
}

fun createFloatArray(input: ReadableArray): FloatArray {
val floatArray = FloatArray(input.size())
for (i in 0 until input.size()) {
floatArray[i] = input.getDouble(i).toFloat()
}
return floatArray
return createTypedArrayFromReadableArray(input) { array, index -> array.getDouble(index).toFloat() }.toFloatArray()
}

fun createLongArray(input: ReadableArray): LongArray {
val longArray = LongArray(input.size())
for (i in 0 until input.size()) {
longArray[i] = input.getInt(i).toLong()
}
return longArray
return createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index).toLong() }.toLongArray()
}

fun createDoubleArray(input: ReadableArray): DoubleArray {
val doubleArray = DoubleArray(input.size())
for (i in 0 until input.size()) {
doubleArray[i] = input.getDouble(i)
}
return doubleArray
return createTypedArrayFromReadableArray(input) { array, index -> array.getDouble(index) }.toDoubleArray()
}

fun createReadableArray(result: Tensor): ReadableArray {
fun createReadableArrayFromTensor(result: Tensor): ReadableArray {
val resultArray = Arguments.createArray()

when (result.dtype()) {
DType.UINT8 -> {
val byteArray = result.dataAsByteArray
for (i in byteArray) {
resultArray.pushInt(i.toInt())
}
result.dataAsByteArray.forEach { resultArray.pushInt(it.toInt()) }
}

DType.INT32 -> {
val intArray = result.dataAsIntArray
for (i in intArray) {
resultArray.pushInt(i)
}
result.dataAsIntArray.forEach { resultArray.pushInt(it) }
}

DType.FLOAT -> {
val longArray = result.dataAsFloatArray
for (i in longArray) {
resultArray.pushDouble(i.toDouble())
}
result.dataAsFloatArray.forEach { resultArray.pushDouble(it.toDouble()) }
}

DType.DOUBLE -> {
val floatArray = result.dataAsDoubleArray
for (i in floatArray) {
resultArray.pushDouble(i)
}
result.dataAsDoubleArray.forEach { resultArray.pushDouble(it) }
}

DType.INT64 -> {
val doubleArray = result.dataAsLongArray
for (i in doubleArray) {
resultArray.pushLong(i)
}
// TODO: Do something to handle or deprecate long dtype
// https://github.com/facebook/react-native/issues/12506
result.dataAsLongArray.forEach { resultArray.pushInt(it.toInt()) }
}

else -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,23 @@ class TensorUtils {
fun getExecutorchInput(input: ReadableArray, shape: LongArray, type: Int): EValue {
try {
when (type) {
0 -> {
1 -> {
val inputTensor = Tensor.fromBlob(ArrayUtils.createByteArray(input), shape)
return EValue.from(inputTensor)
}

1 -> {
3 -> {
val inputTensor = Tensor.fromBlob(ArrayUtils.createIntArray(input), shape)
return EValue.from(inputTensor)
}

2 -> {
4 -> {
val inputTensor = Tensor.fromBlob(ArrayUtils.createLongArray(input), shape)
return EValue.from(inputTensor)
}

3 -> {
6 -> {
val inputTensor = Tensor.fromBlob(ArrayUtils.createFloatArray(input), shape)
return EValue.from(inputTensor)
}

4 -> {
7 -> {
val inputTensor = Tensor.fromBlob(ArrayUtils.createDoubleArray(input), shape)
return EValue.from(inputTensor)
}
Expand Down
64 changes: 35 additions & 29 deletions ios/RnExecutorch/ETModule.mm
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#import "ETModule.h"
#import "utils/Fetcher.h"
#import <ExecutorchLib/ETModel.h>
#include <Foundation/Foundation.h>
#import <React/RCTBridgeModule.h>
#include <string>

Expand All @@ -16,44 +17,49 @@ - (void)loadModule:(NSString *)modelSource
if (!module) {
module = [[ETModel alloc] init];
}

[Fetcher fetchResource:[NSURL URLWithString:modelSource]
resourceType:ResourceType::MODEL
completionHandler:^(NSString *filePath, NSError *error) {
if (error) {
reject(@"init_module_error", @"-1", nil);
return;
}
NSNumber *result = [self->module loadModel:filePath];
if ([result isEqualToNumber:@(0)]) {
resolve(result);
} else {
NSError *error = [NSError
errorWithDomain:@"ETModuleErrorDomain"
code:[result intValue]
userInfo:@{
NSLocalizedDescriptionKey : [NSString
stringWithFormat:@"%ld", (long)[result longValue]]
}];
reject(@"init_module_error", error.localizedDescription, error);
}
}];
if (error) {
reject(@"init_module_error", @"-1", nil);
return;
}

NSNumber *result = [self->module loadModel:filePath];
if ([result isEqualToNumber:@(0)]) {
resolve(result);
} else {
NSError *error = [NSError
errorWithDomain:@"ETModuleErrorDomain"
code:[result intValue]
userInfo:@{
NSLocalizedDescriptionKey : [NSString
stringWithFormat:@"%ld", (long)[result longValue]]
}];

reject(@"init_module_error", error.localizedDescription, error);
}
}];
}

- (void)forward:(NSArray *)input
shape:(NSArray *)shape
inputType:(double)inputType
- (void)forward:(NSArray *)inputs
shapes:(NSArray *)shapes
inputTypes:(NSArray *)inputTypes
resolve:(RCTPromiseResolveBlock)resolve
reject:(RCTPromiseRejectBlock)reject {
@try {
NSArray *result = [module forward:input shape:shape inputType:[NSNumber numberWithInt:inputType]];
NSArray *result = [module forward:inputs
shapes:shapes
inputTypes:inputTypes];
resolve(result);
} @catch (NSException *exception) {
NSLog(@"An exception occurred: %@, %@", exception.name, exception.reason);
reject(@"result_error", [NSString stringWithFormat:@"%@", exception.reason],
nil);
NSLog(@"An exception occurred in forward: %@, %@", exception.name,
exception.reason);
reject(
@"forward_error",
[NSString stringWithFormat:@"An error occurred: %@", exception.reason],
nil);
}
}

Expand All @@ -69,7 +75,7 @@ - (void)loadMethod:(NSString *)methodName
}

- (std::shared_ptr<facebook::react::TurboModule>)getTurboModule:
(const facebook::react::ObjCTurboModule::InitParams &)params {
(const facebook::react::ObjCTurboModule::InitParams &)params {
return std::make_shared<facebook::react::NativeETModuleSpecJSI>(params);
}

Expand Down
13 changes: 9 additions & 4 deletions ios/RnExecutorch/models/BaseModel.h
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
#import "ExecutorchLib/ETModel.h"
#import <Foundation/Foundation.h>
#import <UIKit/UIKit.h>
#import "ExecutorchLib/ETModel.h"

@interface BaseModel : NSObject
{
@interface BaseModel : NSObject {
@protected
ETModel *module;
}

- (NSArray *)forward:(NSArray *)input;
- (void)loadModel:(NSURL *)modelURL completion:(void (^)(BOOL success, NSNumber *code))completion;

- (NSArray *)forward:(NSArray *)inputs
shapes:(NSArray *)shapes
inputTypes:(NSArray *)inputTypes;

- (void)loadModel:(NSURL *)modelURL
completion:(void (^)(BOOL success, NSNumber *code))completion;

@end
51 changes: 35 additions & 16 deletions ios/RnExecutorch/models/BaseModel.mm
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,45 @@
@implementation BaseModel

- (NSArray *)forward:(NSArray *)input {
NSArray *result = [module forward:input shape:[module getInputShape:@0] inputType:[module getInputType: @0]];
NSMutableArray *shapes = [NSMutableArray new];
NSMutableArray *inputTypes = [NSMutableArray new];
NSNumber *numberOfInputs = [module getNumberOfInputs];

for (NSUInteger i = 0; i < [numberOfInputs intValue]; i++) {
[shapes addObject:[module getInputShape:[NSNumber numberWithInt:i]]];
[inputTypes addObject:[module getInputType:[NSNumber numberWithInt:i]]];
}

NSArray *result = [module forward:input shapes:shapes inputTypes:inputTypes];
return result;
}

- (NSArray *)forward:(NSArray *)inputs
shapes:(NSArray *)shapes
inputTypes:(NSArray *)inputTypes {
NSArray *result = [module forward:inputs shapes:shapes inputTypes:inputTypes];
return result;
}

- (void)loadModel:(NSURL *)modelURL completion:(void (^)(BOOL success, NSNumber* code))completion {
- (void)loadModel:(NSURL *)modelURL
completion:(void (^)(BOOL success, NSNumber *code))completion {
module = [[ETModel alloc] init];
[Fetcher fetchResource:modelURL resourceType:ResourceType::MODEL completionHandler:^(NSString *filePath, NSError *error) {
if (error) {
completion(NO, @(InvalidModelSource));
return;
}
NSNumber *result = [self->module loadModel: filePath];
if([result intValue] != 0){
completion(NO, result);
return;
}

completion(YES, result);
return;
}];
[Fetcher fetchResource:modelURL
resourceType:ResourceType::MODEL
completionHandler:^(NSString *filePath, NSError *error) {
if (error) {
completion(NO, @(InvalidModelSource));
return;
}
NSNumber *result = [self->module loadModel:filePath];
if ([result intValue] != 0) {
completion(NO, result);
return;
}

completion(YES, result);
return;
}];
}

@end
6 changes: 3 additions & 3 deletions src/native/NativeETModule.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ export interface Spec extends TurboModule {
loadModule(modelSource: string): Promise<number>;

forward(
input: number[],
shape: number[],
inputType: number
inputs: number[],
shapes: number[],
inputTypes: number[]
): Promise<number[]>;
loadMethod(methodName: string): Promise<number>;
}
Expand Down
Loading