Skip to content

Commit

Permalink
feat: add object detection (android) (#52)
Browse files Browse the repository at this point in the history
## Description
<!-- Provide a concise and descriptive summary of the changes
implemented in this PR. -->

### Type of change
- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] Documentation update (improves or adds clarity to existing
documentation)

### Tested on
- [ ] iOS
- [x] Android

### Testing instructions
<!-- Provide step-by-step instructions on how to test your changes.
Include setup details if necessary. -->

### Screenshots
<!-- Add screenshots here, if applicable -->

### Related issues
<!-- Link related issues here using #issue-number -->

### Checklist
- [x] I have performed a self-review of my code
- [x] I have commented my code, particularly in hard-to-understand areas
- [ ] I have updated the documentation accordingly
- [ ] My changes generate no new warnings

### Additional notes
<!-- Include any additional information, assumptions, or context that
reviewers might need to understand this PR. -->
  • Loading branch information
chmjkb authored Dec 17, 2024
1 parent 962b3df commit 1499364
Show file tree
Hide file tree
Showing 13 changed files with 610 additions and 13 deletions.
9 changes: 7 additions & 2 deletions android/build.gradle
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
buildscript {
ext {
agp_version = '8.4.2'
}

// Buildscript is evaluated before everything else so we can't use getExtOrDefault
def kotlin_version = rootProject.ext.has("kotlinVersion") ? rootProject.ext.get("kotlinVersion") : project.properties["RnExecutorch_kotlinVersion"]

Expand All @@ -9,7 +13,7 @@ buildscript {
}

dependencies {
classpath "com.android.tools.build:gradle:7.2.1"
classpath "com.android.tools.build:gradle:$agp_version"
// noinspection DifferentKotlinGradleVersion
classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version"
}
Expand Down Expand Up @@ -95,7 +99,8 @@ dependencies {
// For < 0.71, this will be from the local maven repo
// For > 0.71, this will be replaced by `com.facebook.react:react-android:$version` by react gradle plugin
//noinspection GradleDynamicVersion
implementation "com.facebook.react:react-native:+"
implementation "com.facebook.react:react-android:+"
implementation 'org.opencv:opencv:4.10.0'
implementation "org.jetbrains.kotlin:kotlin-stdlib:$kotlin_version"
implementation 'com.github.software-mansion:react-native-executorch:main-SNAPSHOT'
implementation 'org.opencv:opencv:4.10.0'
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package com.swmansion.rnexecutorch

import android.util.Log
import com.facebook.react.bridge.Arguments
import com.facebook.react.bridge.Promise
import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.bridge.WritableArray
import com.swmansion.rnexecutorch.models.BaseModel
import com.swmansion.rnexecutorch.utils.ETError
import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.android.OpenCVLoader
import com.swmansion.rnexecutorch.models.objectdetection.SSDLiteLargeModel
import org.opencv.core.Mat

class ObjectDetection(reactContext: ReactApplicationContext) :
NativeObjectDetectionSpec(reactContext) {

private lateinit var ssdLiteLarge: SSDLiteLargeModel

companion object {
const val NAME = "ObjectDetection"
}

init {
if(!OpenCVLoader.initLocal()){
Log.d("rn_executorch", "OpenCV not loaded")
} else {
Log.d("rn_executorch", "OpenCV loaded")
}
}

override fun loadModule(modelSource: String, promise: Promise) {
try {
ssdLiteLarge = SSDLiteLargeModel(reactApplicationContext)
ssdLiteLarge.loadModel(modelSource)
promise.resolve(0)
} catch (e: Exception) {
promise.reject(e.message!!, ETError.InvalidModelPath.toString())
}
}

override fun forward(input: String, promise: Promise) {
try {
val inputImage = ImageProcessor.readImage(input)
val output = ssdLiteLarge.runModel(inputImage)
val outputWritableArray: WritableArray = Arguments.createArray()
output.map { detection ->
detection.toWritableMap()
}.forEach { writableMap ->
outputWritableArray.pushMap(writableMap)
}
promise.resolve(outputWritableArray)
} catch(e: Exception){
promise.reject(e.message!!, e.message)
}
}

override fun getName(): String {
return NAME
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ class RnExecutorchPackage : TurboReactPackage() {
StyleTransfer(reactContext)
} else if (name == Classification.NAME) {
Classification(reactContext)
}
} else if (name == ObjectDetection.NAME) {
ObjectDetection(reactContext)
}
else {
null
}
Expand Down Expand Up @@ -63,6 +65,15 @@ class RnExecutorchPackage : TurboReactPackage() {
false, // isCxxModule
true
)

moduleInfos[ObjectDetection.NAME] = ReactModuleInfo(
ObjectDetection.NAME,
ObjectDetection.NAME,
false, // canOverrideExistingModule
false, // needsEagerInit
false, // isCxxModule
true
)
moduleInfos
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package com.swmansion.rnexecutorch.models.objectdetection

import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.core.Mat
import org.opencv.core.Size
import org.opencv.imgproc.Imgproc
import com.swmansion.rnexecutorch.models.BaseModel
import com.swmansion.rnexecutorch.utils.Bbox
import com.swmansion.rnexecutorch.utils.CocoLabel
import com.swmansion.rnexecutorch.utils.Detection
import com.swmansion.rnexecutorch.utils.nms
import org.pytorch.executorch.EValue

const val detectionScoreThreshold = .7f
const val iouThreshold = .55f

class SSDLiteLargeModel(reactApplicationContext: ReactApplicationContext) : BaseModel<Mat, Array<Detection>>(reactApplicationContext) {
private var heightRatio: Float = 1.0f
private var widthRatio: Float = 1.0f

private fun getModelImageSize(): Size {
val inputShape = module.getInputShape(0)
val width = inputShape[inputShape.lastIndex]
val height = inputShape[inputShape.lastIndex - 1]

return Size(height.toDouble(), width.toDouble())
}

override fun preprocess(input: Mat): EValue {
this.widthRatio = (input.size().width / getModelImageSize().width).toFloat()
this.heightRatio = (input.size().height / getModelImageSize().height).toFloat()
Imgproc.resize(input, input, getModelImageSize())
return ImageProcessor.matToEValue(input, module.getInputShape(0))
}

override fun runModel(input: Mat): Array<Detection> {
val modelInput = preprocess(input)
val modelOutput = forward(modelInput)
return postprocess(modelOutput)
}

override fun postprocess(output: Array<EValue>): Array<Detection> {
val scoresTensor = output[1].toTensor()
val numel = scoresTensor.numel()
val bboxes = output[0].toTensor().dataAsFloatArray
val scores = scoresTensor.dataAsFloatArray
val labels = output[2].toTensor().dataAsFloatArray

val detections: MutableList<Detection> = mutableListOf();
for (idx in 0 until numel.toInt()) {
val score = scores[idx]
if (score < detectionScoreThreshold) {
continue
}
val bbox = Bbox(
bboxes[idx * 4 + 0] * this.widthRatio,
bboxes[idx * 4 + 1] * this.heightRatio,
bboxes[idx * 4 + 2] * this.widthRatio,
bboxes[idx * 4 + 3] * this.heightRatio
)
val label = labels[idx]
detections.add(
Detection(bbox, score, CocoLabel.fromId(label.toInt())!!)
)
}

val detectionsPostNms = nms(detections, iouThreshold);
return detectionsPostNms.toTypedArray()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
package com.swmansion.rnexecutorch.utils

import com.facebook.react.bridge.Arguments
import com.facebook.react.bridge.WritableMap

fun nms(
detections: MutableList<Detection>,
iouThreshold: Float
): List<Detection> {
if (detections.isEmpty()) {
return emptyList()
}

// Sort detections first by label, then by score (descending)
val sortedDetections = detections.sortedWith(compareBy({ it.label }, { -it.score }))

val result = mutableListOf<Detection>()

// Process NMS for each label group
var i = 0
while (i < sortedDetections.size) {
val currentLabel = sortedDetections[i].label

// Collect detections for the current label
val labelDetections = mutableListOf<Detection>()
while (i < sortedDetections.size && sortedDetections[i].label == currentLabel) {
labelDetections.add(sortedDetections[i])
i++
}

// Filter out detections with high IoU
val filteredLabelDetections = mutableListOf<Detection>()
while (labelDetections.isNotEmpty()) {
val current = labelDetections.removeAt(0)
filteredLabelDetections.add(current)

// Remove detections that overlap with the current detection above the IoU threshold
val iterator = labelDetections.iterator()
while (iterator.hasNext()) {
val other = iterator.next()
if (calculateIoU(current.bbox, other.bbox) > iouThreshold) {
iterator.remove() // Remove detection if IoU is above threshold
}
}
}

// Add the filtered detections to the result
result.addAll(filteredLabelDetections)
}

return result
}

fun calculateIoU(bbox1: Bbox, bbox2: Bbox): Float {
val x1 = maxOf(bbox1.x1, bbox2.x1)
val y1 = maxOf(bbox1.y1, bbox2.y1)
val x2 = minOf(bbox1.x2, bbox2.x2)
val y2 = minOf(bbox1.y2, bbox2.y2)

val intersectionArea = maxOf(0f, x2 - x1) * maxOf(0f, y2 - y1)
val bbox1Area = (bbox1.x2 - bbox1.x1) * (bbox1.y2 - bbox1.y1)
val bbox2Area = (bbox2.x2 - bbox2.x1) * (bbox2.y2 - bbox2.y1)

val unionArea = bbox1Area + bbox2Area - intersectionArea
return if (unionArea == 0f) 0f else intersectionArea / unionArea
}


data class Bbox(
val x1: Float,
val y1: Float,
val x2: Float,
val y2: Float
) {
fun toWritableMap(): WritableMap {
val map = Arguments.createMap()
map.putDouble("x1", x1.toDouble())
map.putDouble("x2", x2.toDouble())
map.putDouble("y1", y1.toDouble())
map.putDouble("y2", y2.toDouble())
return map
}
}


data class Detection(
val bbox: Bbox,
val score: Float,
val label: CocoLabel,
) {
fun toWritableMap(): WritableMap {
val map = Arguments.createMap()
map.putMap("bbox", bbox.toWritableMap())
map.putDouble("score", score.toDouble())
map.putString("label", label.name)
return map
}
}

enum class CocoLabel(val id: Int) {
PERSON(1),
BICYCLE(2),
CAR(3),
MOTORCYCLE(4),
AIRPLANE(5),
BUS(6),
TRAIN(7),
TRUCK(8),
BOAT(9),
TRAFFIC_LIGHT(10),
FIRE_HYDRANT(11),
STREET_SIGN(12),
STOP_SIGN(13),
PARKING(14),
BENCH(15),
BIRD(16),
CAT(17),
DOG(18),
HORSE(19),
SHEEP(20),
COW(21),
ELEPHANT(22),
BEAR(23),
ZEBRA(24),
GIRAFFE(25),
HAT(26),
BACKPACK(27),
UMBRELLA(28),
SHOE(29),
EYE(30),
HANDBAG(31),
TIE(32),
SUITCASE(33),
FRISBEE(34),
SKIS(35),
SNOWBOARD(36),
SPORTS(37),
KITE(38),
BASEBALL(39),
SKATEBOARD(41),
SURFBOARD(42),
TENNIS_RACKET(43),
BOTTLE(44),
PLATE(45),
WINE_GLASS(46),
CUP(47),
FORK(48),
KNIFE(49),
SPOON(50),
BOWL(51),
BANANA(52),
APPLE(53),
SANDWICH(54),
ORANGE(55),
BROCCOLI(56),
CARROT(57),
HOT_DOG(58),
PIZZA(59),
DONUT(60),
CAKE(61),
CHAIR(62),
COUCH(63),
POTTED_PLANT(64),
BED(65),
MIRROR(66),
DINING_TABLE(67),
WINDOW(68),
DESK(69),
TOILET(70),
DOOR(71),
TV(72),
LAPTOP(73),
MOUSE(74),
REMOTE(75),
KEYBOARD(76),
CELL_PHONE(77),
MICROWAVE(78),
OVEN(79),
TOASTER(80),
SINK(81),
REFRIGERATOR(82),
BLENDER(83),
BOOK(84),
CLOCK(85),
VASE(86),
SCISSORS(87),
TEDDY_BEAR(88),
HAIR_DRIER(89),
TOOTHBRUSH(90),
HAIR_BRUSH(91);

companion object {
private val idToLabelMap = values().associateBy(CocoLabel::id)
fun fromId(id: Int): CocoLabel? = idToLabelMap[id]
}
}
Loading

0 comments on commit 1499364

Please sign in to comment.