-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add object detection (android) (#52)
## Description <!-- Provide a concise and descriptive summary of the changes implemented in this PR. --> ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [ ] iOS - [x] Android ### Testing instructions <!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. --> ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [ ] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. -->
- Loading branch information
Showing
13 changed files
with
610 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
61 changes: 61 additions & 0 deletions
61
android/src/main/java/com/swmansion/rnexecutorch/ObjectDetection.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
package com.swmansion.rnexecutorch | ||
|
||
import android.util.Log | ||
import com.facebook.react.bridge.Arguments | ||
import com.facebook.react.bridge.Promise | ||
import com.facebook.react.bridge.ReactApplicationContext | ||
import com.facebook.react.bridge.WritableArray | ||
import com.swmansion.rnexecutorch.models.BaseModel | ||
import com.swmansion.rnexecutorch.utils.ETError | ||
import com.swmansion.rnexecutorch.utils.ImageProcessor | ||
import org.opencv.android.OpenCVLoader | ||
import com.swmansion.rnexecutorch.models.objectdetection.SSDLiteLargeModel | ||
import org.opencv.core.Mat | ||
|
||
class ObjectDetection(reactContext: ReactApplicationContext) : | ||
NativeObjectDetectionSpec(reactContext) { | ||
|
||
private lateinit var ssdLiteLarge: SSDLiteLargeModel | ||
|
||
companion object { | ||
const val NAME = "ObjectDetection" | ||
} | ||
|
||
init { | ||
if(!OpenCVLoader.initLocal()){ | ||
Log.d("rn_executorch", "OpenCV not loaded") | ||
} else { | ||
Log.d("rn_executorch", "OpenCV loaded") | ||
} | ||
} | ||
|
||
override fun loadModule(modelSource: String, promise: Promise) { | ||
try { | ||
ssdLiteLarge = SSDLiteLargeModel(reactApplicationContext) | ||
ssdLiteLarge.loadModel(modelSource) | ||
promise.resolve(0) | ||
} catch (e: Exception) { | ||
promise.reject(e.message!!, ETError.InvalidModelPath.toString()) | ||
} | ||
} | ||
|
||
override fun forward(input: String, promise: Promise) { | ||
try { | ||
val inputImage = ImageProcessor.readImage(input) | ||
val output = ssdLiteLarge.runModel(inputImage) | ||
val outputWritableArray: WritableArray = Arguments.createArray() | ||
output.map { detection -> | ||
detection.toWritableMap() | ||
}.forEach { writableMap -> | ||
outputWritableArray.pushMap(writableMap) | ||
} | ||
promise.resolve(outputWritableArray) | ||
} catch(e: Exception){ | ||
promise.reject(e.message!!, e.message) | ||
} | ||
} | ||
|
||
override fun getName(): String { | ||
return NAME | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
71 changes: 71 additions & 0 deletions
71
...oid/src/main/java/com/swmansion/rnexecutorch/models/object_detection/SSDLiteLargeModel.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
package com.swmansion.rnexecutorch.models.objectdetection | ||
|
||
import com.facebook.react.bridge.ReactApplicationContext | ||
import com.swmansion.rnexecutorch.utils.ImageProcessor | ||
import org.opencv.core.Mat | ||
import org.opencv.core.Size | ||
import org.opencv.imgproc.Imgproc | ||
import com.swmansion.rnexecutorch.models.BaseModel | ||
import com.swmansion.rnexecutorch.utils.Bbox | ||
import com.swmansion.rnexecutorch.utils.CocoLabel | ||
import com.swmansion.rnexecutorch.utils.Detection | ||
import com.swmansion.rnexecutorch.utils.nms | ||
import org.pytorch.executorch.EValue | ||
|
||
const val detectionScoreThreshold = .7f | ||
const val iouThreshold = .55f | ||
|
||
class SSDLiteLargeModel(reactApplicationContext: ReactApplicationContext) : BaseModel<Mat, Array<Detection>>(reactApplicationContext) { | ||
private var heightRatio: Float = 1.0f | ||
private var widthRatio: Float = 1.0f | ||
|
||
private fun getModelImageSize(): Size { | ||
val inputShape = module.getInputShape(0) | ||
val width = inputShape[inputShape.lastIndex] | ||
val height = inputShape[inputShape.lastIndex - 1] | ||
|
||
return Size(height.toDouble(), width.toDouble()) | ||
} | ||
|
||
override fun preprocess(input: Mat): EValue { | ||
this.widthRatio = (input.size().width / getModelImageSize().width).toFloat() | ||
this.heightRatio = (input.size().height / getModelImageSize().height).toFloat() | ||
Imgproc.resize(input, input, getModelImageSize()) | ||
return ImageProcessor.matToEValue(input, module.getInputShape(0)) | ||
} | ||
|
||
override fun runModel(input: Mat): Array<Detection> { | ||
val modelInput = preprocess(input) | ||
val modelOutput = forward(modelInput) | ||
return postprocess(modelOutput) | ||
} | ||
|
||
override fun postprocess(output: Array<EValue>): Array<Detection> { | ||
val scoresTensor = output[1].toTensor() | ||
val numel = scoresTensor.numel() | ||
val bboxes = output[0].toTensor().dataAsFloatArray | ||
val scores = scoresTensor.dataAsFloatArray | ||
val labels = output[2].toTensor().dataAsFloatArray | ||
|
||
val detections: MutableList<Detection> = mutableListOf(); | ||
for (idx in 0 until numel.toInt()) { | ||
val score = scores[idx] | ||
if (score < detectionScoreThreshold) { | ||
continue | ||
} | ||
val bbox = Bbox( | ||
bboxes[idx * 4 + 0] * this.widthRatio, | ||
bboxes[idx * 4 + 1] * this.heightRatio, | ||
bboxes[idx * 4 + 2] * this.widthRatio, | ||
bboxes[idx * 4 + 3] * this.heightRatio | ||
) | ||
val label = labels[idx] | ||
detections.add( | ||
Detection(bbox, score, CocoLabel.fromId(label.toInt())!!) | ||
) | ||
} | ||
|
||
val detectionsPostNms = nms(detections, iouThreshold); | ||
return detectionsPostNms.toTypedArray() | ||
} | ||
} |
196 changes: 196 additions & 0 deletions
196
android/src/main/java/com/swmansion/rnexecutorch/utils/ObjectDetectionUtils.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
package com.swmansion.rnexecutorch.utils | ||
|
||
import com.facebook.react.bridge.Arguments | ||
import com.facebook.react.bridge.WritableMap | ||
|
||
fun nms( | ||
detections: MutableList<Detection>, | ||
iouThreshold: Float | ||
): List<Detection> { | ||
if (detections.isEmpty()) { | ||
return emptyList() | ||
} | ||
|
||
// Sort detections first by label, then by score (descending) | ||
val sortedDetections = detections.sortedWith(compareBy({ it.label }, { -it.score })) | ||
|
||
val result = mutableListOf<Detection>() | ||
|
||
// Process NMS for each label group | ||
var i = 0 | ||
while (i < sortedDetections.size) { | ||
val currentLabel = sortedDetections[i].label | ||
|
||
// Collect detections for the current label | ||
val labelDetections = mutableListOf<Detection>() | ||
while (i < sortedDetections.size && sortedDetections[i].label == currentLabel) { | ||
labelDetections.add(sortedDetections[i]) | ||
i++ | ||
} | ||
|
||
// Filter out detections with high IoU | ||
val filteredLabelDetections = mutableListOf<Detection>() | ||
while (labelDetections.isNotEmpty()) { | ||
val current = labelDetections.removeAt(0) | ||
filteredLabelDetections.add(current) | ||
|
||
// Remove detections that overlap with the current detection above the IoU threshold | ||
val iterator = labelDetections.iterator() | ||
while (iterator.hasNext()) { | ||
val other = iterator.next() | ||
if (calculateIoU(current.bbox, other.bbox) > iouThreshold) { | ||
iterator.remove() // Remove detection if IoU is above threshold | ||
} | ||
} | ||
} | ||
|
||
// Add the filtered detections to the result | ||
result.addAll(filteredLabelDetections) | ||
} | ||
|
||
return result | ||
} | ||
|
||
fun calculateIoU(bbox1: Bbox, bbox2: Bbox): Float { | ||
val x1 = maxOf(bbox1.x1, bbox2.x1) | ||
val y1 = maxOf(bbox1.y1, bbox2.y1) | ||
val x2 = minOf(bbox1.x2, bbox2.x2) | ||
val y2 = minOf(bbox1.y2, bbox2.y2) | ||
|
||
val intersectionArea = maxOf(0f, x2 - x1) * maxOf(0f, y2 - y1) | ||
val bbox1Area = (bbox1.x2 - bbox1.x1) * (bbox1.y2 - bbox1.y1) | ||
val bbox2Area = (bbox2.x2 - bbox2.x1) * (bbox2.y2 - bbox2.y1) | ||
|
||
val unionArea = bbox1Area + bbox2Area - intersectionArea | ||
return if (unionArea == 0f) 0f else intersectionArea / unionArea | ||
} | ||
|
||
|
||
data class Bbox( | ||
val x1: Float, | ||
val y1: Float, | ||
val x2: Float, | ||
val y2: Float | ||
) { | ||
fun toWritableMap(): WritableMap { | ||
val map = Arguments.createMap() | ||
map.putDouble("x1", x1.toDouble()) | ||
map.putDouble("x2", x2.toDouble()) | ||
map.putDouble("y1", y1.toDouble()) | ||
map.putDouble("y2", y2.toDouble()) | ||
return map | ||
} | ||
} | ||
|
||
|
||
data class Detection( | ||
val bbox: Bbox, | ||
val score: Float, | ||
val label: CocoLabel, | ||
) { | ||
fun toWritableMap(): WritableMap { | ||
val map = Arguments.createMap() | ||
map.putMap("bbox", bbox.toWritableMap()) | ||
map.putDouble("score", score.toDouble()) | ||
map.putString("label", label.name) | ||
return map | ||
} | ||
} | ||
|
||
enum class CocoLabel(val id: Int) { | ||
PERSON(1), | ||
BICYCLE(2), | ||
CAR(3), | ||
MOTORCYCLE(4), | ||
AIRPLANE(5), | ||
BUS(6), | ||
TRAIN(7), | ||
TRUCK(8), | ||
BOAT(9), | ||
TRAFFIC_LIGHT(10), | ||
FIRE_HYDRANT(11), | ||
STREET_SIGN(12), | ||
STOP_SIGN(13), | ||
PARKING(14), | ||
BENCH(15), | ||
BIRD(16), | ||
CAT(17), | ||
DOG(18), | ||
HORSE(19), | ||
SHEEP(20), | ||
COW(21), | ||
ELEPHANT(22), | ||
BEAR(23), | ||
ZEBRA(24), | ||
GIRAFFE(25), | ||
HAT(26), | ||
BACKPACK(27), | ||
UMBRELLA(28), | ||
SHOE(29), | ||
EYE(30), | ||
HANDBAG(31), | ||
TIE(32), | ||
SUITCASE(33), | ||
FRISBEE(34), | ||
SKIS(35), | ||
SNOWBOARD(36), | ||
SPORTS(37), | ||
KITE(38), | ||
BASEBALL(39), | ||
SKATEBOARD(41), | ||
SURFBOARD(42), | ||
TENNIS_RACKET(43), | ||
BOTTLE(44), | ||
PLATE(45), | ||
WINE_GLASS(46), | ||
CUP(47), | ||
FORK(48), | ||
KNIFE(49), | ||
SPOON(50), | ||
BOWL(51), | ||
BANANA(52), | ||
APPLE(53), | ||
SANDWICH(54), | ||
ORANGE(55), | ||
BROCCOLI(56), | ||
CARROT(57), | ||
HOT_DOG(58), | ||
PIZZA(59), | ||
DONUT(60), | ||
CAKE(61), | ||
CHAIR(62), | ||
COUCH(63), | ||
POTTED_PLANT(64), | ||
BED(65), | ||
MIRROR(66), | ||
DINING_TABLE(67), | ||
WINDOW(68), | ||
DESK(69), | ||
TOILET(70), | ||
DOOR(71), | ||
TV(72), | ||
LAPTOP(73), | ||
MOUSE(74), | ||
REMOTE(75), | ||
KEYBOARD(76), | ||
CELL_PHONE(77), | ||
MICROWAVE(78), | ||
OVEN(79), | ||
TOASTER(80), | ||
SINK(81), | ||
REFRIGERATOR(82), | ||
BLENDER(83), | ||
BOOK(84), | ||
CLOCK(85), | ||
VASE(86), | ||
SCISSORS(87), | ||
TEDDY_BEAR(88), | ||
HAIR_DRIER(89), | ||
TOOTHBRUSH(90), | ||
HAIR_BRUSH(91); | ||
|
||
companion object { | ||
private val idToLabelMap = values().associateBy(CocoLabel::id) | ||
fun fromId(id: Int): CocoLabel? = idToLabelMap[id] | ||
} | ||
} |
Oops, something went wrong.