Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter Detection Classes #69

Open
zeumann opened this issue Dec 13, 2024 · 0 comments
Open

Filter Detection Classes #69

zeumann opened this issue Dec 13, 2024 · 0 comments

Comments

@zeumann
Copy link

zeumann commented Dec 13, 2024

Hi, i have modified the basic_pipelines/detection.py script for my purpose but want to make sure that only the classes are considered I am looking for ("car", "bicycle", "truck", "bus"). All other objects detected shall be ignored. Any idea how to code this?
Thank you!


import gi
gi.require_version('Gst', '1.0')
from gi.repository import Gst, GLib
import os
import numpy as np
import cv2
import hailo
from hailo_rpi_common import (
    get_caps_from_pad,
    get_numpy_from_buffer,
    app_callback_class,
)
from detection_pipeline import GStreamerDetectionApp
import math

# -----------------------------------------------------------------------------------------------
# Helper Functions
# -----------------------------------------------------------------------------------------------
def calculate_distance(point1, point2):
    """Calculate Euclidean distance between two points."""
    return math.sqrt(sum((x - y) ** 2 for x, y in zip(point1, point2)))

# -----------------------------------------------------------------------------------------------
# User-defined class to be used in the callback function
# -----------------------------------------------------------------------------------------------
class user_app_callback_class(app_callback_class):
    def __init__(self):
        super().__init__()
        self.new_variable = 42  # Example variable
        self.seen_cars = set()  # Set to store all unique car IDs
        self.seen_motorcycles = set()  # Set to store all unique motorcycle IDs
        self.seen_trucks = set()  # Set to store all unique truck IDs
        self.seen_buses = set()  # Set to store all unique bus IDs
        self.tracked_objects = {}  # Dictionary of currently tracked objects
        self.frame_count_since_last_seen = {}  # Tracks how many frames an object has been unseen
        self.next_object_id = 1  # Counter to generate unique object IDs

# -----------------------------------------------------------------------------------------------
# User-defined callback function
# -----------------------------------------------------------------------------------------------
def app_callback(pad, info, user_data):
    buffer = info.get_buffer()
    if buffer is None:
        return Gst.PadProbeReturn.OK

    # Increment frame count
    user_data.increment()

    # Get video frame dimensions and format
    format, width, height = get_caps_from_pad(pad)
    frame = None
    if user_data.use_frame and format and width and height:
        frame = get_numpy_from_buffer(buffer, format, width, height)

    # Get detections from the pipeline
    roi = hailo.get_roi_from_buffer(buffer)
    detections = roi.get_objects_typed(hailo.HAILO_DETECTION)

    # Store detected centers and assign IDs
    detection_centers = []  # Store detected centers in this frame
    current_frame_ids = []  # IDs of objects detected in this frame

    for detection in detections:
        label = detection.get_label()
        confidence = detection.get_confidence()
        bbox = detection.get_bbox()

        if label in ["car", "motorcycle", "truck", "bus"] and confidence > 0.4:  # Threshold for valid detections
            # Get bounding box details
            x_min = bbox.xmin()
            y_min = bbox.ymin()
            x_max = bbox.xmax()
            y_max = bbox.ymax()

            # Compute center of the bounding box
            cx = (x_min + x_max) / 2
            cy = (y_min + y_max) / 2
            detection_centers.append((cx, cy, label))
            #print(f"Detected {label} with confidence {confidence}")  # Debugging statement

    # Match detections to existing objects
    unmatched_centers = detection_centers[:]
    print(f"Unmatched Centers: {unmatched_centers}")
    for object_id, (last_center, last_label) in list(user_data.tracked_objects.items()):
        for center in unmatched_centers:
            if calculate_distance(center[:2], last_center) < 20:  # Distance threshold
                # Update position and reset frame counter
                # Here we are ensuring the label is updated to reflect the latest detection
                user_data.tracked_objects[object_id] = (center[:2], center[2])
                user_data.frame_count_since_last_seen[object_id] = 0
                current_frame_ids.append(object_id)
                unmatched_centers.remove(center)
                break

    # Assign new IDs to unmatched detections
    for center in unmatched_centers:
        new_object_id = user_data.next_object_id
        user_data.tracked_objects[new_object_id] = (center[:2], center[2])
        user_data.frame_count_since_last_seen[new_object_id] = 0
        label = center[2]  # Extract the label from the tuple
        print(f"Assigning new ID {new_object_id} to {label}")  # Debugging statement

        # Here is where the issue might lie. Make sure to count based on the current label.
        if label == "car":
            user_data.seen_cars.add(new_object_id)  # Count this car as seen
        elif label == "motorcycle":
            user_data.seen_motorcycles.add(new_object_id)  # Count this motorcycle as seen
        elif label == "truck":
            user_data.seen_trucks.add(new_object_id)  # Count this truck as seen
        elif label == "bus":
            user_data.seen_buses.add(new_object_id)  # Count this bus as seen
        current_frame_ids.append(new_object_id)
        user_data.next_object_id += 1

    # Remove objects that are no longer in the frame (based on frame count)
    objects_to_remove = []
    for object_id in user_data.tracked_objects:
        if object_id not in current_frame_ids:
            user_data.frame_count_since_last_seen[object_id] += 1
            if user_data.frame_count_since_last_seen[object_id] > 60:  # Threshold for removing
                objects_to_remove.append(object_id)

    for object_id in objects_to_remove:
        del user_data.tracked_objects[object_id]
        del user_data.frame_count_since_last_seen[object_id]

    # Overlay object count and IDs on the frame
    if user_data.use_frame and frame is not None:
        # Add a background rectangle for the overlay text
        cv2.rectangle(frame, (0, height - 50), (width, height), (0, 0, 0), -1)  # Black rectangle
        cv2.putText(frame, f"Cars Passed: {len(user_data.seen_cars)}", (10, height - 40),
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
        cv2.putText(frame, f"Motorcycles Passed: {len(user_data.seen_motorcycles)}", (10, height - 30),
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
        cv2.putText(frame, f"Trucks Passed: {len(user_data.seen_trucks)}", (10, height - 20),
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
        cv2.putText(frame, f"Buses Passed: {len(user_data.seen_buses)}", (10, height - 10),
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)

        # Draw bounding boxes and object IDs on the frame
        for object_id, ((cx, cy), label) in user_data.tracked_objects.items():
            # Draw object ID near the detected object
            cv2.putText(frame, f"ID: {object_id}", (int(cx) - 20, int(cy) - 10), 
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
            # Optionally draw a circle or bounding box for visualization
            cv2.circle(frame, (int(cx), int(cy)), 5, (255, 0, 0), -1)

        # Convert the frame back to OpenCV format
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        user_data.set_frame(frame)

    # Debug: Print the total objects and currently tracked objects
    print(f"Total cars passed: {len(user_data.seen_cars)}")
    print(f"Total motorcycles passed: {len(user_data.seen_motorcycles)}")
    print(f"Total trucks passed: {len(user_data.seen_trucks)}")
    print(f"Total buses passed: {len(user_data.seen_buses)}")
    print(f"Currently tracked objects: {user_data.tracked_objects}")

    return Gst.PadProbeReturn.OK  # Ensure this is inside the function

# -----------------------------------------------------------------------------------------------
# Main Application
# -----------------------------------------------------------------------------------------------
if __name__ == "__main__":
    # Create an instance of the user app callback class
    user_data = user_app_callback_class()

    # Initialize the detection app
    app = GStreamerDetectionApp(app_callback, user_data)

    # Run the app
    app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant