Skip to content

Commit

Permalink
feat: Hookless API (#82)
Browse files Browse the repository at this point in the history
## Description
This PR introduces hookless API and restructures the `./src` directory.

### Type of change
- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] Documentation update (improves or adds clarity to existing
documentation)

### Tested on
- [x] iOS
- [x] Android

### Checklist
- [x] I have performed a self-review of my code
- [x] I have commented my code, particularly in hard-to-understand areas
- [ ] I have updated the documentation accordingly
- [x] My changes generate no new warnings
  • Loading branch information
jakmro authored Jan 27, 2025
1 parent 7d391c8 commit cbe2b55
Show file tree
Hide file tree
Showing 23 changed files with 320 additions and 133 deletions.
14 changes: 0 additions & 14 deletions docs/docs/computer-vision/useClassification.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,6 @@ try {
}
```

<details>
<summary>Type definitions</summary>

```typescript
interface ClassificationModule {
error: string | null;
isReady: boolean;
isGenerating: boolean;
forward: (input: string) => Promise<{ [category: string]: number }>;
}
```

</details>

### Arguments

**`modelSource`**
Expand Down
53 changes: 26 additions & 27 deletions docs/docs/computer-vision/useObjectDetection.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ It is recommended to use models provided by us, which are available at our [Hugg
:::

## Reference

```jsx
import { useObjectDetection, SSDLITE_320_MOBILENET_V3_LARGE } from 'react-native-executorch';

Expand Down Expand Up @@ -45,42 +46,36 @@ interface Detection {
label: keyof typeof CocoLabel;
score: number;
}

interface ObjectDetectionModule {
error: string | null;
isReady: boolean;
isGenerating: boolean;
forward: (input: string) => Promise<Detection[]>;
}
```

</details>

### Arguments

`modelSource`

A string that specifies the path to the model file. You can download the model from our [HuggingFace repository](https://huggingface.co/software-mansion/react-native-executorch-ssdlite320-mobilenet-v3-large/tree/main).
A string that specifies the path to the model file. You can download the model from our [HuggingFace repository](https://huggingface.co/software-mansion/react-native-executorch-ssdlite320-mobilenet-v3-large/tree/main).
For more information on that topic, you can check out the [Loading models](https://docs.swmansion.com/react-native-executorch/fundamentals/loading-models) page.

### Returns

The hook returns an object with the following properties:


| **Field** | **Type** | **Description** |
|-----------------------|---------------------------------------|------------------------------------------------------------------------------------------------------------------|
| `forward` | `(input: string) => Promise<Detection[]>` | A function that accepts an image (url, b64) and returns an array of `Detection` objects. |
| `error` | <code>string &#124; null</code> | Contains the error message if the model loading failed. |
| `isGenerating` | `boolean` | Indicates whether the model is currently processing an inference. |
| `isReady` | `boolean` | Indicates whether the model has successfully loaded and is ready for inference. |

| Field | Type | Description |
| -------------- | ----------------------------------------- | ---------------------------------------------------------------------------------------- |
| `forward` | `(input: string) => Promise<Detection[]>` | A function that accepts an image (url, b64) and returns an array of `Detection` objects. |
| `error` | <code>string &#124; null</code> | Contains the error message if the model loading failed. |
| `isGenerating` | `boolean` | Indicates whether the model is currently processing an inference. |
| `isReady` | `boolean` | Indicates whether the model has successfully loaded and is ready for inference. |

## Running the model

To run the model, you can use the `forward` method. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image. The function returns an array of `Detection` objects. Each object contains coordinates of the bounding box, the label of the detected object, and the confidence score. For more information, please refer to the reference or type definitions.

## Detection object

The detection object is specified as follows:

```typescript
interface Bbox {
x1: number;
Expand All @@ -95,33 +90,37 @@ interface Detection {
score: number;
}
```

The `bbox` property contains information about the bounding box of detected objects. It is represented as two points: one at the bottom-left corner of the bounding box (`x1`, `y1`) and the other at the top-right corner (`x2`, `y2`).
The `label` property contains the name of the detected object, which corresponds to one of the `CocoLabels`. The `score` represents the confidence score of the detected object.



## Example

```tsx
import { useObjectDetection, SSDLITE_320_MOBILENET_V3_LARGE } from 'react-native-executorch';
import {
useObjectDetection,
SSDLITE_320_MOBILENET_V3_LARGE,
} from 'react-native-executorch';

function App() {
const ssdlite = useObjectDetection({
modelSource: SSDLITE_320_MOBILENET_V3_LARGE,
});

const runModel = async () => {
const detections = await ssdlite.forward("https://url-to-image.jpg");
const detections = await ssdlite.forward('https://url-to-image.jpg');

for (const detection of detections) {
console.log("Bounding box: ", detection.bbox);
console.log("Bounding label: ", detection.label);
console.log("Bounding score: ", detection.score);
console.log('Bounding box: ', detection.bbox);
console.log('Bounding label: ', detection.label);
console.log('Bounding score: ', detection.score);
}
}
};
}
```

## Supported models

| Model | Number of classes | Class list |
| --------------------------------------------------------------------------------------------------------------- | ----------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [SSDLite320 MobileNetV3 Large](https://pytorch.org/vision/main/models/generated/torchvision.models.detection.ssdlite320_mobilenet_v3_large.html#torchvision.models.detection.SSDLite320_MobileNet_V3_Large_Weights) | 91 | [COCO](https://github.com/software-mansion/react-native-executorch/blob/69802ee1ca161d9df00def1dabe014d36341cfa9/src/types/object_detection.ts#L14) |
| Model | Number of classes | Class list |
| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------- | --------------------------------------------------------------------------------------------------------------------------------------------------- |
| [SSDLite320 MobileNetV3 Large](https://pytorch.org/vision/main/models/generated/torchvision.models.detection.ssdlite320_mobilenet_v3_large.html#torchvision.models.detection.SSDLite320_MobileNet_V3_Large_Weights) | 91 | [COCO](https://github.com/software-mansion/react-native-executorch/blob/69802ee1ca161d9df00def1dabe014d36341cfa9/src/types/object_detection.ts#L14) |
14 changes: 0 additions & 14 deletions docs/docs/computer-vision/useStyleTransfer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,6 @@ try {
}
```

<details>
<summary>Type definitions</summary>

```typescript
interface StyleTransferModule {
error: string | null;
isReady: boolean;
isGenerating: boolean;
forward: (input: string) => Promise<string>;
}
```

</details>

### Arguments

**`modelSource`**
Expand Down
9 changes: 2 additions & 7 deletions examples/computer-vision/components/ImageWithBboxes.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,7 @@ export default function ImageWithBboxes({
const height = (y2 - y1) * scaleY;

return (
<View
key={index}
style={[
styles.bbox,
{ left, top, width, height, borderColor: 'red' },
]}
>
<View key={index} style={[styles.bbox, { left, top, width, height }]}>
<Text style={styles.label}>
{detection.label} ({(detection.score * 100).toFixed(1)}%)
</Text>
Expand All @@ -98,6 +92,7 @@ const styles = StyleSheet.create({
bbox: {
position: 'absolute',
borderWidth: 2,
borderColor: 'red',
},
label: {
position: 'absolute',
Expand Down
6 changes: 5 additions & 1 deletion examples/computer-vision/screens/ObjectDetectionScreen.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ export const ObjectDetectionScreen = ({
/>
) : (
<Image
style={{ width: '100%', height: '100%' }}
style={styles.fullSizeImage}
resizeMode="contain"
source={require('../assets/icons/executorch_logo.png')}
/>
Expand Down Expand Up @@ -127,4 +127,8 @@ const styles = StyleSheet.create({
flex: 1,
marginRight: 4,
},
fullSizeImage: {
width: '100%',
height: '100%',
},
});
5 changes: 4 additions & 1 deletion examples/llama/screens/ChatScreen.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ export default function ChatScreen() {
<SafeAreaView style={styles.container}>
<TouchableWithoutFeedback onPress={Keyboard.dismiss}>
<KeyboardAvoidingView
style={{ flex: 1 }}
style={styles.keyboardAvoidingView}
behavior={Platform.OS === 'ios' ? 'padding' : 'height'}
keyboardVerticalOffset={Platform.OS === 'android' ? 30 : 0}
>
Expand Down Expand Up @@ -133,6 +133,9 @@ const styles = StyleSheet.create({
container: {
flex: 1,
},
keyboardAvoidingView: {
flex: 1,
},
topContainer: {
height: 68,
width: '100%',
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import { useState } from 'react';
import { _ClassificationModule } from '../native/RnExecutorchModules';
import { useModule } from '../useModule';
import { _ClassificationModule } from '../../native/RnExecutorchModules';
import { useModule } from '../../useModule';

interface Props {
modelSource: string | number;
}

interface ClassificationModule {
export const useClassification = ({
modelSource,
}: Props): {
error: string | null;
isReady: boolean;
isGenerating: boolean;
forward: (input: string) => Promise<{ [category: string]: number }>;
}

export const useClassification = ({
modelSource,
}: Props): ClassificationModule => {
} => {
const [module, _] = useState(() => new _ClassificationModule());
const {
error,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
import { useState } from 'react';
import { _ObjectDetectionModule } from '../native/RnExecutorchModules';
import { useModule } from '../useModule';
import { Detection } from '../types/object_detection';
import { _ObjectDetectionModule } from '../../native/RnExecutorchModules';
import { useModule } from '../../useModule';
import { Detection } from '../../types/object_detection';

interface Props {
modelSource: string | number;
}

interface ObjectDetectionModule {
export const useObjectDetection = ({
modelSource,
}: Props): {
error: string | null;
isReady: boolean;
isGenerating: boolean;
forward: (input: string) => Promise<Detection[]>;
}

export const useObjectDetection = ({
modelSource,
}: Props): ObjectDetectionModule => {
} => {
const [module, _] = useState(() => new _ObjectDetectionModule());
const {
error,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import { useState } from 'react';
import { _StyleTransferModule } from '../native/RnExecutorchModules';
import { useModule } from '../useModule';
import { _StyleTransferModule } from '../../native/RnExecutorchModules';
import { useModule } from '../../useModule';

interface Props {
modelSource: string | number;
}

interface StyleTransferModule {
export const useStyleTransfer = ({
modelSource,
}: Props): {
error: string | null;
isReady: boolean;
isGenerating: boolean;
forward: (input: string) => Promise<string>;
}

export const useStyleTransfer = ({
modelSource,
}: Props): StyleTransferModule => {
} => {
const [module, _] = useState(() => new _StyleTransferModule());
const {
error,
Expand Down
17 changes: 12 additions & 5 deletions src/ETModule.ts → src/hooks/general/useExecutorchModule.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
import { useState } from 'react';
import { _ETModule } from './native/RnExecutorchModules';
import { getError } from './Error';
import { ExecutorchModule } from './types/common';
import { useModule } from './useModule';
import { _ETModule } from '../../native/RnExecutorchModules';
import { useModule } from '../../useModule';
import { ETInput } from '../../types/common';
import { getError } from '../../Error';

interface Props {
modelSource: string | number;
}

export const useExecutorchModule = ({
modelSource,
}: Props): ExecutorchModule => {
}: Props): {
error: string | null;
isReady: boolean;
isGenerating: boolean;
forward: (input: ETInput, shape: number[]) => Promise<number[][]>;
loadMethod: (methodName: string) => Promise<void>;
loadForward: () => Promise<void>;
} => {
const [module] = useState(() => new _ETModule());
const {
error,
Expand Down
6 changes: 3 additions & 3 deletions src/LLM.ts → ...oks/natural_language_processing/useLLM.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import { useCallback, useEffect, useRef, useState } from 'react';
import { EventSubscription, Image } from 'react-native';
import { ResourceSource, Model } from './types/common';
import { ResourceSource, Model } from '../../types/common';
import {
DEFAULT_CONTEXT_WINDOW_LENGTH,
DEFAULT_SYSTEM_PROMPT,
EOT_TOKEN,
} from './constants/llamaDefaults';
import { LLM } from './native/RnExecutorchModules';
} from '../../constants/llamaDefaults';
import { LLM } from '../../native/RnExecutorchModules';

const interrupt = () => {
LLM.interrupt();
Expand Down
28 changes: 22 additions & 6 deletions src/index.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
export * from './ETModule';
export * from './LLM';
export * from './constants/modelUrls';
export * from './models/Classification';
export * from './models/ObjectDetection';
export * from './models/StyleTransfer';
// hooks
export * from './hooks/computer_vision/useClassification';
export * from './hooks/computer_vision/useObjectDetection';
export * from './hooks/computer_vision/useStyleTransfer';

export * from './hooks/natural_language_processing/useLLM';

export * from './hooks/general/useExecutorchModule';

// modules
export * from './modules/computer_vision/ClassificationModule';
export * from './modules/computer_vision/ObjectDetectionModule';
export * from './modules/computer_vision/StyleTransferModule';

export * from './modules/natural_language_processing/LLMModule';

export * from './modules/general/ExecutorchModule';

// types
export * from './types/object_detection';

// constants
export * from './constants/modelUrls';
Loading

0 comments on commit cbe2b55

Please sign in to comment.