diff --git a/android/build.gradle b/android/build.gradle index ea5425d..ed7e01d 100644 --- a/android/build.gradle +++ b/android/build.gradle @@ -1,4 +1,8 @@ buildscript { + ext { + agp_version = '8.4.2' + } + // Buildscript is evaluated before everything else so we can't use getExtOrDefault def kotlin_version = rootProject.ext.has("kotlinVersion") ? rootProject.ext.get("kotlinVersion") : project.properties["RnExecutorch_kotlinVersion"] @@ -9,7 +13,7 @@ buildscript { } dependencies { - classpath "com.android.tools.build:gradle:7.2.1" + classpath "com.android.tools.build:gradle:$agp_version" // noinspection DifferentKotlinGradleVersion classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version" } @@ -95,7 +99,8 @@ dependencies { // For < 0.71, this will be from the local maven repo // For > 0.71, this will be replaced by `com.facebook.react:react-android:$version` by react gradle plugin //noinspection GradleDynamicVersion - implementation "com.facebook.react:react-native:+" + implementation "com.facebook.react:react-android:+" + implementation 'org.opencv:opencv:4.10.0' implementation "org.jetbrains.kotlin:kotlin-stdlib:$kotlin_version" implementation 'com.github.software-mansion:react-native-executorch:main-SNAPSHOT' implementation 'org.opencv:opencv:4.10.0' diff --git a/android/src/main/java/com/swmansion/rnexecutorch/ObjectDetection.kt b/android/src/main/java/com/swmansion/rnexecutorch/ObjectDetection.kt new file mode 100644 index 0000000..0411598 --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/ObjectDetection.kt @@ -0,0 +1,61 @@ +package com.swmansion.rnexecutorch + +import android.util.Log +import com.facebook.react.bridge.Arguments +import com.facebook.react.bridge.Promise +import com.facebook.react.bridge.ReactApplicationContext +import com.facebook.react.bridge.WritableArray +import com.swmansion.rnexecutorch.models.BaseModel +import com.swmansion.rnexecutorch.utils.ETError +import com.swmansion.rnexecutorch.utils.ImageProcessor +import org.opencv.android.OpenCVLoader +import com.swmansion.rnexecutorch.models.objectdetection.SSDLiteLargeModel +import org.opencv.core.Mat + +class ObjectDetection(reactContext: ReactApplicationContext) : + NativeObjectDetectionSpec(reactContext) { + + private lateinit var ssdLiteLarge: SSDLiteLargeModel + + companion object { + const val NAME = "ObjectDetection" + } + + init { + if(!OpenCVLoader.initLocal()){ + Log.d("rn_executorch", "OpenCV not loaded") + } else { + Log.d("rn_executorch", "OpenCV loaded") + } + } + + override fun loadModule(modelSource: String, promise: Promise) { + try { + ssdLiteLarge = SSDLiteLargeModel(reactApplicationContext) + ssdLiteLarge.loadModel(modelSource) + promise.resolve(0) + } catch (e: Exception) { + promise.reject(e.message!!, ETError.InvalidModelPath.toString()) + } + } + + override fun forward(input: String, promise: Promise) { + try { + val inputImage = ImageProcessor.readImage(input) + val output = ssdLiteLarge.runModel(inputImage) + val outputWritableArray: WritableArray = Arguments.createArray() + output.map { detection -> + detection.toWritableMap() + }.forEach { writableMap -> + outputWritableArray.pushMap(writableMap) + } + promise.resolve(outputWritableArray) + } catch(e: Exception){ + promise.reject(e.message!!, e.message) + } + } + + override fun getName(): String { + return NAME + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt b/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt index fc4ba2f..6a4fdc2 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt @@ -21,7 +21,9 @@ class RnExecutorchPackage : TurboReactPackage() { StyleTransfer(reactContext) } else if (name == Classification.NAME) { Classification(reactContext) - } + } else if (name == ObjectDetection.NAME) { + ObjectDetection(reactContext) + } else { null } @@ -63,6 +65,15 @@ class RnExecutorchPackage : TurboReactPackage() { false, // isCxxModule true ) + + moduleInfos[ObjectDetection.NAME] = ReactModuleInfo( + ObjectDetection.NAME, + ObjectDetection.NAME, + false, // canOverrideExistingModule + false, // needsEagerInit + false, // isCxxModule + true + ) moduleInfos } } diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/object_detection/SSDLiteLargeModel.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/object_detection/SSDLiteLargeModel.kt new file mode 100644 index 0000000..f8c5748 --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/object_detection/SSDLiteLargeModel.kt @@ -0,0 +1,71 @@ +package com.swmansion.rnexecutorch.models.objectdetection + +import com.facebook.react.bridge.ReactApplicationContext +import com.swmansion.rnexecutorch.utils.ImageProcessor +import org.opencv.core.Mat +import org.opencv.core.Size +import org.opencv.imgproc.Imgproc +import com.swmansion.rnexecutorch.models.BaseModel +import com.swmansion.rnexecutorch.utils.Bbox +import com.swmansion.rnexecutorch.utils.CocoLabel +import com.swmansion.rnexecutorch.utils.Detection +import com.swmansion.rnexecutorch.utils.nms +import org.pytorch.executorch.EValue + +const val detectionScoreThreshold = .7f +const val iouThreshold = .55f + +class SSDLiteLargeModel(reactApplicationContext: ReactApplicationContext) : BaseModel>(reactApplicationContext) { + private var heightRatio: Float = 1.0f + private var widthRatio: Float = 1.0f + + private fun getModelImageSize(): Size { + val inputShape = module.getInputShape(0) + val width = inputShape[inputShape.lastIndex] + val height = inputShape[inputShape.lastIndex - 1] + + return Size(height.toDouble(), width.toDouble()) + } + + override fun preprocess(input: Mat): EValue { + this.widthRatio = (input.size().width / getModelImageSize().width).toFloat() + this.heightRatio = (input.size().height / getModelImageSize().height).toFloat() + Imgproc.resize(input, input, getModelImageSize()) + return ImageProcessor.matToEValue(input, module.getInputShape(0)) + } + + override fun runModel(input: Mat): Array { + val modelInput = preprocess(input) + val modelOutput = forward(modelInput) + return postprocess(modelOutput) + } + + override fun postprocess(output: Array): Array { + val scoresTensor = output[1].toTensor() + val numel = scoresTensor.numel() + val bboxes = output[0].toTensor().dataAsFloatArray + val scores = scoresTensor.dataAsFloatArray + val labels = output[2].toTensor().dataAsFloatArray + + val detections: MutableList = mutableListOf(); + for (idx in 0 until numel.toInt()) { + val score = scores[idx] + if (score < detectionScoreThreshold) { + continue + } + val bbox = Bbox( + bboxes[idx * 4 + 0] * this.widthRatio, + bboxes[idx * 4 + 1] * this.heightRatio, + bboxes[idx * 4 + 2] * this.widthRatio, + bboxes[idx * 4 + 3] * this.heightRatio + ) + val label = labels[idx] + detections.add( + Detection(bbox, score, CocoLabel.fromId(label.toInt())!!) + ) + } + + val detectionsPostNms = nms(detections, iouThreshold); + return detectionsPostNms.toTypedArray() + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/utils/ObjectDetectionUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/utils/ObjectDetectionUtils.kt new file mode 100644 index 0000000..00fd3fb --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/utils/ObjectDetectionUtils.kt @@ -0,0 +1,196 @@ +package com.swmansion.rnexecutorch.utils + +import com.facebook.react.bridge.Arguments +import com.facebook.react.bridge.WritableMap + +fun nms( + detections: MutableList, + iouThreshold: Float +): List { + if (detections.isEmpty()) { + return emptyList() + } + + // Sort detections first by label, then by score (descending) + val sortedDetections = detections.sortedWith(compareBy({ it.label }, { -it.score })) + + val result = mutableListOf() + + // Process NMS for each label group + var i = 0 + while (i < sortedDetections.size) { + val currentLabel = sortedDetections[i].label + + // Collect detections for the current label + val labelDetections = mutableListOf() + while (i < sortedDetections.size && sortedDetections[i].label == currentLabel) { + labelDetections.add(sortedDetections[i]) + i++ + } + + // Filter out detections with high IoU + val filteredLabelDetections = mutableListOf() + while (labelDetections.isNotEmpty()) { + val current = labelDetections.removeAt(0) + filteredLabelDetections.add(current) + + // Remove detections that overlap with the current detection above the IoU threshold + val iterator = labelDetections.iterator() + while (iterator.hasNext()) { + val other = iterator.next() + if (calculateIoU(current.bbox, other.bbox) > iouThreshold) { + iterator.remove() // Remove detection if IoU is above threshold + } + } + } + + // Add the filtered detections to the result + result.addAll(filteredLabelDetections) + } + + return result +} + +fun calculateIoU(bbox1: Bbox, bbox2: Bbox): Float { + val x1 = maxOf(bbox1.x1, bbox2.x1) + val y1 = maxOf(bbox1.y1, bbox2.y1) + val x2 = minOf(bbox1.x2, bbox2.x2) + val y2 = minOf(bbox1.y2, bbox2.y2) + + val intersectionArea = maxOf(0f, x2 - x1) * maxOf(0f, y2 - y1) + val bbox1Area = (bbox1.x2 - bbox1.x1) * (bbox1.y2 - bbox1.y1) + val bbox2Area = (bbox2.x2 - bbox2.x1) * (bbox2.y2 - bbox2.y1) + + val unionArea = bbox1Area + bbox2Area - intersectionArea + return if (unionArea == 0f) 0f else intersectionArea / unionArea +} + + +data class Bbox( + val x1: Float, + val y1: Float, + val x2: Float, + val y2: Float +) { + fun toWritableMap(): WritableMap { + val map = Arguments.createMap() + map.putDouble("x1", x1.toDouble()) + map.putDouble("x2", x2.toDouble()) + map.putDouble("y1", y1.toDouble()) + map.putDouble("y2", y2.toDouble()) + return map + } +} + + +data class Detection( + val bbox: Bbox, + val score: Float, + val label: CocoLabel, +) { + fun toWritableMap(): WritableMap { + val map = Arguments.createMap() + map.putMap("bbox", bbox.toWritableMap()) + map.putDouble("score", score.toDouble()) + map.putString("label", label.name) + return map + } +} + +enum class CocoLabel(val id: Int) { + PERSON(1), + BICYCLE(2), + CAR(3), + MOTORCYCLE(4), + AIRPLANE(5), + BUS(6), + TRAIN(7), + TRUCK(8), + BOAT(9), + TRAFFIC_LIGHT(10), + FIRE_HYDRANT(11), + STREET_SIGN(12), + STOP_SIGN(13), + PARKING(14), + BENCH(15), + BIRD(16), + CAT(17), + DOG(18), + HORSE(19), + SHEEP(20), + COW(21), + ELEPHANT(22), + BEAR(23), + ZEBRA(24), + GIRAFFE(25), + HAT(26), + BACKPACK(27), + UMBRELLA(28), + SHOE(29), + EYE(30), + HANDBAG(31), + TIE(32), + SUITCASE(33), + FRISBEE(34), + SKIS(35), + SNOWBOARD(36), + SPORTS(37), + KITE(38), + BASEBALL(39), + SKATEBOARD(41), + SURFBOARD(42), + TENNIS_RACKET(43), + BOTTLE(44), + PLATE(45), + WINE_GLASS(46), + CUP(47), + FORK(48), + KNIFE(49), + SPOON(50), + BOWL(51), + BANANA(52), + APPLE(53), + SANDWICH(54), + ORANGE(55), + BROCCOLI(56), + CARROT(57), + HOT_DOG(58), + PIZZA(59), + DONUT(60), + CAKE(61), + CHAIR(62), + COUCH(63), + POTTED_PLANT(64), + BED(65), + MIRROR(66), + DINING_TABLE(67), + WINDOW(68), + DESK(69), + TOILET(70), + DOOR(71), + TV(72), + LAPTOP(73), + MOUSE(74), + REMOTE(75), + KEYBOARD(76), + CELL_PHONE(77), + MICROWAVE(78), + OVEN(79), + TOASTER(80), + SINK(81), + REFRIGERATOR(82), + BLENDER(83), + BOOK(84), + CLOCK(85), + VASE(86), + SCISSORS(87), + TEDDY_BEAR(88), + HAIR_DRIER(89), + TOOTHBRUSH(90), + HAIR_BRUSH(91); + + companion object { + private val idToLabelMap = values().associateBy(CocoLabel::id) + fun fromId(id: Int): CocoLabel? = idToLabelMap[id] + } +} diff --git a/examples/computer-vision/App.tsx b/examples/computer-vision/App.tsx index 5f6c4ef..7fc4571 100644 --- a/examples/computer-vision/App.tsx +++ b/examples/computer-vision/App.tsx @@ -7,6 +7,7 @@ import { StyleTransferScreen } from './screens/StyleTransferScreen'; import { SafeAreaProvider, SafeAreaView } from 'react-native-safe-area-context'; import { View, StyleSheet } from 'react-native'; import { ClassificationScreen } from './screens/ClassificationScreen'; +import { ObjectDetectionScreen } from './screens/ObjectDetectionScreen'; enum ModelType { STYLE_TRANSFER, @@ -36,7 +37,12 @@ export default function App() { ); case ModelType.OBJECT_DETECTION: - return <>; + return ( + + ); case ModelType.CLASSIFICATION: return ( diff --git a/examples/computer-vision/components/ImageWithBboxes.tsx b/examples/computer-vision/components/ImageWithBboxes.tsx new file mode 100644 index 0000000..7d08e33 --- /dev/null +++ b/examples/computer-vision/components/ImageWithBboxes.tsx @@ -0,0 +1,112 @@ +import React from 'react'; +import { Image, StyleSheet, View, Text } from 'react-native'; +import { Detection } from 'react-native-executorch'; + +interface Props { + imageUri: string; + detections: Detection[]; + imageWidth: number; + imageHeight: number; +} + +export default function ImageWithBboxes({ + imageUri, + detections, + imageWidth, + imageHeight, +}: Props) { + const [layout, setLayout] = React.useState({ width: 0, height: 0 }); + + const calculateAdjustedDimensions = () => { + const imageRatio = imageWidth / imageHeight; + const layoutRatio = layout.width / layout.height; + + let sx, sy; // Scale in x and y directions + if (imageRatio > layoutRatio) { + // image is more "wide" + sx = layout.width / imageWidth; + sy = layout.width / imageRatio / imageHeight; + } else { + // image is more "tall" + sy = layout.height / imageHeight; + sx = (layout.height * imageRatio) / imageWidth; + } + + return { + scaleX: sx, + scaleY: sy, + offsetX: (layout.width - imageWidth * sx) / 2, // Centering the image horizontally + offsetY: (layout.height - imageHeight * sy) / 2, // Centering the image vertically + }; + }; + + return ( + { + const { width, height } = event.nativeEvent.layout; + setLayout({ width, height }); + }} + > + + {detections.map((detection, index) => { + const { scaleX, scaleY, offsetX, offsetY } = + calculateAdjustedDimensions(); + const { x1, y1, x2, y2 } = detection.bbox; + + const left = x1 * scaleX + offsetX; + const top = y1 * scaleY + offsetY; + const width = (x2 - x1) * scaleX; + const height = (y2 - y1) * scaleY; + + return ( + + + {detection.label} ({(detection.score * 100).toFixed(1)}%) + + + ); + })} + + ); +} + +const styles = StyleSheet.create({ + container: { + flex: 1, + position: 'relative', + }, + image: { + flex: 1, + width: '100%', + height: '100%', + }, + bbox: { + position: 'absolute', + borderWidth: 2, + }, + label: { + position: 'absolute', + top: -20, + left: 0, + backgroundColor: 'rgba(255, 0, 0, 0.7)', + color: 'white', + fontSize: 12, + paddingHorizontal: 4, + borderRadius: 4, + }, +}); diff --git a/examples/computer-vision/screens/ClassificationScreen.tsx b/examples/computer-vision/screens/ClassificationScreen.tsx index bb7c528..587d072 100644 --- a/examples/computer-vision/screens/ClassificationScreen.tsx +++ b/examples/computer-vision/screens/ClassificationScreen.tsx @@ -1,7 +1,7 @@ import { useState } from 'react'; import Spinner from 'react-native-loading-spinner-overlay'; import { BottomBar } from '../components/BottomBar'; -import { getImageUri } from '../utils'; +import { getImage } from '../utils'; import { useClassification, EFFICIENTNET_V2_S } from 'react-native-executorch'; import { View, StyleSheet, Image, Text, ScrollView } from 'react-native'; @@ -21,7 +21,8 @@ export const ClassificationScreen = ({ }); const handleCameraPress = async (isCamera: boolean) => { - const uri = await getImageUri(isCamera); + const image = await getImage(isCamera); + const uri = image?.uri; if (typeof uri === 'string') { setImageUri(uri as string); setResults([]); diff --git a/examples/computer-vision/screens/ObjectDetectionScreen.tsx b/examples/computer-vision/screens/ObjectDetectionScreen.tsx new file mode 100644 index 0000000..67f32e7 --- /dev/null +++ b/examples/computer-vision/screens/ObjectDetectionScreen.tsx @@ -0,0 +1,130 @@ +import { useState } from 'react'; +import Spinner from 'react-native-loading-spinner-overlay'; +import { BottomBar } from '../components/BottomBar'; +import { getImage } from '../utils'; +import { + Detection, + useObjectDetection, + SSDLITE_320_MOBILENET_V3_LARGE_URL, +} from 'react-native-executorch'; +import { View, StyleSheet, Image } from 'react-native'; +import ImageWithBboxes from '../components/ImageWithBboxes'; + +export const ObjectDetectionScreen = ({ + imageUri, + setImageUri, +}: { + imageUri: string; + setImageUri: (imageUri: string) => void; +}) => { + const [results, setResults] = useState([]); + const [imageDimensions, setImageDimensions] = useState<{ + width: number; + height: number; + }>(); + + const ssdLite = useObjectDetection({ + modelSource: SSDLITE_320_MOBILENET_V3_LARGE_URL, + }); + + const handleCameraPress = async (isCamera: boolean) => { + const image = await getImage(isCamera); + const uri = image?.uri; + const width = image?.width; + const height = image?.height; + + if (uri && width && height) { + setImageUri(image.uri as string); + setImageDimensions({ width: width as number, height: height as number }); + setResults([]); + } + }; + + const runForward = async () => { + if (imageUri) { + try { + const output = await ssdLite.forward(imageUri); + console.log(output); + setResults(output); + } catch (e) { + console.error(e); + } + } + }; + + if (!ssdLite.isModelReady) { + return ( + + ); + } + + return ( + <> + + + {imageUri && imageDimensions?.width && imageDimensions?.height ? ( + + ) : ( + + )} + + + + + ); +}; + +const styles = StyleSheet.create({ + imageContainer: { + flex: 6, + width: '100%', + padding: 16, + }, + image: { + flex: 2, + borderRadius: 8, + width: '100%', + }, + results: { + flex: 1, + alignItems: 'center', + justifyContent: 'center', + gap: 4, + padding: 4, + }, + resultHeader: { + fontSize: 18, + color: 'navy', + }, + resultsList: { + flex: 1, + }, + resultRecord: { + flexDirection: 'row', + width: '100%', + justifyContent: 'space-between', + padding: 8, + borderBottomWidth: 1, + }, + resultLabel: { + flex: 1, + marginRight: 4, + }, +}); diff --git a/examples/computer-vision/screens/StyleTransferScreen.tsx b/examples/computer-vision/screens/StyleTransferScreen.tsx index 1556b97..a82657b 100644 --- a/examples/computer-vision/screens/StyleTransferScreen.tsx +++ b/examples/computer-vision/screens/StyleTransferScreen.tsx @@ -1,6 +1,6 @@ import Spinner from 'react-native-loading-spinner-overlay'; import { BottomBar } from '../components/BottomBar'; -import { getImageUri } from '../utils'; +import { getImage } from '../utils'; import { useStyleTransfer, STYLE_TRANSFER_CANDY, @@ -19,7 +19,8 @@ export const StyleTransferScreen = ({ }); const handleCameraPress = async (isCamera: boolean) => { - const uri = await getImageUri(isCamera); + const image = await getImage(isCamera); + const uri = image?.uri; if (typeof uri === 'string') { setImageUri(uri as string); } diff --git a/examples/computer-vision/tsconfig.json b/examples/computer-vision/tsconfig.json index b9567f6..a0863d0 100644 --- a/examples/computer-vision/tsconfig.json +++ b/examples/computer-vision/tsconfig.json @@ -1,6 +1,7 @@ { "extends": "expo/tsconfig.base", "compilerOptions": { - "strict": true + "strict": true, + "jsx": "react-jsx" } } diff --git a/examples/computer-vision/utils.ts b/examples/computer-vision/utils.ts index 327681a..3af8d79 100644 --- a/examples/computer-vision/utils.ts +++ b/examples/computer-vision/utils.ts @@ -4,7 +4,7 @@ import { launchImageLibrary, } from 'react-native-image-picker'; -export const getImageUri = async (useCamera: boolean) => { +export const getImage = async (useCamera: boolean) => { const options: CameraOptions = { mediaType: 'photo', }; @@ -15,9 +15,7 @@ export const getImageUri = async (useCamera: boolean) => { if (!output.assets || output.assets.length === 0) return; - const imageUri = output.assets[0].uri; - if (!imageUri) return; - return imageUri; + return output.assets[0]; } catch (err) { console.error(err); } diff --git a/src/constants/modelUrls.ts b/src/constants/modelUrls.ts index 795d5e4..dca54e4 100644 --- a/src/constants/modelUrls.ts +++ b/src/constants/modelUrls.ts @@ -43,3 +43,7 @@ export const STYLE_TRANSFER_UDNIE = Platform.OS === 'ios' ? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-udnie/resolve/main/coreml/style_transfer_udnie_coreml.pte' : 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-udnie/resolve/main/xnnpack/style_transfer_udnie_xnnpack.pte'; + +// Object detection +export const SSDLITE_320_MOBILENET_V3_LARGE_URL = + 'https://huggingface.co/software-mansion/react-native-executorch-ssdlite320-mobilenet-v3-large/resolve/main/ssdlite320-mobilenetv3-large.pte';