diff --git a/docs/docs/computer-vision/useClassification.mdx b/docs/docs/computer-vision/useClassification.mdx
index cb39d07..1043088 100644
--- a/docs/docs/computer-vision/useClassification.mdx
+++ b/docs/docs/computer-vision/useClassification.mdx
@@ -31,20 +31,6 @@ try {
}
```
-
-Type definitions
-
-```typescript
-interface ClassificationModule {
- error: string | null;
- isReady: boolean;
- isGenerating: boolean;
- forward: (input: string) => Promise<{ [category: string]: number }>;
-}
-```
-
-
-
### Arguments
**`modelSource`**
diff --git a/docs/docs/computer-vision/useObjectDetection.mdx b/docs/docs/computer-vision/useObjectDetection.mdx
index 0e40076..5de3da4 100644
--- a/docs/docs/computer-vision/useObjectDetection.mdx
+++ b/docs/docs/computer-vision/useObjectDetection.mdx
@@ -11,6 +11,7 @@ It is recommended to use models provided by us, which are available at our [Hugg
:::
## Reference
+
```jsx
import { useObjectDetection, SSDLITE_320_MOBILENET_V3_LARGE } from 'react-native-executorch';
@@ -45,42 +46,36 @@ interface Detection {
label: keyof typeof CocoLabel;
score: number;
}
-
-interface ObjectDetectionModule {
- error: string | null;
- isReady: boolean;
- isGenerating: boolean;
- forward: (input: string) => Promise;
-}
```
+
### Arguments
`modelSource`
-A string that specifies the path to the model file. You can download the model from our [HuggingFace repository](https://huggingface.co/software-mansion/react-native-executorch-ssdlite320-mobilenet-v3-large/tree/main).
+A string that specifies the path to the model file. You can download the model from our [HuggingFace repository](https://huggingface.co/software-mansion/react-native-executorch-ssdlite320-mobilenet-v3-large/tree/main).
For more information on that topic, you can check out the [Loading models](https://docs.swmansion.com/react-native-executorch/fundamentals/loading-models) page.
### Returns
The hook returns an object with the following properties:
-
-| **Field** | **Type** | **Description** |
-|-----------------------|---------------------------------------|------------------------------------------------------------------------------------------------------------------|
-| `forward` | `(input: string) => Promise` | A function that accepts an image (url, b64) and returns an array of `Detection` objects. |
-| `error` | string | null
| Contains the error message if the model loading failed. |
-| `isGenerating` | `boolean` | Indicates whether the model is currently processing an inference. |
-| `isReady` | `boolean` | Indicates whether the model has successfully loaded and is ready for inference. |
-
+| Field | Type | Description |
+| -------------- | ----------------------------------------- | ---------------------------------------------------------------------------------------- |
+| `forward` | `(input: string) => Promise` | A function that accepts an image (url, b64) and returns an array of `Detection` objects. |
+| `error` | string | null
| Contains the error message if the model loading failed. |
+| `isGenerating` | `boolean` | Indicates whether the model is currently processing an inference. |
+| `isReady` | `boolean` | Indicates whether the model has successfully loaded and is ready for inference. |
## Running the model
To run the model, you can use the `forward` method. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image. The function returns an array of `Detection` objects. Each object contains coordinates of the bounding box, the label of the detected object, and the confidence score. For more information, please refer to the reference or type definitions.
## Detection object
+
The detection object is specified as follows:
+
```typescript
interface Bbox {
x1: number;
@@ -95,14 +90,17 @@ interface Detection {
score: number;
}
```
+
The `bbox` property contains information about the bounding box of detected objects. It is represented as two points: one at the bottom-left corner of the bounding box (`x1`, `y1`) and the other at the top-right corner (`x2`, `y2`).
The `label` property contains the name of the detected object, which corresponds to one of the `CocoLabels`. The `score` represents the confidence score of the detected object.
-
-
## Example
+
```tsx
-import { useObjectDetection, SSDLITE_320_MOBILENET_V3_LARGE } from 'react-native-executorch';
+import {
+ useObjectDetection,
+ SSDLITE_320_MOBILENET_V3_LARGE,
+} from 'react-native-executorch';
function App() {
const ssdlite = useObjectDetection({
@@ -110,18 +108,19 @@ function App() {
});
const runModel = async () => {
- const detections = await ssdlite.forward("https://url-to-image.jpg");
+ const detections = await ssdlite.forward('https://url-to-image.jpg');
+
for (const detection of detections) {
- console.log("Bounding box: ", detection.bbox);
- console.log("Bounding label: ", detection.label);
- console.log("Bounding score: ", detection.score);
+ console.log('Bounding box: ', detection.bbox);
+ console.log('Bounding label: ', detection.label);
+ console.log('Bounding score: ', detection.score);
}
- }
+ };
}
```
## Supported models
-| Model | Number of classes | Class list |
-| --------------------------------------------------------------------------------------------------------------- | ----------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| [SSDLite320 MobileNetV3 Large](https://pytorch.org/vision/main/models/generated/torchvision.models.detection.ssdlite320_mobilenet_v3_large.html#torchvision.models.detection.SSDLite320_MobileNet_V3_Large_Weights) | 91 | [COCO](https://github.com/software-mansion/react-native-executorch/blob/69802ee1ca161d9df00def1dabe014d36341cfa9/src/types/object_detection.ts#L14) |
\ No newline at end of file
+| Model | Number of classes | Class list |
+| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------- | --------------------------------------------------------------------------------------------------------------------------------------------------- |
+| [SSDLite320 MobileNetV3 Large](https://pytorch.org/vision/main/models/generated/torchvision.models.detection.ssdlite320_mobilenet_v3_large.html#torchvision.models.detection.SSDLite320_MobileNet_V3_Large_Weights) | 91 | [COCO](https://github.com/software-mansion/react-native-executorch/blob/69802ee1ca161d9df00def1dabe014d36341cfa9/src/types/object_detection.ts#L14) |
diff --git a/docs/docs/computer-vision/useStyleTransfer.mdx b/docs/docs/computer-vision/useStyleTransfer.mdx
index b6f4f4b..c5a5e3e 100644
--- a/docs/docs/computer-vision/useStyleTransfer.mdx
+++ b/docs/docs/computer-vision/useStyleTransfer.mdx
@@ -30,20 +30,6 @@ try {
}
```
-
-Type definitions
-
-```typescript
-interface StyleTransferModule {
- error: string | null;
- isReady: boolean;
- isGenerating: boolean;
- forward: (input: string) => Promise;
-}
-```
-
-
-
### Arguments
**`modelSource`**
diff --git a/examples/computer-vision/components/ImageWithBboxes.tsx b/examples/computer-vision/components/ImageWithBboxes.tsx
index 7d08e33..65d345d 100644
--- a/examples/computer-vision/components/ImageWithBboxes.tsx
+++ b/examples/computer-vision/components/ImageWithBboxes.tsx
@@ -68,13 +68,7 @@ export default function ImageWithBboxes({
const height = (y2 - y1) * scaleY;
return (
-
+
{detection.label} ({(detection.score * 100).toFixed(1)}%)
@@ -98,6 +92,7 @@ const styles = StyleSheet.create({
bbox: {
position: 'absolute',
borderWidth: 2,
+ borderColor: 'red',
},
label: {
position: 'absolute',
diff --git a/examples/computer-vision/screens/ObjectDetectionScreen.tsx b/examples/computer-vision/screens/ObjectDetectionScreen.tsx
index 280e3c5..ea82dfd 100644
--- a/examples/computer-vision/screens/ObjectDetectionScreen.tsx
+++ b/examples/computer-vision/screens/ObjectDetectionScreen.tsx
@@ -76,7 +76,7 @@ export const ObjectDetectionScreen = ({
/>
) : (
@@ -127,4 +127,8 @@ const styles = StyleSheet.create({
flex: 1,
marginRight: 4,
},
+ fullSizeImage: {
+ width: '100%',
+ height: '100%',
+ },
});
diff --git a/examples/llama/screens/ChatScreen.tsx b/examples/llama/screens/ChatScreen.tsx
index 4d2f707..40d9e70 100644
--- a/examples/llama/screens/ChatScreen.tsx
+++ b/examples/llama/screens/ChatScreen.tsx
@@ -63,7 +63,7 @@ export default function ChatScreen() {
@@ -133,6 +133,9 @@ const styles = StyleSheet.create({
container: {
flex: 1,
},
+ keyboardAvoidingView: {
+ flex: 1,
+ },
topContainer: {
height: 68,
width: '100%',
diff --git a/src/models/Classification.ts b/src/hooks/computer_vision/useClassification.ts
similarity index 74%
rename from src/models/Classification.ts
rename to src/hooks/computer_vision/useClassification.ts
index 6bd8dfb..836a0fc 100644
--- a/src/models/Classification.ts
+++ b/src/hooks/computer_vision/useClassification.ts
@@ -1,21 +1,19 @@
import { useState } from 'react';
-import { _ClassificationModule } from '../native/RnExecutorchModules';
-import { useModule } from '../useModule';
+import { _ClassificationModule } from '../../native/RnExecutorchModules';
+import { useModule } from '../../useModule';
interface Props {
modelSource: string | number;
}
-interface ClassificationModule {
+export const useClassification = ({
+ modelSource,
+}: Props): {
error: string | null;
isReady: boolean;
isGenerating: boolean;
forward: (input: string) => Promise<{ [category: string]: number }>;
-}
-
-export const useClassification = ({
- modelSource,
-}: Props): ClassificationModule => {
+} => {
const [module, _] = useState(() => new _ClassificationModule());
const {
error,
diff --git a/src/models/ObjectDetection.ts b/src/hooks/computer_vision/useObjectDetection.ts
similarity index 67%
rename from src/models/ObjectDetection.ts
rename to src/hooks/computer_vision/useObjectDetection.ts
index fda2fd0..8456ee3 100644
--- a/src/models/ObjectDetection.ts
+++ b/src/hooks/computer_vision/useObjectDetection.ts
@@ -1,22 +1,20 @@
import { useState } from 'react';
-import { _ObjectDetectionModule } from '../native/RnExecutorchModules';
-import { useModule } from '../useModule';
-import { Detection } from '../types/object_detection';
+import { _ObjectDetectionModule } from '../../native/RnExecutorchModules';
+import { useModule } from '../../useModule';
+import { Detection } from '../../types/object_detection';
interface Props {
modelSource: string | number;
}
-interface ObjectDetectionModule {
+export const useObjectDetection = ({
+ modelSource,
+}: Props): {
error: string | null;
isReady: boolean;
isGenerating: boolean;
forward: (input: string) => Promise;
-}
-
-export const useObjectDetection = ({
- modelSource,
-}: Props): ObjectDetectionModule => {
+} => {
const [module, _] = useState(() => new _ObjectDetectionModule());
const {
error,
diff --git a/src/models/StyleTransfer.ts b/src/hooks/computer_vision/useStyleTransfer.ts
similarity index 73%
rename from src/models/StyleTransfer.ts
rename to src/hooks/computer_vision/useStyleTransfer.ts
index 215f5ae..20c400b 100644
--- a/src/models/StyleTransfer.ts
+++ b/src/hooks/computer_vision/useStyleTransfer.ts
@@ -1,21 +1,19 @@
import { useState } from 'react';
-import { _StyleTransferModule } from '../native/RnExecutorchModules';
-import { useModule } from '../useModule';
+import { _StyleTransferModule } from '../../native/RnExecutorchModules';
+import { useModule } from '../../useModule';
interface Props {
modelSource: string | number;
}
-interface StyleTransferModule {
+export const useStyleTransfer = ({
+ modelSource,
+}: Props): {
error: string | null;
isReady: boolean;
isGenerating: boolean;
forward: (input: string) => Promise;
-}
-
-export const useStyleTransfer = ({
- modelSource,
-}: Props): StyleTransferModule => {
+} => {
const [module, _] = useState(() => new _StyleTransferModule());
const {
error,
diff --git a/src/ETModule.ts b/src/hooks/general/useExecutorchModule.ts
similarity index 60%
rename from src/ETModule.ts
rename to src/hooks/general/useExecutorchModule.ts
index 416c1f4..5a180fd 100644
--- a/src/ETModule.ts
+++ b/src/hooks/general/useExecutorchModule.ts
@@ -1,8 +1,8 @@
import { useState } from 'react';
-import { _ETModule } from './native/RnExecutorchModules';
-import { getError } from './Error';
-import { ExecutorchModule } from './types/common';
-import { useModule } from './useModule';
+import { _ETModule } from '../../native/RnExecutorchModules';
+import { useModule } from '../../useModule';
+import { ETInput } from '../../types/common';
+import { getError } from '../../Error';
interface Props {
modelSource: string | number;
@@ -10,7 +10,14 @@ interface Props {
export const useExecutorchModule = ({
modelSource,
-}: Props): ExecutorchModule => {
+}: Props): {
+ error: string | null;
+ isReady: boolean;
+ isGenerating: boolean;
+ forward: (input: ETInput, shape: number[]) => Promise;
+ loadMethod: (methodName: string) => Promise;
+ loadForward: () => Promise;
+} => {
const [module] = useState(() => new _ETModule());
const {
error,
diff --git a/src/LLM.ts b/src/hooks/natural_language_processing/useLLM.ts
similarity index 95%
rename from src/LLM.ts
rename to src/hooks/natural_language_processing/useLLM.ts
index 4219fdc..386750c 100644
--- a/src/LLM.ts
+++ b/src/hooks/natural_language_processing/useLLM.ts
@@ -1,12 +1,12 @@
import { useCallback, useEffect, useRef, useState } from 'react';
import { EventSubscription, Image } from 'react-native';
-import { ResourceSource, Model } from './types/common';
+import { ResourceSource, Model } from '../../types/common';
import {
DEFAULT_CONTEXT_WINDOW_LENGTH,
DEFAULT_SYSTEM_PROMPT,
EOT_TOKEN,
-} from './constants/llamaDefaults';
-import { LLM } from './native/RnExecutorchModules';
+} from '../../constants/llamaDefaults';
+import { LLM } from '../../native/RnExecutorchModules';
const interrupt = () => {
LLM.interrupt();
diff --git a/src/index.tsx b/src/index.tsx
index 74cfd13..1c21b25 100644
--- a/src/index.tsx
+++ b/src/index.tsx
@@ -1,7 +1,23 @@
-export * from './ETModule';
-export * from './LLM';
-export * from './constants/modelUrls';
-export * from './models/Classification';
-export * from './models/ObjectDetection';
-export * from './models/StyleTransfer';
+// hooks
+export * from './hooks/computer_vision/useClassification';
+export * from './hooks/computer_vision/useObjectDetection';
+export * from './hooks/computer_vision/useStyleTransfer';
+
+export * from './hooks/natural_language_processing/useLLM';
+
+export * from './hooks/general/useExecutorchModule';
+
+// modules
+export * from './modules/computer_vision/ClassificationModule';
+export * from './modules/computer_vision/ObjectDetectionModule';
+export * from './modules/computer_vision/StyleTransferModule';
+
+export * from './modules/natural_language_processing/LLMModule';
+
+export * from './modules/general/ExecutorchModule';
+
+// types
export * from './types/object_detection';
+
+// constants
+export * from './constants/modelUrls';
diff --git a/src/modules/BaseModule.ts b/src/modules/BaseModule.ts
new file mode 100644
index 0000000..a93591c
--- /dev/null
+++ b/src/modules/BaseModule.ts
@@ -0,0 +1,36 @@
+import { Image } from 'react-native';
+import {
+ _StyleTransferModule,
+ _ObjectDetectionModule,
+ _ClassificationModule,
+ _ETModule,
+} from '../native/RnExecutorchModules';
+import { ResourceSource } from '../types/common';
+import { getError } from '../Error';
+
+export class BaseModule {
+ static module:
+ | _StyleTransferModule
+ | _ObjectDetectionModule
+ | _ClassificationModule
+ | _ETModule;
+
+ static async load(modelSource: ResourceSource) {
+ if (!modelSource) return;
+
+ let path =
+ typeof modelSource === 'number'
+ ? Image.resolveAssetSource(modelSource).uri
+ : modelSource;
+
+ try {
+ await this.module.loadModule(path);
+ } catch (e) {
+ throw new Error(getError(e));
+ }
+ }
+
+ static async forward(..._: any[]): Promise {
+ throw new Error('The forward method is not implemented.');
+ }
+}
diff --git a/src/modules/computer_vision/BaseCVModule.ts b/src/modules/computer_vision/BaseCVModule.ts
new file mode 100644
index 0000000..c61987d
--- /dev/null
+++ b/src/modules/computer_vision/BaseCVModule.ts
@@ -0,0 +1,22 @@
+import { BaseModule } from '../BaseModule';
+import {
+ _StyleTransferModule,
+ _ObjectDetectionModule,
+ _ClassificationModule,
+} from '../../native/RnExecutorchModules';
+import { getError } from '../../Error';
+
+export class BaseCVModule extends BaseModule {
+ static module:
+ | _StyleTransferModule
+ | _ObjectDetectionModule
+ | _ClassificationModule;
+
+ static async forward(input: string) {
+ try {
+ return await this.module.forward(input);
+ } catch (e) {
+ throw new Error(getError(e));
+ }
+ }
+}
diff --git a/src/modules/computer_vision/ClassificationModule.ts b/src/modules/computer_vision/ClassificationModule.ts
new file mode 100644
index 0000000..2c6392c
--- /dev/null
+++ b/src/modules/computer_vision/ClassificationModule.ts
@@ -0,0 +1,12 @@
+import { BaseCVModule } from './BaseCVModule';
+import { _ClassificationModule } from '../../native/RnExecutorchModules';
+
+export class ClassificationModule extends BaseCVModule {
+ static module = new _ClassificationModule();
+
+ static async forward(input: string) {
+ return await (super.forward(input) as ReturnType<
+ _ClassificationModule['forward']
+ >);
+ }
+}
diff --git a/src/modules/computer_vision/ObjectDetectionModule.ts b/src/modules/computer_vision/ObjectDetectionModule.ts
new file mode 100644
index 0000000..c50ce02
--- /dev/null
+++ b/src/modules/computer_vision/ObjectDetectionModule.ts
@@ -0,0 +1,12 @@
+import { BaseCVModule } from './BaseCVModule';
+import { _ObjectDetectionModule } from '../../native/RnExecutorchModules';
+
+export class ObjectDetectionModule extends BaseCVModule {
+ static module = new _ObjectDetectionModule();
+
+ static async forward(input: string) {
+ return await (super.forward(input) as ReturnType<
+ _ObjectDetectionModule['forward']
+ >);
+ }
+}
diff --git a/src/modules/computer_vision/StyleTransferModule.ts b/src/modules/computer_vision/StyleTransferModule.ts
new file mode 100644
index 0000000..830a8c5
--- /dev/null
+++ b/src/modules/computer_vision/StyleTransferModule.ts
@@ -0,0 +1,12 @@
+import { BaseCVModule } from './BaseCVModule';
+import { _StyleTransferModule } from '../../native/RnExecutorchModules';
+
+export class StyleTransferModule extends BaseCVModule {
+ static module = new _StyleTransferModule();
+
+ static async forward(input: string) {
+ return await (super.forward(input) as ReturnType<
+ _StyleTransferModule['forward']
+ >);
+ }
+}
diff --git a/src/modules/general/ExecutorchModule.ts b/src/modules/general/ExecutorchModule.ts
new file mode 100644
index 0000000..5d5990c
--- /dev/null
+++ b/src/modules/general/ExecutorchModule.ts
@@ -0,0 +1,34 @@
+import { BaseModule } from '../BaseModule';
+import { ETError, getError } from '../../Error';
+import { _ETModule } from '../../native/RnExecutorchModules';
+import { ETInput, getTypeIdentifier } from '../../types/common';
+
+export class ExecutorchModule extends BaseModule {
+ static module = new _ETModule();
+
+ static async forward(input: ETInput, shape: number[]) {
+ const inputType = getTypeIdentifier(input);
+ if (inputType === -1) {
+ throw new Error(getError(ETError.InvalidArgument));
+ }
+
+ try {
+ const numberArray = [...input] as number[];
+ return await this.module.forward(numberArray, shape, inputType);
+ } catch (e) {
+ throw new Error(getError(e));
+ }
+ }
+
+ static async loadMethod(methodName: string) {
+ try {
+ await this.module.loadMethod(methodName);
+ } catch (e) {
+ throw new Error(getError(e));
+ }
+ }
+
+ static async loadForward() {
+ await this.loadMethod('forward');
+ }
+}
diff --git a/src/modules/natural_language_processing/LLMModule.ts b/src/modules/natural_language_processing/LLMModule.ts
new file mode 100644
index 0000000..20914ed
--- /dev/null
+++ b/src/modules/natural_language_processing/LLMModule.ts
@@ -0,0 +1,61 @@
+import { LLM } from '../../native/RnExecutorchModules';
+import { Image } from 'react-native';
+import {
+ DEFAULT_CONTEXT_WINDOW_LENGTH,
+ DEFAULT_SYSTEM_PROMPT,
+} from '../../constants/llamaDefaults';
+import { ResourceSource } from '../../types/common';
+
+export class LLMModule {
+ static async load(
+ modelSource: ResourceSource,
+ tokenizerSource: ResourceSource,
+ systemPrompt = DEFAULT_SYSTEM_PROMPT,
+ contextWindowLength = DEFAULT_CONTEXT_WINDOW_LENGTH
+ ) {
+ try {
+ let modelUrl =
+ typeof modelSource === 'number'
+ ? Image.resolveAssetSource(modelSource).uri
+ : modelSource;
+
+ let tokenizerUrl =
+ typeof tokenizerSource === 'number'
+ ? Image.resolveAssetSource(tokenizerSource).uri
+ : tokenizerSource;
+
+ await LLM.loadLLM(
+ modelUrl,
+ tokenizerUrl,
+ systemPrompt,
+ contextWindowLength
+ );
+ } catch (err) {
+ throw new Error((err as Error).message);
+ }
+ }
+
+ static async generate(input: string) {
+ try {
+ await LLM.runInference(input);
+ } catch (err) {
+ throw new Error((err as Error).message);
+ }
+ }
+
+ static onDownloadProgress(callback: (data: number) => void) {
+ return LLM.onDownloadProgress(callback);
+ }
+
+ static onToken(callback: (data: string | undefined) => void) {
+ return LLM.onToken(callback);
+ }
+
+ static interrupt() {
+ LLM.interrupt();
+ }
+
+ static delete() {
+ LLM.deleteModule();
+ }
+}
diff --git a/src/native/NativeETModule.ts b/src/native/NativeETModule.ts
index 6d4bfd0..d04da1a 100644
--- a/src/native/NativeETModule.ts
+++ b/src/native/NativeETModule.ts
@@ -9,6 +9,7 @@ export interface Spec extends TurboModule {
shape: number[],
inputType: number
): Promise;
+
loadMethod(methodName: string): Promise;
}
diff --git a/src/native/RnExecutorchModules.ts b/src/native/RnExecutorchModules.ts
index 8a80b59..e898216 100644
--- a/src/native/RnExecutorchModules.ts
+++ b/src/native/RnExecutorchModules.ts
@@ -1,4 +1,8 @@
import { Platform } from 'react-native';
+import { Spec as ClassificationInterface } from './NativeClassification';
+import { Spec as ObjectDetectionInterface } from './NativeObjectDetection';
+import { Spec as StyleTransferInterface } from './NativeStyleTransfer';
+import { Spec as ETModuleInterface } from './NativeETModule';
const LINKING_ERROR =
`The package 'react-native-executorch' doesn't seem to be linked. Make sure: \n\n` +
@@ -72,28 +76,36 @@ const StyleTransfer = StyleTransferSpec
);
class _ObjectDetectionModule {
- async forward(input: string) {
+ async forward(
+ input: string
+ ): ReturnType {
return await ObjectDetection.forward(input);
}
- async loadModule(modelSource: string | number) {
+ async loadModule(
+ modelSource: string | number
+ ): ReturnType {
return await ObjectDetection.loadModule(modelSource);
}
}
class _StyleTransferModule {
- async forward(input: string) {
+ async forward(input: string): ReturnType {
return await StyleTransfer.forward(input);
}
- async loadModule(modelSource: string | number) {
+ async loadModule(
+ modelSource: string | number
+ ): ReturnType {
return await StyleTransfer.loadModule(modelSource);
}
}
class _ClassificationModule {
- async forward(input: string) {
+ async forward(input: string): ReturnType {
return await Classification.forward(input);
}
- async loadModule(modelSource: string | number) {
+ async loadModule(
+ modelSource: string | number
+ ): ReturnType {
return await Classification.loadModule(modelSource);
}
}
@@ -103,13 +115,17 @@ class _ETModule {
input: number[],
shape: number[],
inputType: number
- ): Promise {
+ ): ReturnType {
return await ETModule.forward(input, shape, inputType);
}
- async loadModule(modelSource: string) {
+ async loadModule(
+ modelSource: string
+ ): ReturnType {
return await ETModule.loadModule(modelSource);
}
- async loadMethod(methodName: string): Promise {
+ async loadMethod(
+ methodName: string
+ ): ReturnType {
return await ETModule.loadMethod(methodName);
}
}
diff --git a/src/types/common.ts b/src/types/common.ts
index f12643d..ec0daa2 100644
--- a/src/types/common.ts
+++ b/src/types/common.ts
@@ -26,16 +26,17 @@ export type ETInput =
| Float32Array
| Float64Array;
-export interface ExecutorchModule {
- error: string | null;
- isReady: boolean;
- isGenerating: boolean;
- forward: (input: ETInput, shape: number[]) => Promise;
- loadMethod: (methodName: string) => Promise;
- loadForward: () => Promise;
-}
+export const getTypeIdentifier = (arr: ETInput): number => {
+ if (arr instanceof Int8Array) return 0;
+ if (arr instanceof Int32Array) return 1;
+ if (arr instanceof BigInt64Array) return 2;
+ if (arr instanceof Float32Array) return 3;
+ if (arr instanceof Float64Array) return 4;
+
+ return -1;
+};
-export type module =
+export type Module =
| _ClassificationModule
| _StyleTransferModule
| _ObjectDetectionModule
diff --git a/src/useModule.ts b/src/useModule.ts
index 66c2fd4..e4080a5 100644
--- a/src/useModule.ts
+++ b/src/useModule.ts
@@ -1,21 +1,11 @@
import { useEffect, useState } from 'react';
import { Image } from 'react-native';
import { ETError, getError } from './Error';
-import { ETInput, module } from './types/common';
-
-const getTypeIdentifier = (arr: ETInput): number => {
- if (arr instanceof Int8Array) return 0;
- if (arr instanceof Int32Array) return 1;
- if (arr instanceof BigInt64Array) return 2;
- if (arr instanceof Float32Array) return 3;
- if (arr instanceof Float64Array) return 4;
-
- return -1;
-};
+import { ETInput, Module, getTypeIdentifier } from './types/common';
interface Props {
modelSource: string | number;
- module: module;
+ module: Module;
}
interface _Module {