diff --git a/docs/source/creators/creators_add_metadata_to_model.rst b/docs/source/creators/creators_add_metadata_to_model.rst index 2b8fbc0..e6d9740 100644 --- a/docs/source/creators/creators_add_metadata_to_model.rst +++ b/docs/source/creators/creators_add_metadata_to_model.rst @@ -53,6 +53,7 @@ Availeble detector types: - :code:`YOLO_v9` - :code:`YOLO_Ultralytics` - :code:`YOLO_Ultralytics_segmentation` +- :code:`YOLO_Ultralytics_obb` ======= Example diff --git a/src/deepness/common/processing_parameters/detection_parameters.py b/src/deepness/common/processing_parameters/detection_parameters.py index 1f48830..c96232c 100644 --- a/src/deepness/common/processing_parameters/detection_parameters.py +++ b/src/deepness/common/processing_parameters/detection_parameters.py @@ -23,6 +23,7 @@ class DetectorType(enum.Enum): YOLO_v9 = 'YOLO_v9' YOLO_ULTRALYTICS = 'YOLO_Ultralytics' YOLO_ULTRALYTICS_SEGMENTATION = 'YOLO_Ultralytics_segmentation' + YOLO_ULTRALYTICS_OBB = 'YOLO_Ultralytics_obb' def get_parameters(self): if self == DetectorType.YOLO_v5_v7_DEFAULT: @@ -36,7 +37,7 @@ def get_parameters(self): has_inverted_output_shape=True, skipped_objectness_probability=True, ) - elif self == DetectorType.YOLO_ULTRALYTICS or self == DetectorType.YOLO_ULTRALYTICS_SEGMENTATION: + elif self == DetectorType.YOLO_ULTRALYTICS or self == DetectorType.YOLO_ULTRALYTICS_SEGMENTATION or self == DetectorType.YOLO_ULTRALYTICS_OBB: return DetectorTypeParameters( has_inverted_output_shape=True, skipped_objectness_probability=True, diff --git a/src/deepness/metadata.txt b/src/deepness/metadata.txt index dcaa04a..f675b3f 100644 --- a/src/deepness/metadata.txt +++ b/src/deepness/metadata.txt @@ -6,7 +6,7 @@ name=Deepness: Deep Neural Remote Sensing qgisMinimumVersion=3.22 description=Inference of deep neural network models (ONNX) for segmentation, detection and regression -version=0.6.3 +version=0.6.4 author=PUT Vision email=przemyslaw.aszkowski@gmail.com diff --git a/src/deepness/processing/map_processor/map_processor_detection.py b/src/deepness/processing/map_processor/map_processor_detection.py index 967a334..6b9101b 100644 --- a/src/deepness/processing/map_processor/map_processor_detection.py +++ b/src/deepness/processing/map_processor/map_processor_detection.py @@ -13,6 +13,7 @@ from deepness.processing.map_processor.utils.ckdtree import cKDTree from deepness.processing.models.detector import Detection, Detector from deepness.processing.tile_params import TileParams +from deepness.processing.models.detector import DetectorType class MapProcessorDetection(MapProcessorWithModel): @@ -49,8 +50,10 @@ def _run(self) -> MapProcessingResult: bounding_boxes_in_tile_batched = self._process_tile(tile_img_batched, tile_params_batched) all_bounding_boxes += [d for det in bounding_boxes_in_tile_batched for d in det] + with_rot = self.detection_parameters.detector_type == DetectorType.YOLO_ULTRALYTICS_OBB + if len(all_bounding_boxes) > 0: - all_bounding_boxes_nms = self.remove_overlaping_detections(all_bounding_boxes, iou_threshold=self.detection_parameters.iou_threshold) + all_bounding_boxes_nms = self.remove_overlaping_detections(all_bounding_boxes, iou_threshold=self.detection_parameters.iou_threshold, with_rot=with_rot) all_bounding_boxes_restricted = self.limit_bounding_boxes_to_processed_area(all_bounding_boxes_nms) else: all_bounding_boxes_restricted = [] @@ -197,17 +200,20 @@ def add_to_gui(): return add_to_gui @staticmethod - def remove_overlaping_detections(bounding_boxes: List[Detection], iou_threshold: float) -> List[Detection]: + def remove_overlaping_detections(bounding_boxes: List[Detection], iou_threshold: float, with_rot: bool = False) -> List[Detection]: bboxes = [] probs = [] for det in bounding_boxes: - bboxes.append(det.get_bbox_xyxy()) + if with_rot: + bboxes.append(det.get_bbox_xyxy_rot()) + else: + bboxes.append(det.get_bbox_xyxy()) probs.append(det.conf) bboxes = np.array(bboxes) probs = np.array(probs) - pick_ids = Detector.non_max_suppression_fast(bboxes, probs, iou_threshold) + pick_ids = Detector.non_max_suppression_fast(boxes=bboxes, probs=probs, iou_threshold=iou_threshold, with_rot=with_rot) filtered_bounding_boxes = [x for i, x in enumerate(bounding_boxes) if i in pick_ids] filtered_bounding_boxes = sorted(filtered_bounding_boxes, reverse=True) diff --git a/src/deepness/processing/models/detector.py b/src/deepness/processing/models/detector.py index 84030d0..474802d 100644 --- a/src/deepness/processing/models/detector.py +++ b/src/deepness/processing/models/detector.py @@ -2,6 +2,7 @@ """ from dataclasses import dataclass from typing import List, Optional, Tuple +from qgis.core import Qgis, QgsGeometry, QgsRectangle, QgsPointXY import cv2 import numpy as np @@ -34,6 +35,7 @@ class of the detected object mask: Optional[np.ndarray] = None """np.ndarray: mask of the detected object""" mask_offsets: Optional[Tuple[int, int]] = None + """Tuple[int, int]: offsets of the mask""" def convert_to_global(self, offset_x: int, offset_y: int): """Apply (x,y) offset to bounding box coordinates @@ -59,6 +61,16 @@ def get_bbox_xyxy(self) -> np.ndarray: Array in (x1, y1, x2, y2) format """ return self.bbox.get_xyxy() + + def get_bbox_xyxy_rot(self) -> np.ndarray: + """Convert stored bounding box into x1y1x2y2r format + + Returns + ------- + np.ndarray + Array in (x1, y1, x2, y2, r) format + """ + return self.bbox.get_xyxy_rot() def get_bbox_center(self) -> Tuple[int, int]: """Get center of the bounding box @@ -146,11 +158,19 @@ def get_number_of_output_channels(self): shape_index = -2 if model_type_params.has_inverted_output_shape else -1 if len(self.outputs_layers) == 1: - if model_type_params.skipped_objectness_probability: + # YOLO_ULTRALYTICS_OBB + if self.model_type == DetectorType.YOLO_ULTRALYTICS_OBB: + return [self.outputs_layers[0].shape[shape_index] - 4 - 1] + + elif model_type_params.skipped_objectness_probability: return [self.outputs_layers[0].shape[shape_index] - 4] + return [self.outputs_layers[0].shape[shape_index] - 4 - 1] # shape - 4 bboxes - 1 conf + + # YOLO_ULTRALYTICS_SEGMENTATION elif len(self.outputs_layers) == 2 and self.model_type == DetectorType.YOLO_ULTRALYTICS_SEGMENTATION: return [self.outputs_layers[0].shape[shape_index] - 4 - self.outputs_layers[1].shape[1]] + else: raise NotImplementedError("Model with multiple output layer is not supported! Use only one output layer.") @@ -182,13 +202,12 @@ def postprocessing(self, model_output): batch_detection = [] outputs_range = len(model_output) - if self.model_type == DetectorType.YOLO_ULTRALYTICS_SEGMENTATION: - outputs_range = len(model_output[0]) - elif self.model_type == DetectorType.YOLO_v9: + if self.model_type == DetectorType.YOLO_ULTRALYTICS_SEGMENTATION or self.model_type == DetectorType.YOLO_v9: outputs_range = len(model_output[0]) for i in range(outputs_range): masks = None + rots = None detections = [] if self.model_type == DetectorType.YOLO_v5_v7_DEFAULT: @@ -201,18 +220,22 @@ def postprocessing(self, model_output): boxes, conf, classes = self._postprocessing_YOLO_ULTRALYTICS(model_output[0][i]) elif self.model_type == DetectorType.YOLO_ULTRALYTICS_SEGMENTATION: boxes, conf, classes, masks = self._postprocessing_YOLO_ULTRALYTICS_SEGMENTATION(model_output[0][i], model_output[1][i]) + elif self.model_type == DetectorType.YOLO_ULTRALYTICS_OBB: + boxes, conf, classes, rots = self._postprocessing_YOLO_ULTRALYTICS_OBB(model_output[0][i]) else: raise NotImplementedError(f"Model type not implemented! ('{self.model_type}')") masks = masks if masks is not None else [None] * len(boxes) + rots = rots if rots is not None else [0.0] * len(boxes) - for b, c, cl, m in zip(boxes, conf, classes, masks): + for b, c, cl, m, r in zip(boxes, conf, classes, masks, rots): det = Detection( bbox=BoundingBox( x_min=b[0], x_max=b[2], y_min=b[1], - y_max=b[3]), + y_max=b[3], + rot=r), conf=c, clss=cl, mask=m, @@ -360,6 +383,36 @@ def _postprocessing_YOLO_ULTRALYTICS_SEGMENTATION(self, detections, protos): return boxes, conf, classes, masks + def _postprocessing_YOLO_ULTRALYTICS_OBB(self, model_output): + model_output = np.transpose(model_output, (1, 0)) + + outputs_filtered = np.array( + list(filter(lambda x: np.max(x[4:-1]) >= self.confidence, model_output)) + ) + + if len(outputs_filtered.shape) < 2: + return [], [], [], [] + + probabilities = np.max(outputs_filtered[:, 4:-1], axis=1) + rotations = outputs_filtered[:, -1] + + outputs_x1y1x2y2_rot = self.xywhr2xyxyr(outputs_filtered, rotations) + + pick_indxs = self.non_max_suppression_fast( + outputs_x1y1x2y2_rot, + probs=probabilities, + iou_threshold=self.iou_threshold, + with_rot=True) + + outputs_nms = outputs_x1y1x2y2_rot[pick_indxs] + + boxes = np.array(outputs_nms[:, :4], dtype=int) + conf = np.max(outputs_nms[:, 4:-1], axis=1) + classes = np.argmax(outputs_nms[:, 4:-1], axis=1) + rots = outputs_nms[:, -1] + + return boxes, conf, classes, rots + # based on https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/ops.py#L638C1-L638C67 def process_mask(self, protos, masks_in, bboxes): c, mh, mw = protos.shape # CHW @@ -404,7 +457,7 @@ def xywh2xyxy(x: np.ndarray) -> np.ndarray: Parameters ---------- x : np.ndarray - Bounding box in (x,y,w,h) format + Bounding box in (x,y,w,h) format with classes' probabilities Returns ------- @@ -419,7 +472,33 @@ def xywh2xyxy(x: np.ndarray) -> np.ndarray: return y @staticmethod - def non_max_suppression_fast(boxes: np.ndarray, probs: np.ndarray, iou_threshold: float) -> List: + def xywhr2xyxyr(bbox: np.ndarray, rot: np.ndarray) -> np.ndarray: + """Convert bounding box from (x,y,w,h,r) to (x1,y1,x2,y2,r) format, keeping rotated boxes in range [0, pi/2] + + Parameters + ---------- + bbox : np.ndarray + Bounding box in (x,y,w,h) format with classes' probabilities and rotations + + Returns + ------- + np.ndarray + Bounding box in (x1,y1,x2,y2,r) format, keep classes' probabilities + """ + x, y, w, h = bbox[:, 0], bbox[:, 1], bbox[:, 2], bbox[:, 3] + + w_ = np.where(w > h, w, h) + h_ = np.where(w > h, h, w) + r_ = np.where(w > h, rot, rot + np.pi / 2) % np.pi + + new_bbox_xywh = np.stack([x, y, w_, h_], axis=1) + new_bbox_xyxy = Detector.xywh2xyxy(new_bbox_xywh) + + return np.concatenate([new_bbox_xyxy, bbox[:, 4:-1], r_[:, None]], axis=1) + + + @staticmethod + def non_max_suppression_fast(boxes: np.ndarray, probs: np.ndarray, iou_threshold: float, with_rot: bool = False) -> List: """Apply non-maximum suppression to bounding boxes Based on: @@ -428,18 +507,19 @@ def non_max_suppression_fast(boxes: np.ndarray, probs: np.ndarray, iou_threshold Parameters ---------- boxes : np.ndarray - Bounding boxes in (x1,y1,x2,y2) format + Bounding boxes in (x1,y1,x2,y2) format or (x1,y1,x2,y2,r) format if with_rot is True probs : np.ndarray Confidence scores iou_threshold : float IoU threshold + with_rot: bool + If True, use rotated IoU Returns ------- List List of indexes of bounding boxes to keep """ - # If no bounding boxes, return empty list if len(boxes) == 0: return [] @@ -453,6 +533,10 @@ def non_max_suppression_fast(boxes: np.ndarray, probs: np.ndarray, iou_threshold end_x = boxes[:, 2] end_y = boxes[:, 3] + if with_rot: + # Rotations of bounding boxes + rotations = boxes[:, 4] + # Confidence scores of bounding boxes score = np.array(probs) @@ -473,25 +557,137 @@ def non_max_suppression_fast(boxes: np.ndarray, probs: np.ndarray, iou_threshold # Pick the bounding box with largest confidence score picked_boxes.append(index) - # Compute ordinates of intersection-over-union(IOU) - x1 = np.maximum(start_x[index], start_x[order[:-1]]) - x2 = np.minimum(end_x[index], end_x[order[:-1]]) - y1 = np.maximum(start_y[index], start_y[order[:-1]]) - y2 = np.minimum(end_y[index], end_y[order[:-1]]) - - # Compute areas of intersection-over-union - w = np.maximum(0.0, x2 - x1 + 1) - h = np.maximum(0.0, y2 - y1 + 1) - intersection = w * h - - # Compute the ratio between intersection and union - ratio = intersection / (areas[index] + areas[order[:-1]] - intersection) + if not with_rot: + ratio = Detector.compute_iou(index, order, start_x, start_y, end_x, end_y, areas) + else: + ratio = Detector.compute_rotated_iou(index, order, start_x, start_y, end_x, end_y, rotations, areas) left = np.where(ratio < iou_threshold) order = order[left] return picked_boxes + @staticmethod + def compute_iou(index: int, order: np.ndarray, start_x: np.ndarray, start_y: np.ndarray, end_x: np.ndarray, end_y: np.ndarray, areas: np.ndarray) -> np.ndarray: + """Compute IoU for bounding boxes + + Parameters + ---------- + index : int + Index of the bounding box + order : np.ndarray + Order of bounding boxes + start_x : np.ndarray + Start x coordinate of bounding boxes + start_y : np.ndarray + Start y coordinate of bounding boxes + end_x : np.ndarray + End x coordinate of bounding boxes + end_y : np.ndarray + End y coordinate of bounding boxes + areas : np.ndarray + Areas of bounding boxes + + Returns + ------- + np.ndarray + IoU values + """ + + # Compute ordinates of intersection-over-union(IOU) + x1 = np.maximum(start_x[index], start_x[order[:-1]]) + x2 = np.minimum(end_x[index], end_x[order[:-1]]) + y1 = np.maximum(start_y[index], start_y[order[:-1]]) + y2 = np.minimum(end_y[index], end_y[order[:-1]]) + + # Compute areas of intersection-over-union + w = np.maximum(0.0, x2 - x1 + 1) + h = np.maximum(0.0, y2 - y1 + 1) + intersection = w * h + + # Compute the ratio between intersection and union + return intersection / (areas[index] + areas[order[:-1]] - intersection) + + + @staticmethod + def compute_rotated_iou(index: int, order: np.ndarray, start_x: np.ndarray, start_y: np.ndarray, end_x: np.ndarray, end_y: np.ndarray, rotations: np.ndarray, areas: np.ndarray) -> np.ndarray: + """Compute IoU for rotated bounding boxes + + Parameters + ---------- + index : int + Index of the bounding box + order : np.ndarray + Order of bounding boxes + start_x : np.ndarray + Start x coordinate of bounding boxes + start_y : np.ndarray + Start y coordinate of bounding boxes + end_x : np.ndarray + End x coordinate of bounding boxes + end_y : np.ndarray + End y coordinate of bounding boxes + rotations : np.ndarray + Rotations of bounding boxes (in radians, around the center) + areas : np.ndarray + Areas of bounding boxes + + Returns + ------- + np.ndarray + IoU values + """ + + def create_rotated_geom(x1, y1, x2, y2, rotation): + """Helper function to create a rotated QgsGeometry rectangle""" + # Define the corners of the box before rotation + center_x = (x1 + x2) / 2 + center_y = (y1 + y2) / 2 + + # Create a rectangle using QgsRectangle + rect = QgsRectangle(QgsPointXY(x1, y1), QgsPointXY(x2, y2)) + + # Convert to QgsGeometry + geom = QgsGeometry.fromRect(rect) + + # Rotate the geometry around its center + result = geom.rotate(np.degrees(rotation), QgsPointXY(center_x, center_y)) + + if result == Qgis.GeometryOperationResult.Success: + return geom + else: + return QgsGeometry() + + # Create the rotated geometry for the current bounding box + geom1 = create_rotated_geom(start_x[index], start_y[index], end_x[index], end_y[index], rotations[index]) + + iou_values = [] + + # Iterate over the rest of the boxes in order and calculate IoU + for i in range(len(order) - 1): + # Create the rotated geometry for the other boxes in the order + geom2 = create_rotated_geom(start_x[order[i]], start_y[order[i]], end_x[order[i]], end_y[order[i]], rotations[order[i]]) + + # Compute the intersection geometry + intersection_geom = geom1.intersection(geom2) + + # Check if intersection is empty + if intersection_geom.isEmpty(): + intersection_area = 0.0 + else: + # Compute the intersection area + intersection_area = intersection_geom.area() + + # Compute the union area + union_area = areas[index] + areas[order[i]] - intersection_area + + # Compute IoU + iou = intersection_area / union_area if union_area > 0 else 0.0 + iou_values.append(iou) + + return np.array(iou_values) + + def check_loaded_model_outputs(self): """Check if model outputs are valid. Valid model are: diff --git a/src/deepness/processing/processing_utils.py b/src/deepness/processing/processing_utils.py index 5c0fcb9..f68dc6b 100644 --- a/src/deepness/processing/processing_utils.py +++ b/src/deepness/processing/processing_utils.py @@ -261,6 +261,7 @@ class BoundingBox: x_max: int y_min: int y_max: int + rot: float = 0.0 def get_shape(self) -> Tuple[int, int]: """ Returns the shape of the bounding box as a tuple (height, width) @@ -289,6 +290,22 @@ def get_xyxy(self) -> Tuple[int, int, int, int]: self.x_max, self.y_max ] + + def get_xyxy_rot(self) -> Tuple[int, int, int, int, float]: + """ Returns the bounding box as a tuple (x_min, y_min, x_max, y_max, rotation) + + Returns + ------- + Tuple[int, int, int, int, float] + (x_min, y_min, x_max, y_max, rotation) + """ + return [ + self.x_min, + self.y_min, + self.x_max, + self.y_max, + self.rot + ] def get_xywh(self) -> Tuple[int, int, int, int]: """ Returns the bounding box as a tuple (x_min, y_min, width, height) @@ -407,12 +424,28 @@ def get_4_corners(self) -> List[Tuple]: List[Tuple] List of 4 rectangle corners in (x, y) format """ - return [ - (self.x_min, self.y_min), - (self.x_min, self.y_max), - (self.x_max, self.y_max), - (self.x_max, self.y_min), - ] + if np.isclose(self.rot, 0.0): + return [ + (self.x_min, self.y_min), + (self.x_min, self.y_max), + (self.x_max, self.y_max), + (self.x_max, self.y_min), + ] + else: + x_center = (self.x_min + self.x_max) / 2 + y_center = (self.y_min + self.y_max) / 2 + + corners = np.array([ + [self.x_min, self.y_min], + [self.x_min, self.y_max], + [self.x_max, self.y_max], + [self.x_max, self.y_min], + ]) + + xys = x_center + np.cos(self.rot) * (corners[:, 0] - x_center) - np.sin(self.rot) * (corners[:, 1] - y_center) + yys = y_center + np.sin(self.rot) * (corners[:, 0] - x_center) + np.cos(self.rot) * (corners[:, 1] - y_center) + + return [(int(x), int(y)) for x, y in zip(xys, yys)] def transform_polygon_with_rings_epsg_to_extended_xy_pixels( diff --git a/src/deepness/python_requirements/requirements.txt b/src/deepness/python_requirements/requirements.txt index bc56f03..7227014 100644 --- a/src/deepness/python_requirements/requirements.txt +++ b/src/deepness/python_requirements/requirements.txt @@ -4,5 +4,6 @@ # NOTE - for the time being - keep the same packages and versions in the `packages_installer_dialog.py` +numpy<2.0.0 onnxruntime-gpu>=1.12.1,<=1.17.0 opencv-python-headless>=4.5.5.64,<=4.9.0.80 diff --git a/test/manual_test_map_processor_obb_yolo_ultralytics.py b/test/manual_test_map_processor_obb_yolo_ultralytics.py new file mode 100644 index 0000000..2734a30 --- /dev/null +++ b/test/manual_test_map_processor_obb_yolo_ultralytics.py @@ -0,0 +1,60 @@ +import os +from pathlib import Path +from test.test_utils import create_default_input_channels_mapping_for_rgb_bands, create_rlayer_from_file, init_qgis +from unittest.mock import MagicMock + +from deepness.common.processing_overlap import ProcessingOverlap, ProcessingOverlapOptions +from deepness.common.processing_parameters.detection_parameters import DetectionParameters, DetectorType +from deepness.common.processing_parameters.map_processing_parameters import ProcessedAreaType +from deepness.processing.map_processor.map_processor_detection import MapProcessorDetection +from deepness.processing.models.detector import Detector + +# Files and model from github issue: https://github.com/PUTvision/qgis-plugin-deepness/discussions/101 + +HOME_DIR = Path(__file__).resolve().parents[1] +EXAMPLE_DATA_DIR = os.path.join(HOME_DIR, 'examples', 'manually_downloaded') +MODEL_FILE_PATH = os.path.join(EXAMPLE_DATA_DIR, 'yolo11m-obb.onnx') + +RASTER_FILE_PATH = os.path.join(HOME_DIR, 'examples', 'yolov7_planes_detection_google_earth', 'google_earth_planes_lawica.png') + +INPUT_CHANNELS_MAPPING = create_default_input_channels_mapping_for_rgb_bands() + + +def test_map_processor_obb_yolo_ultralytics(): + qgs = init_qgis() + + rlayer = create_rlayer_from_file(RASTER_FILE_PATH) + model_wrapper = Detector(MODEL_FILE_PATH) + + params = DetectionParameters( + resolution_cm_per_px=50, + tile_size_px=model_wrapper.get_input_size_in_pixels()[0], # same x and y dimensions, so take x + batch_size=1, + local_cache=False, + processed_area_type=ProcessedAreaType.ENTIRE_LAYER, + mask_layer_id=None, + input_layer_id=rlayer.id(), + input_channels_mapping=INPUT_CHANNELS_MAPPING, + processing_overlap=ProcessingOverlap(ProcessingOverlapOptions.OVERLAP_IN_PERCENT, percentage=15), + model=model_wrapper, + confidence=0.5, + iou_threshold=0.4, + detector_type=DetectorType.YOLO_ULTRALYTICS_OBB, + ) + + map_processor = MapProcessorDetection( + rlayer=rlayer, + vlayer_mask=None, + map_canvas=MagicMock(), + params=params, + ) + + map_processor.run() + + assert len(map_processor.get_all_detections()) == 2 + + + +if __name__ == '__main__': + test_map_processor_obb_yolo_ultralytics() + print('Done') diff --git a/test/test_nms_funtion.py b/test/test_nms_function.py similarity index 100% rename from test/test_nms_funtion.py rename to test/test_nms_function.py diff --git a/test/test_nms_with_rotation_function.py b/test/test_nms_with_rotation_function.py new file mode 100644 index 0000000..b5d855b --- /dev/null +++ b/test/test_nms_with_rotation_function.py @@ -0,0 +1,50 @@ +from unittest.mock import MagicMock +import numpy as np + +from deepness.processing.models.detector import Detection, Detector +from deepness.processing.processing_utils import BoundingBox +from test.test_utils import init_qgis + + +def test_nms_human_case_with_rotation(): + detections = [ + Detection(bbox=BoundingBox(x_min=10, x_max=80, y_min=20, y_max=150, rot=0), conf=0.99999, clss=0), + Detection(bbox=BoundingBox(x_min=0, x_max=90, y_min=100, y_max=190, rot=0), conf=0.44444, clss=0), + Detection(bbox=BoundingBox(x_min=0, x_max=100, y_min=0, y_max=200, rot=0), conf=0.88888, clss=0), + ] + + bboxes = [] + confs = [] + for detection in detections: + bboxes.append(detection.bbox.get_xyxy_rot()) + confs.append(detection.conf) + + picks = Detector.non_max_suppression_fast(np.array(bboxes), np.array(confs), 0.2, with_rot=True) + + detections = [d for i, d in enumerate(detections) if i in picks] + assert len(detections) == 1 + + +def test_nms_human_case_with_rotation_v2(): + detections = [ + Detection(bbox=BoundingBox(x_min=100, x_max=200, y_min=100, y_max=200, rot=0), conf=0.99999, clss=0), + Detection(bbox=BoundingBox(x_min=100, x_max=200, y_min=100, y_max=200, rot=np.pi/2), conf=0.99999, clss=0), + Detection(bbox=BoundingBox(x_min=100, x_max=200, y_min=100, y_max=200, rot=-np.pi/2), conf=0.99999, clss=0), + Detection(bbox=BoundingBox(x_min=100, x_max=200, y_min=100, y_max=200, rot=np.pi), conf=0.99999, clss=0), + ] + + bboxes = [] + confs = [] + for detection in detections: + bboxes.append(detection.bbox.get_xyxy_rot()) + confs.append(detection.conf) + + picks = Detector.non_max_suppression_fast(np.array(bboxes), np.array(confs), 0.2, with_rot=True) + + detections = [d for i, d in enumerate(detections) if i in picks] + assert len(detections) == 1 + +if __name__ == '__main__': + test_nms_human_case_with_rotation() + test_nms_human_case_with_rotation_v2() +