Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add object detection (android) #52

Merged
merged 7 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions android/build.gradle
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
buildscript {
ext {
agp_version = '8.4.2'
}

// Buildscript is evaluated before everything else so we can't use getExtOrDefault
def kotlin_version = rootProject.ext.has("kotlinVersion") ? rootProject.ext.get("kotlinVersion") : project.properties["RnExecutorch_kotlinVersion"]

Expand All @@ -9,7 +13,7 @@ buildscript {
}

dependencies {
classpath "com.android.tools.build:gradle:7.2.1"
classpath "com.android.tools.build:gradle:$agp_version"
// noinspection DifferentKotlinGradleVersion
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
}
Expand Down Expand Up @@ -95,7 +99,8 @@ dependencies {
// For < 0.71, this will be from the local maven repo
// For > 0.71, this will be replaced by `com.facebook.react:react-android:$version` by react gradle plugin
//noinspection GradleDynamicVersion
implementation "com.facebook.react:react-native:+"
implementation "com.facebook.react:react-android:+"
implementation 'org.opencv:opencv:4.10.0'
implementation "org.jetbrains.kotlin:kotlin-stdlib:$kotlin_version"
implementation 'com.github.software-mansion:react-native-executorch:main-SNAPSHOT'
implementation 'org.opencv:opencv:4.10.0'
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package com.swmansion.rnexecutorch

import android.util.Log
import com.facebook.react.bridge.Arguments
import com.facebook.react.bridge.Promise
import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.bridge.WritableArray
import com.swmansion.rnexecutorch.models.BaseModel
import com.swmansion.rnexecutorch.utils.ETError
import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.android.OpenCVLoader
import com.swmansion.rnexecutorch.models.objectdetection.SSDLiteLargeModel
import org.opencv.core.Mat

class ObjectDetection(reactContext: ReactApplicationContext) :
NativeObjectDetectionSpec(reactContext) {

private lateinit var ssdLiteLarge: SSDLiteLargeModel

companion object {
const val NAME = "ObjectDetection"
}

init {
if(!OpenCVLoader.initLocal()){
Log.d("rn_executorch", "OpenCV not loaded")
} else {
Log.d("rn_executorch", "OpenCV loaded")
}
}

override fun loadModule(modelSource: String, promise: Promise) {
try {
ssdLiteLarge = SSDLiteLargeModel(reactApplicationContext)
ssdLiteLarge.loadModel(modelSource)
promise.resolve(0)
} catch (e: Exception) {
promise.reject(e.message!!, ETError.InvalidModelPath.toString())
}
}

override fun forward(input: String, promise: Promise) {
try {
val inputImage = ImageProcessor.readImage(input)
val output = ssdLiteLarge.runModel(inputImage)
val outputWritableArray: WritableArray = Arguments.createArray()
output.map { detection ->
detection.toWritableMap()
}.forEach { writableMap ->
outputWritableArray.pushMap(writableMap)
}
promise.resolve(outputWritableArray)
} catch(e: Exception){
promise.reject(e.message!!, e.message)
}
}

override fun getName(): String {
return NAME
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ class RnExecutorchPackage : TurboReactPackage() {
StyleTransfer(reactContext)
} else if (name == Classification.NAME) {
Classification(reactContext)
}
} else if (name == ObjectDetection.NAME) {
ObjectDetection(reactContext)
}
else {
null
}
Expand Down Expand Up @@ -63,6 +65,15 @@ class RnExecutorchPackage : TurboReactPackage() {
false, // isCxxModule
true
)

moduleInfos[ObjectDetection.NAME] = ReactModuleInfo(
ObjectDetection.NAME,
ObjectDetection.NAME,
false, // canOverrideExistingModule
false, // needsEagerInit
false, // isCxxModule
true
)
moduleInfos
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package com.swmansion.rnexecutorch.models.objectdetection

import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.core.Mat
import org.opencv.core.Size
import org.opencv.imgproc.Imgproc
import com.swmansion.rnexecutorch.models.BaseModel
import com.swmansion.rnexecutorch.utils.Bbox
import com.swmansion.rnexecutorch.utils.CocoLabel
import com.swmansion.rnexecutorch.utils.Detection
import com.swmansion.rnexecutorch.utils.nms
import org.pytorch.executorch.EValue

const val detectionScoreThreshold = .7f
const val iouThreshold = .55f

class SSDLiteLargeModel(reactApplicationContext: ReactApplicationContext) : BaseModel<Mat, Array<Detection>>(reactApplicationContext) {
private var heightRatio: Float = 1.0f
private var widthRatio: Float = 1.0f

private fun getModelImageSize(): Size {
val inputShape = module.getInputShape(0)
val width = inputShape[inputShape.lastIndex]
val height = inputShape[inputShape.lastIndex - 1]

return Size(height.toDouble(), width.toDouble())
}

override fun preprocess(input: Mat): EValue {
this.widthRatio = (input.size().width / getModelImageSize().width).toFloat()
this.heightRatio = (input.size().height / getModelImageSize().height).toFloat()
Imgproc.resize(input, input, getModelImageSize())
return ImageProcessor.matToEValue(input, module.getInputShape(0))
}

override fun runModel(input: Mat): Array<Detection> {
val modelInput = preprocess(input)
val modelOutput = forward(modelInput)
return postprocess(modelOutput)
}

override fun postprocess(output: Array<EValue>): Array<Detection> {
val scoresTensor = output[1].toTensor()
val numel = scoresTensor.numel()
val bboxes = output[0].toTensor().dataAsFloatArray
val scores = scoresTensor.dataAsFloatArray
val labels = output[2].toTensor().dataAsFloatArray

val detections: MutableList<Detection> = mutableListOf();
for (idx in 0 until numel.toInt()) {
val score = scores[idx]
if (score < detectionScoreThreshold) {
continue
}
val bbox = Bbox(
bboxes[idx * 4 + 0] * this.widthRatio,
bboxes[idx * 4 + 1] * this.heightRatio,
bboxes[idx * 4 + 2] * this.widthRatio,
bboxes[idx * 4 + 3] * this.heightRatio
)
val label = labels[idx]
detections.add(
Detection(bbox, score, CocoLabel.fromId(label.toInt())!!)
)
}

val detectionsPostNms = nms(detections, iouThreshold);
return detectionsPostNms.toTypedArray()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
package com.swmansion.rnexecutorch.utils

import com.facebook.react.bridge.Arguments
import com.facebook.react.bridge.WritableMap

fun nms(
detections: MutableList<Detection>,
iouThreshold: Float
): List<Detection> {
if (detections.isEmpty()) {
return emptyList()
}

// Sort detections first by label, then by score (descending)
val sortedDetections = detections.sortedWith(compareBy({ it.label }, { -it.score }))

val result = mutableListOf<Detection>()

// Process NMS for each label group
var i = 0
while (i < sortedDetections.size) {
val currentLabel = sortedDetections[i].label

// Collect detections for the current label
val labelDetections = mutableListOf<Detection>()
while (i < sortedDetections.size && sortedDetections[i].label == currentLabel) {
labelDetections.add(sortedDetections[i])
i++
}

// Filter out detections with high IoU
val filteredLabelDetections = mutableListOf<Detection>()
while (labelDetections.isNotEmpty()) {
val current = labelDetections.removeAt(0)
filteredLabelDetections.add(current)

// Remove detections that overlap with the current detection above the IoU threshold
val iterator = labelDetections.iterator()
while (iterator.hasNext()) {
val other = iterator.next()
if (calculateIoU(current.bbox, other.bbox) > iouThreshold) {
iterator.remove() // Remove detection if IoU is above threshold
}
}
}

// Add the filtered detections to the result
result.addAll(filteredLabelDetections)
}

return result
}

fun calculateIoU(bbox1: Bbox, bbox2: Bbox): Float {
val x1 = maxOf(bbox1.x1, bbox2.x1)
val y1 = maxOf(bbox1.y1, bbox2.y1)
val x2 = minOf(bbox1.x2, bbox2.x2)
val y2 = minOf(bbox1.y2, bbox2.y2)

val intersectionArea = maxOf(0f, x2 - x1) * maxOf(0f, y2 - y1)
val bbox1Area = (bbox1.x2 - bbox1.x1) * (bbox1.y2 - bbox1.y1)
val bbox2Area = (bbox2.x2 - bbox2.x1) * (bbox2.y2 - bbox2.y1)

val unionArea = bbox1Area + bbox2Area - intersectionArea
return if (unionArea == 0f) 0f else intersectionArea / unionArea
}


data class Bbox(
val x1: Float,
val y1: Float,
val x2: Float,
val y2: Float
) {
fun toWritableMap(): WritableMap {
val map = Arguments.createMap()
map.putDouble("x1", x1.toDouble())
map.putDouble("x2", x2.toDouble())
map.putDouble("y1", y1.toDouble())
map.putDouble("y2", y2.toDouble())
return map
}
}


data class Detection(
val bbox: Bbox,
val score: Float,
val label: CocoLabel,
) {
fun toWritableMap(): WritableMap {
val map = Arguments.createMap()
map.putMap("bbox", bbox.toWritableMap())
map.putDouble("score", score.toDouble())
map.putString("label", label.name)
return map
}
}

enum class CocoLabel(val id: Int) {
PERSON(1),
BICYCLE(2),
CAR(3),
MOTORCYCLE(4),
AIRPLANE(5),
BUS(6),
TRAIN(7),
TRUCK(8),
BOAT(9),
TRAFFIC_LIGHT(10),
FIRE_HYDRANT(11),
STREET_SIGN(12),
STOP_SIGN(13),
PARKING(14),
BENCH(15),
BIRD(16),
CAT(17),
DOG(18),
HORSE(19),
SHEEP(20),
COW(21),
ELEPHANT(22),
BEAR(23),
ZEBRA(24),
GIRAFFE(25),
HAT(26),
BACKPACK(27),
UMBRELLA(28),
SHOE(29),
EYE(30),
HANDBAG(31),
TIE(32),
SUITCASE(33),
FRISBEE(34),
SKIS(35),
SNOWBOARD(36),
SPORTS(37),
KITE(38),
BASEBALL(39),
SKATEBOARD(41),
SURFBOARD(42),
TENNIS_RACKET(43),
BOTTLE(44),
PLATE(45),
WINE_GLASS(46),
CUP(47),
FORK(48),
KNIFE(49),
SPOON(50),
BOWL(51),
BANANA(52),
APPLE(53),
SANDWICH(54),
ORANGE(55),
BROCCOLI(56),
CARROT(57),
HOT_DOG(58),
PIZZA(59),
DONUT(60),
CAKE(61),
CHAIR(62),
COUCH(63),
POTTED_PLANT(64),
BED(65),
MIRROR(66),
DINING_TABLE(67),
WINDOW(68),
DESK(69),
TOILET(70),
DOOR(71),
TV(72),
LAPTOP(73),
MOUSE(74),
REMOTE(75),
KEYBOARD(76),
CELL_PHONE(77),
MICROWAVE(78),
OVEN(79),
TOASTER(80),
SINK(81),
REFRIGERATOR(82),
BLENDER(83),
BOOK(84),
CLOCK(85),
VASE(86),
SCISSORS(87),
TEDDY_BEAR(88),
HAIR_DRIER(89),
TOOTHBRUSH(90),
HAIR_BRUSH(91);

companion object {
private val idToLabelMap = values().associateBy(CocoLabel::id)
fun fromId(id: Int): CocoLabel? = idToLabelMap[id]
}
}
Loading
Loading