diff --git a/supervision/annotators/core.py b/supervision/annotators/core.py index 1910ac9f4..02ab47d6f 100644 --- a/supervision/annotators/core.py +++ b/supervision/annotators/core.py @@ -16,7 +16,7 @@ ) from supervision.config import CLASS_NAME_DATA_FIELD, ORIENTED_BOX_COORDINATES from supervision.detection.core import Detections -from supervision.detection.utils import clip_boxes, mask_to_polygons +from supervision.detection.utils import clip_boxes, mask_to_polygons, spread_out_boxes from supervision.draw.color import Color, ColorPalette from supervision.draw.utils import draw_polygon from supervision.geometry.core import Position @@ -32,6 +32,8 @@ ) from supervision.utils.internal import deprecated +CV2_FONT = cv2.FONT_HERSHEY_SIMPLEX + class BoxAnnotator(BaseAnnotator): """ @@ -1054,6 +1056,7 @@ def __init__( text_position: Position = Position.TOP_LEFT, color_lookup: ColorLookup = ColorLookup.CLASS, border_radius: int = 0, + smart_position: bool = False, ): """ Args: @@ -1070,6 +1073,7 @@ def __init__( Options are `INDEX`, `CLASS`, `TRACK`. border_radius (int): The radius to apply round edges. If the selected value is higher than the lower dimension, width or height, is clipped. + smart_position (bool): Spread out the labels to avoid overlapping. """ self.border_radius: int = border_radius self.color: Union[Color, ColorPalette] = color @@ -1079,6 +1083,7 @@ def __init__( self.text_padding: int = text_padding self.text_anchor: Position = text_position self.color_lookup: ColorLookup = color_lookup + self.smart_position = smart_position @ensure_cv2_image_for_annotation def annotate( @@ -1128,11 +1133,29 @@ def annotate( ![label-annotator-example](https://media.roboflow.com/ supervision-annotator-examples/label-annotator-example-purple.png) """ + assert isinstance(scene, np.ndarray) - font = cv2.FONT_HERSHEY_SIMPLEX - anchors_coordinates = detections.get_anchors_coordinates( - anchor=self.text_anchor - ).astype(int) + self._validate_labels(labels, detections) + + labels = self._get_labels_text(detections, labels) + label_properties = self._get_label_properties(detections, labels) + + if self.smart_position: + xyxy = label_properties[:, :4] + xyxy = spread_out_boxes(xyxy) + label_properties[:, :4] = xyxy + + self._draw_labels( + scene=scene, + labels=labels, + label_properties=label_properties, + detections=detections, + custom_color_lookup=custom_color_lookup, + ) + + return scene + + def _validate_labels(self, labels: Optional[List[str]], detections: Detections): if labels is not None and len(labels) != len(detections): raise ValueError( f"The number of labels ({len(labels)}) does not match the " @@ -1140,72 +1163,121 @@ def annotate( f"should have exactly 1 label." ) - for detection_idx, center_coordinates in enumerate(anchors_coordinates): - color = resolve_color( - color=self.color, - detections=detections, - detection_idx=detection_idx, - color_lookup=( - self.color_lookup - if custom_color_lookup is None - else custom_color_lookup - ), - ) - - text_color = resolve_color( - color=self.text_color, - detections=detections, - detection_idx=detection_idx, - color_lookup=( - self.color_lookup - if custom_color_lookup is None - else custom_color_lookup - ), - ) + def _get_label_properties( + self, + detections: Detections, + labels: List[str], + ) -> np.ndarray: + """ + Calculate the numerical properties required to draw the labels on the image. - if labels is not None: - text = labels[detection_idx] - elif CLASS_NAME_DATA_FIELD in detections.data: - text = detections.data[CLASS_NAME_DATA_FIELD][detection_idx] - elif detections.class_id is not None: - text = str(detections.class_id[detection_idx]) - else: - text = str(detection_idx) + Returns: + (np.ndarray): An array of label properties, containing columns: + `min_x`, `min_y`, `max_x`, `max_y`, `padded_text_height`. + """ + label_properties = [] + anchors_coordinates = detections.get_anchors_coordinates( + anchor=self.text_anchor + ).astype(int) - text_w, text_h = cv2.getTextSize( - text=text, - fontFace=font, + for label, center_coords in zip(labels, anchors_coordinates): + (text_w, text_h) = cv2.getTextSize( + text=label, + fontFace=CV2_FONT, fontScale=self.text_scale, thickness=self.text_thickness, )[0] - text_w_padded = text_w + 2 * self.text_padding - text_h_padded = text_h + 2 * self.text_padding + + width_padded = text_w + 2 * self.text_padding + height_padded = text_h + 2 * self.text_padding + text_background_xyxy = resolve_text_background_xyxy( - center_coordinates=tuple(center_coordinates), - text_wh=(text_w_padded, text_h_padded), + center_coordinates=tuple(center_coords), + text_wh=(width_padded, height_padded), position=self.text_anchor, ) - text_x = text_background_xyxy[0] + self.text_padding - text_y = text_background_xyxy[1] + self.text_padding + text_h + label_properties.append( + [ + *text_background_xyxy, + text_h, + ] + ) + + return np.array(label_properties).reshape(-1, 5) + + @staticmethod + def _get_labels_text( + detections: Detections, custom_labels: Optional[List[str]] + ) -> List[str]: + if custom_labels is not None: + return custom_labels + + labels = [] + for idx in range(len(detections)): + if CLASS_NAME_DATA_FIELD in detections.data: + labels.append(detections.data[CLASS_NAME_DATA_FIELD][idx]) + elif detections.class_id is not None: + labels.append(str(detections.class_id[idx])) + else: + labels.append(str(idx)) + return labels + + def _draw_labels( + self, + scene: np.ndarray, + labels: List[str], + label_properties: np.ndarray, + detections: Detections, + custom_color_lookup: Optional[np.ndarray], + ) -> None: + assert len(labels) == len(label_properties) == len(detections), ( + f"Number of label properties ({len(label_properties)}), " + f"labels ({len(labels)}) and detections ({len(detections)}) " + "do not match." + ) + + color_lookup = ( + custom_color_lookup + if custom_color_lookup is not None + else self.color_lookup + ) + + for idx, label_property in enumerate(label_properties): + background_color = resolve_color( + color=self.color, + detections=detections, + detection_idx=idx, + color_lookup=color_lookup, + ) + text_color = resolve_color( + color=self.text_color, + detections=detections, + detection_idx=idx, + color_lookup=color_lookup, + ) + box_xyxy = label_property[:4] + text_height_padded = label_property[4] self.draw_rounded_rectangle( scene=scene, - xyxy=text_background_xyxy, - color=color.as_bgr(), + xyxy=box_xyxy, + color=background_color.as_bgr(), border_radius=self.border_radius, ) + + text_x = box_xyxy[0] + self.text_padding + text_y = box_xyxy[1] + self.text_padding + text_height_padded cv2.putText( img=scene, - text=text, + text=labels[idx], org=(text_x, text_y), - fontFace=font, + fontFace=CV2_FONT, fontScale=self.text_scale, color=text_color.as_bgr(), thickness=self.text_thickness, lineType=cv2.LINE_AA, ) - return scene @staticmethod def draw_rounded_rectangle( @@ -1266,6 +1338,7 @@ def __init__( text_position: Position = Position.TOP_LEFT, color_lookup: ColorLookup = ColorLookup.CLASS, border_radius: int = 0, + smart_position: bool = False, ): """ Args: @@ -1282,6 +1355,7 @@ def __init__( Options are `INDEX`, `CLASS`, `TRACK`. border_radius (int): The radius to apply round edges. If the selected value is higher than the lower dimension, width or height, is clipped. + smart_position (bool): Spread out the labels to avoid overlapping. """ self.color = color self.text_color = text_color @@ -1289,14 +1363,8 @@ def __init__( self.text_anchor = text_position self.color_lookup = color_lookup self.border_radius = border_radius - if font_path is not None: - try: - self.font = ImageFont.truetype(font_path, font_size) - except OSError: - print(f"Font path '{font_path}' not found. Using PIL's default font.") - self.font = self._load_default_font(font_size) - else: - self.font = self._load_default_font(font_size) + self.smart_position = smart_position + self.font = self._load_font(font_size, font_path) @ensure_pil_image_for_annotation def annotate( @@ -1346,88 +1414,157 @@ def annotate( """ assert isinstance(scene, Image.Image) + self._validate_labels(labels, detections) + draw = ImageDraw.Draw(scene) - anchors_coordinates = detections.get_anchors_coordinates( - anchor=self.text_anchor - ).astype(int) + labels = self._get_labels_text(detections, labels) + label_properties = self._get_label_properties(draw, detections, labels) + + if self.smart_position: + xyxy = label_properties[:, :4] + xyxy = spread_out_boxes(xyxy) + label_properties[:, :4] = xyxy + + self._draw_labels( + draw=draw, + labels=labels, + label_properties=label_properties, + detections=detections, + custom_color_lookup=custom_color_lookup, + ) + + return scene + + def _validate_labels(self, labels: Optional[List[str]], detections: Detections): if labels is not None and len(labels) != len(detections): raise ValueError( - f"The number of labels provided ({len(labels)}) does not match the " - f"number of detections ({len(detections)}). Each detection should have " - f"a corresponding label." + f"The number of labels ({len(labels)}) does not match the " + f"number of detections ({len(detections)}). Each detection " + f"should have exactly 1 label." ) - for detection_idx, center_coordinates in enumerate(anchors_coordinates): - color = resolve_color( - color=self.color, - detections=detections, - detection_idx=detection_idx, - color_lookup=( - self.color_lookup - if custom_color_lookup is None - else custom_color_lookup - ), + + def _get_label_properties( + self, draw, detections: Detections, labels: List[str] + ) -> np.ndarray: + """ + Calculate the numerical properties required to draw the labels on the image. + + Returns: + (np.ndarray): An array of label properties, containing columns: + `min_x`, `min_y`, `max_x`, `max_y`, `text_left_coordinate`, + `text_top_coordinate`. The first 4 values are already padded + with `text_padding`. + """ + label_properties = [] + + anchor_coordinates = detections.get_anchors_coordinates( + anchor=self.text_anchor + ).astype(int) + + for label, center_coords in zip(labels, anchor_coordinates): + text_left, text_top, text_right, text_bottom = draw.textbbox( + (0, 0), label, font=self.font ) + text_width = text_right - text_left + text_height = text_bottom - text_top + width_padded = text_width + 2 * self.text_padding + height_padded = text_height + 2 * self.text_padding - text_color = resolve_color( - color=self.text_color, - detections=detections, - detection_idx=detection_idx, - color_lookup=( - self.color_lookup - if custom_color_lookup is None - else custom_color_lookup - ), + text_background_xyxy = resolve_text_background_xyxy( + center_coordinates=tuple(center_coords), + text_wh=(width_padded, height_padded), + position=self.text_anchor, ) - if labels is not None: - text = labels[detection_idx] - elif CLASS_NAME_DATA_FIELD in detections.data: - text = detections.data[CLASS_NAME_DATA_FIELD][detection_idx] + label_properties.append([*text_background_xyxy, text_left, text_top]) + + return np.array(label_properties).reshape(-1, 6) + + @staticmethod + def _get_labels_text( + detections: Detections, custom_labels: Optional[List[str]] + ) -> List[str]: + if custom_labels is not None: + return custom_labels + + labels = [] + for idx in range(len(detections)): + if CLASS_NAME_DATA_FIELD in detections.data: + labels.append(detections.data[CLASS_NAME_DATA_FIELD][idx]) elif detections.class_id is not None: - text = str(detections.class_id[detection_idx]) + labels.append(str(detections.class_id[idx])) else: - text = str(detection_idx) + labels.append(str(idx)) + return labels - left, top, right, bottom = draw.textbbox((0, 0), text, font=self.font) - text_width = right - left - text_height = bottom - top - text_w_padded = text_width + 2 * self.text_padding - text_h_padded = text_height + 2 * self.text_padding - text_background_xyxy = resolve_text_background_xyxy( - center_coordinates=tuple(center_coordinates), - text_wh=(text_w_padded, text_h_padded), - position=self.text_anchor, + def _draw_labels( + self, + draw, + labels: List[str], + label_properties: np.ndarray, + detections: Detections, + custom_color_lookup: Optional[np.ndarray], + ) -> None: + assert len(labels) == len(label_properties) == len(detections), ( + f"Number of label properties ({len(label_properties)}), " + f"labels ({len(labels)}) and detections ({len(detections)}) " + "do not match." + ) + color_lookup = ( + custom_color_lookup + if custom_color_lookup is not None + else self.color_lookup + ) + + for idx, label_property in enumerate(label_properties): + background_color = resolve_color( + color=self.color, + detections=detections, + detection_idx=idx, + color_lookup=color_lookup, + ) + text_color = resolve_color( + color=self.text_color, + detections=detections, + detection_idx=idx, + color_lookup=color_lookup, ) - text_x = text_background_xyxy[0] + self.text_padding - left - text_y = text_background_xyxy[1] + self.text_padding - top + box_xyxy = label_property[:4] + text_left = label_property[4] + text_top = label_property[5] + label_x_position = box_xyxy[0] + self.text_padding - text_left + label_y_position = box_xyxy[1] + self.text_padding - text_top draw.rounded_rectangle( - text_background_xyxy, + tuple(box_xyxy), radius=self.border_radius, - fill=color.as_rgb(), + fill=background_color.as_rgb(), outline=None, ) draw.text( - xy=(text_x, text_y), - text=text, + xy=(label_x_position, label_y_position), + text=labels[idx], font=self.font, fill=text_color.as_rgb(), ) - return scene @staticmethod - def _load_default_font(size): - """ - PIL either loads a font that accepts a size (e.g. on my machine) - or raises an error saying `load_default` does not accept arguments - (e.g. in Colab). - """ + def _load_font(font_size: int, font_path: Optional[str]): + def load_default_font(size): + try: + return ImageFont.load_default(size) + except TypeError: + return ImageFont.load_default() + + if font_path is None: + return load_default_font(font_size) + try: - font = ImageFont.load_default(size) - except TypeError: - font = ImageFont.load_default() - return font + return ImageFont.truetype(font_path, font_size) + except OSError: + print(f"Font path '{font_path}' not found. Using PIL's default font.") + return load_default_font(font_size) class IconAnnotator(BaseAnnotator): diff --git a/supervision/detection/utils.py b/supervision/detection/utils.py index 43fcec5a0..c6c63286d 100644 --- a/supervision/detection/utils.py +++ b/supervision/detection/utils.py @@ -1039,3 +1039,59 @@ def cross_product(anchors: np.ndarray, vector: Vector) -> np.ndarray: ) vector_start = np.array([vector.start.x, vector.start.y]) return np.cross(vector_at_zero, anchors - vector_start) + + +def spread_out_boxes( + xyxy: np.ndarray, + max_iterations: int = 100, +) -> np.ndarray: + """ + Spread out boxes that overlap with each other. + + Args: + xyxy: Numpy array of shape (N, 4) where N is the number of boxes. + max_iterations: Maximum number of iterations to run the algorithm for. + """ + if len(xyxy) == 0: + return xyxy + + xyxy_padded = pad_boxes(xyxy, px=1) + for _ in range(max_iterations): + # NxN + iou = box_iou_batch(xyxy_padded, xyxy_padded) + np.fill_diagonal(iou, 0) + if np.all(iou == 0): + break + + overlap_mask = iou > 0 + + # Nx2 + centers = (xyxy_padded[:, :2] + xyxy_padded[:, 2:]) / 2 + + # NxNx2 + delta_centers = centers[:, np.newaxis, :] - centers[np.newaxis, :, :] + delta_centers *= overlap_mask[:, :, np.newaxis] + + # Nx2 + delta_sum = np.sum(delta_centers, axis=1) + delta_magnitude = np.linalg.norm(delta_sum, axis=1, keepdims=True) + direction_vectors = np.divide( + delta_sum, + delta_magnitude, + out=np.zeros_like(delta_sum), + where=delta_magnitude != 0, + ) + + force_vectors = np.sum(iou, axis=1) + force_vectors = force_vectors[:, np.newaxis] * direction_vectors + + force_vectors *= 10 + force_vectors[(force_vectors > 0) & (force_vectors < 2)] = 2 + force_vectors[(force_vectors < 0) & (force_vectors > -2)] = -2 + + force_vectors = force_vectors.astype(int) + + xyxy_padded[:, [0, 1]] += force_vectors + xyxy_padded[:, [2, 3]] += force_vectors + + return pad_boxes(xyxy_padded, px=-1) diff --git a/supervision/keypoint/annotators.py b/supervision/keypoint/annotators.py index 559bfa921..7537b264a 100644 --- a/supervision/keypoint/annotators.py +++ b/supervision/keypoint/annotators.py @@ -5,10 +5,11 @@ import cv2 import numpy as np -from supervision import Rect, pad_boxes from supervision.annotators.base import ImageType +from supervision.detection.utils import pad_boxes, spread_out_boxes from supervision.draw.color import Color from supervision.draw.utils import draw_rounded_rectangle +from supervision.geometry.core import Rect from supervision.keypoint.core import KeyPoints from supervision.keypoint.skeletons import SKELETONS_BY_VERTEX_COUNT from supervision.utils.conversion import ensure_cv2_image_for_annotation @@ -201,6 +202,7 @@ def __init__( text_thickness: int = 1, text_padding: int = 10, border_radius: int = 0, + smart_position: bool = False, ): """ Args: @@ -215,6 +217,7 @@ def __init__( text_padding (int): The padding around the text. border_radius (int): The radius of the rounded corners of the boxes. Set to a high value to produce circles. + smart_position (bool): Spread out the labels to avoid overlap. """ self.border_radius: int = border_radius self.color: Union[Color, List[Color]] = color @@ -222,6 +225,7 @@ def __init__( self.text_scale: float = text_scale self.text_thickness: int = text_thickness self.text_padding: int = text_padding + self.smart_position = smart_position def annotate( self, @@ -356,9 +360,12 @@ def annotate( for anchor, label in zip(anchors, labels) ] ) - xyxy_padded = pad_boxes(xyxy=xyxy, px=self.text_padding) + if self.smart_position: + xyxy_padded = spread_out_boxes(xyxy_padded) + xyxy = pad_boxes(xyxy=xyxy_padded, px=-self.text_padding) + for text, color, text_color, box, box_padded in zip( labels, colors, text_colors, xyxy, xyxy_padded ):