From 759451908c85c99dd922367aab3ac9b3f72f12e3 Mon Sep 17 00:00:00 2001 From: chmjkb Date: Wed, 11 Dec 2024 17:30:03 +0100 Subject: [PATCH] feat: update TS object detection api to match the native code --- .../object_detection/ObjectDetection.ts | 23 ++--- src/models/object_detection/types.ts | 96 ++++++++++++++++--- 2 files changed, 91 insertions(+), 28 deletions(-) diff --git a/src/models/object_detection/ObjectDetection.ts b/src/models/object_detection/ObjectDetection.ts index 645ce19..dc05a6f 100644 --- a/src/models/object_detection/ObjectDetection.ts +++ b/src/models/object_detection/ObjectDetection.ts @@ -3,14 +3,11 @@ import { Image } from 'react-native'; import { ETError, getError } from '../../Error'; import { ObjectDetection } from '../../native/RnExecutorchModules'; import { - ObjectDetectionModel, - ObjectDetectionOutputType, ObjectDetectionResult, } from './types'; interface Props { - model: keyof typeof ObjectDetectionModel; - path: string | number; + modelSource: string | number; } interface ObjectDetectionModule { @@ -19,29 +16,25 @@ interface ObjectDetectionModule { isModelGenerating: boolean; forward: ( input: string, - outputType: keyof typeof ObjectDetectionOutputType, - topk: number // TODO: find an alternative way ) => Promise; } export const useObjectDetection = ({ - model, - path, + modelSource, }: Props): ObjectDetectionModule => { const [error, setError] = useState(null); const [isModelLoading, setIsModelLoading] = useState(true); const [isModelGenerating, setIsModelGenerating] = useState(false); useEffect(() => { - // TODO: handle the case where kind is wrong const loadModel = async () => { - if (typeof path === 'number') { - path = Image.resolveAssetSource(path).uri; + if (typeof modelSource === 'number') { + modelSource = Image.resolveAssetSource(modelSource).uri; } try { setIsModelLoading(true); - await ObjectDetection.loadModule(path, model); + await ObjectDetection.loadModule(modelSource); } catch (e) { setError(getError(e)); } finally { @@ -50,12 +43,10 @@ export const useObjectDetection = ({ }; loadModel(); - }, [path]); + }, [modelSource]); const forward = async ( input: string, - outptutType: keyof typeof ObjectDetectionOutputType, - topk: number ) => { if (isModelLoading) { throw new Error(getError(ETError.ModuleNotLoaded)); @@ -63,7 +54,7 @@ export const useObjectDetection = ({ try { setIsModelGenerating(true); - const output = await ObjectDetection.forward(input, outptutType, topk); + const output = await ObjectDetection.forward(input); setIsModelGenerating(false); return output; } catch (e) { diff --git a/src/models/object_detection/types.ts b/src/models/object_detection/types.ts index 46623ee..b95d604 100644 --- a/src/models/object_detection/types.ts +++ b/src/models/object_detection/types.ts @@ -1,13 +1,3 @@ -export enum ObjectDetectionModel { - SSDLITE_LARGE = 'SSDLITE_LARGE', -} - -export enum ObjectDetectionOutputType { - IMAGE = 1, - DETECTIONS = 2, - ALL = 3, -} - export interface Bbox { x1: number; y1: number; @@ -17,11 +7,93 @@ export interface Bbox { export interface Detection { bbox: Bbox; - label: string; // TODO + label: keyof typeof CocoLabel; score: number; } export interface ObjectDetectionResult { - outputImageUri?: String[]; detections: Detection[]; } + +enum CocoLabel { + PERSON = 1, + BICYCLE = 2, + CAR = 3, + MOTORCYCLE = 4, + AIRPLANE = 5, + BUS = 6, + TRAIN = 7, + TRUCK = 8, + BOAT = 9, + TRAFFIC_LIGHT = 10, + FIRE_HYDRANT = 11, + STOP_SIGN = 12, + PARKING_METER = 13, + BENCH = 14, + BIRD = 15, + CAT = 16, + DOG = 17, + HORSE = 18, + SHEEP = 19, + COW = 20, + ELEPHANT = 21, + BEAR = 22, + ZEBRA = 23, + GIRAFFE = 24, + BACKPACK = 25, + UMBRELLA = 26, + HANDBAG = 27, + TIE = 28, + SUITCASE = 29, + FRISBEE = 30, + SKIS = 31, + SNOWBOARD = 32, + SPORTS_BALL = 33, + KITE = 34, + BASEBALL_BAT = 35, + BASEBALL_GLOVE = 36, + SKATEBOARD = 37, + SURFBOARD = 38, + TENNIS_RACKET = 39, + BOTTLE = 40, + WINE_GLASS = 41, + CUP = 42, + FORK = 43, + KNIFE = 44, + SPOON = 45, + BOWL = 46, + BANANA = 47, + APPLE = 48, + SANDWICH = 49, + ORANGE = 50, + BROCCOLI = 51, + CARROT = 52, + HOT_DOG = 53, + PIZZA = 54, + DONUT = 55, + CAKE = 56, + CHAIR = 57, + COUCH = 58, + POTTED_PLANT = 59, + BED = 60, + DINING_TABLE = 61, + TOILET = 62, + TV = 63, + LAPTOP = 64, + MOUSE = 65, + REMOTE = 66, + KEYBOARD = 67, + CELL_PHONE = 68, + MICROWAVE = 69, + OVEN = 70, + TOASTER = 71, + SINK = 72, + REFRIGERATOR = 73, + BOOK = 74, + CLOCK = 75, + VASE = 76, + SCISSORS = 77, + TEDDY_BEAR = 78, + HAIR_DRIER = 79, + TOOTHBRUSH = 80, +}