Skip to content

Commit

Permalink
refactor(runtime): import supervision as sv
Browse files Browse the repository at this point in the history
  • Loading branch information
giuseppeambrosio97 committed Dec 27, 2024
1 parent aa787e2 commit 9d5ec67
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions focoos/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
preprocess_image: Preprocesses an image for model input.
postprocess_image: Postprocesses the output image from the model.
image_to_byte_array: Converts a PIL image to a byte array.
det_postprocess: Postprocesses detection model outputs into Detections.
semseg_postprocess: Postprocesses semantic segmentation model outputs into Detections.
det_postprocess: Postprocesses detection model outputs into sv.Detections.
semseg_postprocess: Postprocesses semantic segmentation model outputs into sv.Detections.
get_runtime: Returns an ONNXRuntime instance configured for the given runtime type.
Classes:
Expand All @@ -25,8 +25,8 @@

import numpy as np
import onnxruntime as ort
import supervision as sv
from PIL import Image
from supervision import Detections

from focoos.ports import (
FocoosTask,
Expand Down Expand Up @@ -95,7 +95,7 @@ def image_to_byte_array(image: Image.Image) -> bytes:

def det_postprocess(
out: np.ndarray, im0_shape: Tuple[int, int], conf_threshold: float
) -> Detections:
) -> sv.Detections:
"""
Postprocesses the output of an object detection model and filters detections
based on a confidence threshold.
Expand All @@ -106,14 +106,14 @@ def det_postprocess(
conf_threshold (float): The confidence threshold for filtering detections.
Returns:
Detections: A Detections object containing the filtered bounding boxes, class ids, and confidences.
sv.Detections: A sv.Detections object containing the filtered bounding boxes, class ids, and confidences.
"""
cls_ids, boxes, confs = out
boxes[:, 0::2] *= im0_shape[1]
boxes[:, 1::2] *= im0_shape[0]
high_conf_indices = np.where(confs > conf_threshold)

return Detections(
return sv.Detections(
xyxy=boxes[high_conf_indices].astype(int),
class_id=cls_ids[high_conf_indices].astype(int),
confidence=confs[high_conf_indices].astype(float),
Expand All @@ -122,7 +122,7 @@ def det_postprocess(

def semseg_postprocess(
out: np.ndarray, im0_shape: Tuple[int, int], conf_threshold: float
) -> Detections:
) -> sv.Detections:
"""
Postprocesses the output of a semantic segmentation model and filters based
on a confidence threshold.
Expand All @@ -133,7 +133,7 @@ def semseg_postprocess(
conf_threshold (float): The confidence threshold for filtering detections.
Returns:
Detections: A Detections object containing the masks, class ids, and confidences.
sv.Detections: A sv.Detections object containing the masks, class ids, and confidences.
"""
cls_ids, mask, confs = out[0][0], out[1][0], out[2][0]
masks = np.zeros((len(cls_ids), *mask.shape), dtype=bool)
Expand All @@ -143,7 +143,7 @@ def semseg_postprocess(
masks = masks[high_conf_indices].astype(bool)
cls_ids = cls_ids[high_conf_indices].astype(int)
confs = confs[high_conf_indices].astype(float)
return Detections(
return sv.Detections(
mask=masks,
# xyxy is required from supervisio
xyxy=np.zeros(shape=(len(high_conf_indices), 4), dtype=np.uint8),
Expand Down Expand Up @@ -303,7 +303,7 @@ def __init__(

self.logger.info(f"⏱️ [onnxruntime] {self.name} WARMUP DONE")

def __call__(self, im: np.ndarray, conf_threshold: float) -> Detections:
def __call__(self, im: np.ndarray, conf_threshold: float) -> sv.Detections:
"""
Runs inference on the provided input image and returns the model's detections.
Expand All @@ -312,7 +312,7 @@ def __call__(self, im: np.ndarray, conf_threshold: float) -> Detections:
conf_threshold (float): The confidence threshold for filtering results.
Returns:
Detections: A Detections object containing the model's output detections.
sv.Detections: A sv.Detections object containing the model's output detections.
"""
out_name = None
input_name = self.ort_sess.get_inputs()[0].name
Expand Down

0 comments on commit 9d5ec67

Please sign in to comment.