diff --git a/recipes/people_line_counter/README.md b/recipes/people_line_counter/README.md new file mode 100644 index 0000000..2a265f2 --- /dev/null +++ b/recipes/people_line_counter/README.md @@ -0,0 +1,50 @@ +# Count Objects Crossing the Line + +![Demo Video](demo.gif) + +This example demonstrate the power of `trolo` using object detection to count object on a video, along with line zone counting functionality. It uses a pretained model for detection and provides an annotated output video with detected objects, tracking IDs, and counts of objects crossing a defined line zone. + +## Features +- **Object Detection**: Detect objects in video frames using a specified detection model. +- **Object Tracking**: Tracks detected objects across video frames using the `ByteTrack` tracker from `supervision` +- **Line Zone Counting**: Counts the number of objects crossing a predefined line in the video from `supervision` + +## Installation: + +- Setup python environment and activate it [optional] + ```shell + python3 -m venv venv + source venv/bin/activate + ``` +- Install required dependencies + ```shell + pip install trolo supervision tqdm + ``` + +## How It Works +1. The script loads a detection model using `DetectionPredictor` from `trolo` +2. It reads video frames from the input video file. +3. Each frame is processed to: + - Detect objects based on the specified confidence threshold. + - Track objects across frames. + - Annotate the frame with bounding boxes, labels, traces, and line zone information. +4. The processed frames are either visualized in real-time or written to an output video file (if specified). + +## Command-Line Arguments + +| Argument | Type | Default | Description | +|--------------------|----------|-------------|-----------------------------------------------------------------------------| +| `--video_path` | `str` | (Required) | Path to the input video file. | +| `--model_name` | `str` | `dfine-m` | Name of the detection model to use. | +| `--output_path` | `str` | `None` | Path to save the output annotated video. | +| `--vis` | `bool` | `True` | Whether to visualize the annotated frames in real-time. | +| `--conf_threshold` | `float` | `0.35` | Confidence threshold for filtering detections. | + +### Example Usage + +#### Basic Usage +```bash +python script.py --video_path people-walking.mp4 +``` + +We thank Supervision to provide ready to use application apis. diff --git a/recipes/people_line_counter/demo.gif b/recipes/people_line_counter/demo.gif new file mode 100644 index 0000000..15831b8 Binary files /dev/null and b/recipes/people_line_counter/demo.gif differ diff --git a/recipes/people_line_counter/line_counter.py b/recipes/people_line_counter/line_counter.py new file mode 100644 index 0000000..f71d78c --- /dev/null +++ b/recipes/people_line_counter/line_counter.py @@ -0,0 +1,102 @@ +from typing import Optional +import argparse + +from tqdm import tqdm +import cv2 +import numpy as np +import supervision as sv +from trolo import DetectionPredictor, to_sv + + +def detect_objects(model_predictor: DetectionPredictor, + image: np.ndarray, + conf_threshold: Optional[float]=0.35) -> sv.Detections: + """ + Detect objects in an image using the provided model predictor. + """ + image = sv.cv2_to_pillow(image) + results = model_predictor.predict([image], conf_threshold=conf_threshold) + detections = to_sv(results[0]) + return detections + + +def main(video_path: str, + model_name: Optional[str]="dfine-m" , + output_path: Optional[str]=None, + vis: Optional[bool]=True, + conf_threshold: Optional[float]=0.35) -> int: + predictor = DetectionPredictor(model=model_name) + video_info = sv.VideoInfo.from_video_path(video_path) + + # Change as per your requirement + color_lookup = sv.ColorLookup.TRACK + START = sv.Point(0, video_info.height // 2) + END = sv.Point(video_info.width, video_info.height // 2) + + video_info.height = int(video_info.height * 0.5) + video_info.width = int(video_info.width * 0.5) + + frames_generator = sv.get_video_frames_generator(video_path) + if output_path: + video_sink = sv.VideoSink(target_path=str(output_path), video_info=video_info).__enter__() + + # Initialize the tracker and annotators + tracker = sv.ByteTrack() + box_annotator = sv.BoxAnnotator(color_lookup=color_lookup) + label_annotator = sv.LabelAnnotator(color_lookup=color_lookup) + track_annotator = sv.TraceAnnotator(color_lookup=color_lookup) + + # Initialize the line zone counter and annotator + line_zone = sv.LineZone(start=START, end=END) + line_zone_annotator = sv.LineZoneAnnotator( + thickness=4, + text_thickness=4, + text_scale=2) + for frame in tqdm(frames_generator, desc="Counting Objects in the video"): + # Detect objects in the frame + detected_objects = detect_objects(model_predictor=predictor, image=frame, conf_threshold=conf_threshold) + # Update the tracker with the detected objects + tracked_detections = tracker.update_with_detections(detected_objects) + # Update the line zone counter + line_zone.trigger(tracked_detections) + + # Annotate the frame + annotated_frame = frame.copy() + annotated_frame = box_annotator.annotate(annotated_frame, tracked_detections) + labels = [ + f"{track_id[0]}" + for track_id in zip(tracked_detections.tracker_id) + ] + annotated_frame = label_annotator.annotate(annotated_frame, tracked_detections, labels) + annotated_frame = track_annotator.annotate(annotated_frame, tracked_detections) + annotated_frame = line_zone_annotator.annotate(frame=annotated_frame, line_counter=line_zone) + + if vis: + cv2.imshow("Annotated Frame", annotated_frame) + key = cv2.waitKey(1) + if key == ord("q"): + break + if output_path: + annotated_frame = sv.resize_image(annotated_frame, (video_info.width, video_info.height)) + annotated_frame = sv.pillow_to_cv2(annotated_frame) + video_sink.write_frame(annotated_frame) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Object detection and tracking in video with line zone counting.") + parser.add_argument("--video_path", type=str, required=True, help="Path to the input video file.") + parser.add_argument("--model_name", type=str, default="dfine-m", help="Name of the detection model.") + parser.add_argument("--output_path", type=str, default="demo.mp4", help="Path to save the output annotated video.") + parser.add_argument("--vis", type=bool, default=True, + help="Whether to visualize the output frames (default: True).") + parser.add_argument("--conf_threshold", type=float, default=0.35, + help="Confidence threshold for detection (default: 0.35).") + + args = parser.parse_args() + main( + video_path=args.video_path, + model_name=args.model_name, + output_path=args.output_path, + vis=args.vis, + conf_threshold=float(args.conf_threshold) + ) diff --git a/recipes/people_line_counter/people-walking.mp4 b/recipes/people_line_counter/people-walking.mp4 new file mode 100644 index 0000000..84bf38b Binary files /dev/null and b/recipes/people_line_counter/people-walking.mp4 differ diff --git a/trolo/__init__.py b/trolo/__init__.py index 673faef..9679763 100644 --- a/trolo/__init__.py +++ b/trolo/__init__.py @@ -6,3 +6,5 @@ from .configs import * from .loaders.registry import GLOBAL_CONFIG +from .inference import DetectionPredictor +from .utils.box_ops import to_sv \ No newline at end of file