Skip to content

Commit

Permalink
Make classes static
Browse files Browse the repository at this point in the history
  • Loading branch information
jakmro committed Jan 15, 2025
1 parent 2c531ca commit 41d6b42
Show file tree
Hide file tree
Showing 11 changed files with 52 additions and 47 deletions.
2 changes: 1 addition & 1 deletion src/hooks/computer_vision/useClassification.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ interface ClassificationModule {
export const useClassification = ({
modelSource,
}: Props): ClassificationModule => {
const [module, _] = useState(() => new _ClassificationModule());
const [module, _] = useState(() => _ClassificationModule);
const {
error,
isReady,
Expand Down
2 changes: 1 addition & 1 deletion src/hooks/computer_vision/useObjectDetection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ interface ObjectDetectionModule {
export const useObjectDetection = ({
modelSource,
}: Props): ObjectDetectionModule => {
const [module, _] = useState(() => new _ObjectDetectionModule());
const [module, _] = useState(() => _ObjectDetectionModule);
const {
error,
isReady,
Expand Down
2 changes: 1 addition & 1 deletion src/hooks/computer_vision/useStyleTransfer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ interface StyleTransferModule {
export const useStyleTransfer = ({
modelSource,
}: Props): StyleTransferModule => {
const [module, _] = useState(() => new _StyleTransferModule());
const [module, _] = useState(() => _StyleTransferModule);
const {
error,
isReady,
Expand Down
2 changes: 1 addition & 1 deletion src/hooks/general/useExecutorchModule.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ interface Props {
export const useExecutorchModule = ({
modelSource,
}: Props): ExecutorchModule => {
const [module] = useState(() => new _ETModule());
const [module] = useState(() => _ETModule);
const {
error,
isReady,
Expand Down
14 changes: 4 additions & 10 deletions src/modules/computer_vision/BaseModule.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,7 @@ import { Image } from 'react-native';
import { getError } from '../../Error';

export class BaseModule {
protected module: any;

constructor(module: any) {
this.module = module;
}

async loadModule(modelSource: string | number) {
static async load(module: any, modelSource: string | number) {
if (!modelSource) return;

let path = modelSource;
Expand All @@ -18,15 +12,15 @@ export class BaseModule {
}

try {
await this.module.loadModule(path);
await module.loadModule(path);
} catch (e) {
throw new Error(getError(e));
}
}

async forward(input: string) {
static async forward(module: any, input: string) {
try {
return await this.module.forward(input);
return await module.forward(input);
} catch (e) {
throw new Error(getError(e));
}
Expand Down
10 changes: 7 additions & 3 deletions src/modules/computer_vision/ClassificationModule.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import { BaseModule } from './BaseModule';
import { _ClassificationModule } from '../../native/RnExecutorchModules';

export class ClassificationModule extends BaseModule {
constructor() {
super(new _ClassificationModule());
export class ClassificationModule {
static async load(modelSource: string | number) {
await BaseModule.load(_ClassificationModule, modelSource);
}

static async forward(input: string): Promise<{ [category: string]: number }> {
return await BaseModule.forward(_ClassificationModule, input);
}
}
11 changes: 8 additions & 3 deletions src/modules/computer_vision/ObjectDetectionModule.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import { BaseModule } from './BaseModule';
import { _ObjectDetectionModule } from '../../native/RnExecutorchModules';
import { Detection } from '../../types/object_detection';

export class ObjectDetectionModule extends BaseModule {
constructor() {
super(new _ObjectDetectionModule());
export class ObjectDetectionModule {
static async load(modelSource: string | number) {
await BaseModule.load(_ObjectDetectionModule, modelSource);
}

static async forward(input: string): Promise<Detection[]> {
return await BaseModule.forward(_ObjectDetectionModule, input);
}
}
10 changes: 7 additions & 3 deletions src/modules/computer_vision/StyleTransferModule.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import { BaseModule } from './BaseModule';
import { _StyleTransferModule } from '../../native/RnExecutorchModules';

export class StyleTransfer extends BaseModule {
constructor() {
super(new _StyleTransferModule());
export class StyleTransferModule {
static async load(modelSource: string | number) {
await BaseModule.load(_StyleTransferModule, modelSource);
}

static async forward(input: string): Promise<string> {
return await BaseModule.forward(_StyleTransferModule, input);
}
}
16 changes: 7 additions & 9 deletions src/modules/general/ExecutorchModule.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ import { _ETModule } from '../../native/RnExecutorchModules';
import { ETInput, getTypeIdentifier } from '../../types/common';

export class ExecutorchModule {
private module = new _ETModule();

async loadModule(modelSource: string) {
static async load(modelSource: string) {
if (!modelSource) return;

let path = modelSource;
Expand All @@ -16,35 +14,35 @@ export class ExecutorchModule {
}

try {
await this.module.loadModule(path);
await _ETModule.loadModule(path);
} catch (e) {
throw new Error(getError(e));
}
}

async forward(input: ETInput, shape: number[]) {
static async forward(input: ETInput, shape: number[]) {
const inputType = getTypeIdentifier(input);
if (inputType === -1) {
throw new Error(getError(ETError.InvalidArgument));
}

try {
const numberArray = [...input] as number[];
return await this.module.forward(numberArray, shape, inputType);
return await _ETModule.forward(numberArray, shape, inputType);
} catch (e) {
throw new Error(getError(e));
}
}

async loadMethod(methodName: string) {
static async loadMethod(methodName: string) {
try {
await this.module.loadMethod(methodName);
await _ETModule.loadMethod(methodName);
} catch (e) {
throw new Error(getError(e));
}
}

async loadForward() {
static async loadForward() {
await this.loadMethod('forward');
}
}
12 changes: 6 additions & 6 deletions src/modules/natural_language_processing/LLMModule.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { Image } from 'react-native';
import { ResourceSource } from '../../types/common';

export class LLMModule {
async loadModel(
static async load(
modelSource: ResourceSource,
tokenizerSource: ResourceSource,
systemPrompt?: string,
Expand Down Expand Up @@ -32,27 +32,27 @@ export class LLMModule {
}
}

async generate(input: string): Promise<void> {
static async generate(input: string) {
try {
await LLM.runInference(input);
} catch (err) {
throw new Error((err as Error).message);
}
}

onDownloadProgress(callback: (data: number) => void) {
static onDownloadProgress(callback: (data: number) => void) {
return LLM.onDownloadProgress(callback);
}

onToken(callback: (data: string | undefined) => void) {
static onToken(callback: (data: string | undefined) => void) {
return LLM.onToken(callback);
}

interrupt() {
static interrupt() {
LLM.interrupt();
}

deleteModule() {
static deleteModule() {
LLM.deleteModule();
}
}
18 changes: 9 additions & 9 deletions src/native/RnExecutorchModules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,44 +72,44 @@ const StyleTransfer = StyleTransferSpec
);

class _ObjectDetectionModule {
async forward(input: string) {
static async forward(input: string) {
return await ObjectDetection.forward(input);
}
async loadModule(modelSource: string | number) {
static async loadModule(modelSource: string | number) {
return await ObjectDetection.loadModule(modelSource);
}
}

class _StyleTransferModule {
async forward(input: string) {
static async forward(input: string) {
return await StyleTransfer.forward(input);
}
async loadModule(modelSource: string | number) {
static async loadModule(modelSource: string | number) {
return await StyleTransfer.loadModule(modelSource);
}
}

class _ClassificationModule {
async forward(input: string) {
static async forward(input: string) {
return await Classification.forward(input);
}
async loadModule(modelSource: string | number) {
static async loadModule(modelSource: string | number) {
return await Classification.loadModule(modelSource);
}
}

class _ETModule {
async forward(
static async forward(
input: number[],
shape: number[],
inputType: number
): Promise<number[]> {
return await ETModule.forward(input, shape, inputType);
}
async loadModule(modelSource: string) {
static async loadModule(modelSource: string) {
return await ETModule.loadModule(modelSource);
}
async loadMethod(methodName: string): Promise<number> {
static async loadMethod(methodName: string): Promise<number> {
return await ETModule.loadMethod(methodName);
}
}
Expand Down

0 comments on commit 41d6b42

Please sign in to comment.