diff --git a/src/hooks/computer_vision/useClassification.ts b/src/hooks/computer_vision/useClassification.ts index be109381..5a824556 100644 --- a/src/hooks/computer_vision/useClassification.ts +++ b/src/hooks/computer_vision/useClassification.ts @@ -16,7 +16,7 @@ interface ClassificationModule { export const useClassification = ({ modelSource, }: Props): ClassificationModule => { - const [module, _] = useState(() => new _ClassificationModule()); + const [module, _] = useState(() => _ClassificationModule); const { error, isReady, diff --git a/src/hooks/computer_vision/useObjectDetection.ts b/src/hooks/computer_vision/useObjectDetection.ts index 9a5ccd12..e08d757e 100644 --- a/src/hooks/computer_vision/useObjectDetection.ts +++ b/src/hooks/computer_vision/useObjectDetection.ts @@ -17,7 +17,7 @@ interface ObjectDetectionModule { export const useObjectDetection = ({ modelSource, }: Props): ObjectDetectionModule => { - const [module, _] = useState(() => new _ObjectDetectionModule()); + const [module, _] = useState(() => _ObjectDetectionModule); const { error, isReady, diff --git a/src/hooks/computer_vision/useStyleTransfer.ts b/src/hooks/computer_vision/useStyleTransfer.ts index 12c45ff4..5e7624bf 100644 --- a/src/hooks/computer_vision/useStyleTransfer.ts +++ b/src/hooks/computer_vision/useStyleTransfer.ts @@ -16,7 +16,7 @@ interface StyleTransferModule { export const useStyleTransfer = ({ modelSource, }: Props): StyleTransferModule => { - const [module, _] = useState(() => new _StyleTransferModule()); + const [module, _] = useState(() => _StyleTransferModule); const { error, isReady, diff --git a/src/hooks/general/useExecutorchModule.ts b/src/hooks/general/useExecutorchModule.ts index 09dc1157..eb90383a 100644 --- a/src/hooks/general/useExecutorchModule.ts +++ b/src/hooks/general/useExecutorchModule.ts @@ -11,7 +11,7 @@ interface Props { export const useExecutorchModule = ({ modelSource, }: Props): ExecutorchModule => { - const [module] = useState(() => new _ETModule()); + const [module] = useState(() => _ETModule); const { error, isReady, diff --git a/src/modules/computer_vision/BaseModule.ts b/src/modules/computer_vision/BaseModule.ts index 9e4aee06..a09b0a5c 100644 --- a/src/modules/computer_vision/BaseModule.ts +++ b/src/modules/computer_vision/BaseModule.ts @@ -2,13 +2,7 @@ import { Image } from 'react-native'; import { getError } from '../../Error'; export class BaseModule { - protected module: any; - - constructor(module: any) { - this.module = module; - } - - async loadModule(modelSource: string | number) { + static async load(module: any, modelSource: string | number) { if (!modelSource) return; let path = modelSource; @@ -18,15 +12,15 @@ export class BaseModule { } try { - await this.module.loadModule(path); + await module.loadModule(path); } catch (e) { throw new Error(getError(e)); } } - async forward(input: string) { + static async forward(module: any, input: string) { try { - return await this.module.forward(input); + return await module.forward(input); } catch (e) { throw new Error(getError(e)); } diff --git a/src/modules/computer_vision/ClassificationModule.ts b/src/modules/computer_vision/ClassificationModule.ts index 2b9e989d..c21e4594 100644 --- a/src/modules/computer_vision/ClassificationModule.ts +++ b/src/modules/computer_vision/ClassificationModule.ts @@ -1,8 +1,12 @@ import { BaseModule } from './BaseModule'; import { _ClassificationModule } from '../../native/RnExecutorchModules'; -export class ClassificationModule extends BaseModule { - constructor() { - super(new _ClassificationModule()); +export class ClassificationModule { + static async load(modelSource: string | number) { + await BaseModule.load(_ClassificationModule, modelSource); + } + + static async forward(input: string): Promise<{ [category: string]: number }> { + return await BaseModule.forward(_ClassificationModule, input); } } diff --git a/src/modules/computer_vision/ObjectDetectionModule.ts b/src/modules/computer_vision/ObjectDetectionModule.ts index fa5f7589..3c3dd9dd 100644 --- a/src/modules/computer_vision/ObjectDetectionModule.ts +++ b/src/modules/computer_vision/ObjectDetectionModule.ts @@ -1,8 +1,13 @@ import { BaseModule } from './BaseModule'; import { _ObjectDetectionModule } from '../../native/RnExecutorchModules'; +import { Detection } from '../../types/object_detection'; -export class ObjectDetectionModule extends BaseModule { - constructor() { - super(new _ObjectDetectionModule()); +export class ObjectDetectionModule { + static async load(modelSource: string | number) { + await BaseModule.load(_ObjectDetectionModule, modelSource); + } + + static async forward(input: string): Promise { + return await BaseModule.forward(_ObjectDetectionModule, input); } } diff --git a/src/modules/computer_vision/StyleTransferModule.ts b/src/modules/computer_vision/StyleTransferModule.ts index 830bd057..d7b55e14 100644 --- a/src/modules/computer_vision/StyleTransferModule.ts +++ b/src/modules/computer_vision/StyleTransferModule.ts @@ -1,8 +1,12 @@ import { BaseModule } from './BaseModule'; import { _StyleTransferModule } from '../../native/RnExecutorchModules'; -export class StyleTransfer extends BaseModule { - constructor() { - super(new _StyleTransferModule()); +export class StyleTransferModule { + static async load(modelSource: string | number) { + await BaseModule.load(_StyleTransferModule, modelSource); + } + + static async forward(input: string): Promise { + return await BaseModule.forward(_StyleTransferModule, input); } } diff --git a/src/modules/general/ExecutorchModule.ts b/src/modules/general/ExecutorchModule.ts index 971f36e8..ca4f3389 100644 --- a/src/modules/general/ExecutorchModule.ts +++ b/src/modules/general/ExecutorchModule.ts @@ -4,9 +4,7 @@ import { _ETModule } from '../../native/RnExecutorchModules'; import { ETInput, getTypeIdentifier } from '../../types/common'; export class ExecutorchModule { - private module = new _ETModule(); - - async loadModule(modelSource: string) { + static async load(modelSource: string) { if (!modelSource) return; let path = modelSource; @@ -16,13 +14,13 @@ export class ExecutorchModule { } try { - await this.module.loadModule(path); + await _ETModule.loadModule(path); } catch (e) { throw new Error(getError(e)); } } - async forward(input: ETInput, shape: number[]) { + static async forward(input: ETInput, shape: number[]) { const inputType = getTypeIdentifier(input); if (inputType === -1) { throw new Error(getError(ETError.InvalidArgument)); @@ -30,21 +28,21 @@ export class ExecutorchModule { try { const numberArray = [...input] as number[]; - return await this.module.forward(numberArray, shape, inputType); + return await _ETModule.forward(numberArray, shape, inputType); } catch (e) { throw new Error(getError(e)); } } - async loadMethod(methodName: string) { + static async loadMethod(methodName: string) { try { - await this.module.loadMethod(methodName); + await _ETModule.loadMethod(methodName); } catch (e) { throw new Error(getError(e)); } } - async loadForward() { + static async loadForward() { await this.loadMethod('forward'); } } diff --git a/src/modules/natural_language_processing/LLMModule.ts b/src/modules/natural_language_processing/LLMModule.ts index af0a616d..c0638582 100644 --- a/src/modules/natural_language_processing/LLMModule.ts +++ b/src/modules/natural_language_processing/LLMModule.ts @@ -3,7 +3,7 @@ import { Image } from 'react-native'; import { ResourceSource } from '../../types/common'; export class LLMModule { - async loadModel( + static async load( modelSource: ResourceSource, tokenizerSource: ResourceSource, systemPrompt?: string, @@ -32,7 +32,7 @@ export class LLMModule { } } - async generate(input: string): Promise { + static async generate(input: string) { try { await LLM.runInference(input); } catch (err) { @@ -40,19 +40,19 @@ export class LLMModule { } } - onDownloadProgress(callback: (data: number) => void) { + static onDownloadProgress(callback: (data: number) => void) { return LLM.onDownloadProgress(callback); } - onToken(callback: (data: string | undefined) => void) { + static onToken(callback: (data: string | undefined) => void) { return LLM.onToken(callback); } - interrupt() { + static interrupt() { LLM.interrupt(); } - deleteModule() { + static deleteModule() { LLM.deleteModule(); } } diff --git a/src/native/RnExecutorchModules.ts b/src/native/RnExecutorchModules.ts index 8a80b595..ca79367b 100644 --- a/src/native/RnExecutorchModules.ts +++ b/src/native/RnExecutorchModules.ts @@ -72,44 +72,44 @@ const StyleTransfer = StyleTransferSpec ); class _ObjectDetectionModule { - async forward(input: string) { + static async forward(input: string) { return await ObjectDetection.forward(input); } - async loadModule(modelSource: string | number) { + static async loadModule(modelSource: string | number) { return await ObjectDetection.loadModule(modelSource); } } class _StyleTransferModule { - async forward(input: string) { + static async forward(input: string) { return await StyleTransfer.forward(input); } - async loadModule(modelSource: string | number) { + static async loadModule(modelSource: string | number) { return await StyleTransfer.loadModule(modelSource); } } class _ClassificationModule { - async forward(input: string) { + static async forward(input: string) { return await Classification.forward(input); } - async loadModule(modelSource: string | number) { + static async loadModule(modelSource: string | number) { return await Classification.loadModule(modelSource); } } class _ETModule { - async forward( + static async forward( input: number[], shape: number[], inputType: number ): Promise { return await ETModule.forward(input, shape, inputType); } - async loadModule(modelSource: string) { + static async loadModule(modelSource: string) { return await ETModule.loadModule(modelSource); } - async loadMethod(methodName: string): Promise { + static async loadMethod(methodName: string): Promise { return await ETModule.loadMethod(methodName); } }