Skip to content

Commit

Permalink
line based object counter example (#34)
Browse files Browse the repository at this point in the history
* added example of line counting

* update

* update

* Update README.md

* Add files via upload
  • Loading branch information
hardikdava authored Dec 6, 2024
1 parent 2625884 commit b5f6eac
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 0 deletions.
50 changes: 50 additions & 0 deletions recipes/people_line_counter/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Count Objects Crossing the Line

![Demo Video](demo.gif)

This example demonstrate the power of `trolo` using object detection to count object on a video, along with line zone counting functionality. It uses a pretained model for detection and provides an annotated output video with detected objects, tracking IDs, and counts of objects crossing a defined line zone.

## Features
- **Object Detection**: Detect objects in video frames using a specified detection model.
- **Object Tracking**: Tracks detected objects across video frames using the `ByteTrack` tracker from `supervision`
- **Line Zone Counting**: Counts the number of objects crossing a predefined line in the video from `supervision`

## Installation:

- Setup python environment and activate it [optional]
```shell
python3 -m venv venv
source venv/bin/activate
```
- Install required dependencies
```shell
pip install trolo supervision tqdm
```

## How It Works
1. The script loads a detection model using `DetectionPredictor` from `trolo`
2. It reads video frames from the input video file.
3. Each frame is processed to:
- Detect objects based on the specified confidence threshold.
- Track objects across frames.
- Annotate the frame with bounding boxes, labels, traces, and line zone information.
4. The processed frames are either visualized in real-time or written to an output video file (if specified).

## Command-Line Arguments

| Argument | Type | Default | Description |
|--------------------|----------|-------------|-----------------------------------------------------------------------------|
| `--video_path` | `str` | (Required) | Path to the input video file. |
| `--model_name` | `str` | `dfine-m` | Name of the detection model to use. |
| `--output_path` | `str` | `None` | Path to save the output annotated video. |
| `--vis` | `bool` | `True` | Whether to visualize the annotated frames in real-time. |
| `--conf_threshold` | `float` | `0.35` | Confidence threshold for filtering detections. |

### Example Usage

#### Basic Usage
```bash
python script.py --video_path people-walking.mp4
```

We thank Supervision to provide ready to use application apis.
Binary file added recipes/people_line_counter/demo.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
102 changes: 102 additions & 0 deletions recipes/people_line_counter/line_counter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from typing import Optional
import argparse

from tqdm import tqdm
import cv2
import numpy as np
import supervision as sv
from trolo import DetectionPredictor, to_sv


def detect_objects(model_predictor: DetectionPredictor,
image: np.ndarray,
conf_threshold: Optional[float]=0.35) -> sv.Detections:
"""
Detect objects in an image using the provided model predictor.
"""
image = sv.cv2_to_pillow(image)
results = model_predictor.predict([image], conf_threshold=conf_threshold)
detections = to_sv(results[0])
return detections


def main(video_path: str,
model_name: Optional[str]="dfine-m" ,
output_path: Optional[str]=None,
vis: Optional[bool]=True,
conf_threshold: Optional[float]=0.35) -> int:
predictor = DetectionPredictor(model=model_name)
video_info = sv.VideoInfo.from_video_path(video_path)

# Change as per your requirement
color_lookup = sv.ColorLookup.TRACK
START = sv.Point(0, video_info.height // 2)
END = sv.Point(video_info.width, video_info.height // 2)

video_info.height = int(video_info.height * 0.5)
video_info.width = int(video_info.width * 0.5)

frames_generator = sv.get_video_frames_generator(video_path)
if output_path:
video_sink = sv.VideoSink(target_path=str(output_path), video_info=video_info).__enter__()

# Initialize the tracker and annotators
tracker = sv.ByteTrack()
box_annotator = sv.BoxAnnotator(color_lookup=color_lookup)
label_annotator = sv.LabelAnnotator(color_lookup=color_lookup)
track_annotator = sv.TraceAnnotator(color_lookup=color_lookup)

# Initialize the line zone counter and annotator
line_zone = sv.LineZone(start=START, end=END)
line_zone_annotator = sv.LineZoneAnnotator(
thickness=4,
text_thickness=4,
text_scale=2)
for frame in tqdm(frames_generator, desc="Counting Objects in the video"):
# Detect objects in the frame
detected_objects = detect_objects(model_predictor=predictor, image=frame, conf_threshold=conf_threshold)
# Update the tracker with the detected objects
tracked_detections = tracker.update_with_detections(detected_objects)
# Update the line zone counter
line_zone.trigger(tracked_detections)

# Annotate the frame
annotated_frame = frame.copy()
annotated_frame = box_annotator.annotate(annotated_frame, tracked_detections)
labels = [
f"{track_id[0]}"
for track_id in zip(tracked_detections.tracker_id)
]
annotated_frame = label_annotator.annotate(annotated_frame, tracked_detections, labels)
annotated_frame = track_annotator.annotate(annotated_frame, tracked_detections)
annotated_frame = line_zone_annotator.annotate(frame=annotated_frame, line_counter=line_zone)

if vis:
cv2.imshow("Annotated Frame", annotated_frame)
key = cv2.waitKey(1)
if key == ord("q"):
break
if output_path:
annotated_frame = sv.resize_image(annotated_frame, (video_info.width, video_info.height))
annotated_frame = sv.pillow_to_cv2(annotated_frame)
video_sink.write_frame(annotated_frame)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Object detection and tracking in video with line zone counting.")
parser.add_argument("--video_path", type=str, required=True, help="Path to the input video file.")
parser.add_argument("--model_name", type=str, default="dfine-m", help="Name of the detection model.")
parser.add_argument("--output_path", type=str, default="demo.mp4", help="Path to save the output annotated video.")
parser.add_argument("--vis", type=bool, default=True,
help="Whether to visualize the output frames (default: True).")
parser.add_argument("--conf_threshold", type=float, default=0.35,
help="Confidence threshold for detection (default: 0.35).")

args = parser.parse_args()
main(
video_path=args.video_path,
model_name=args.model_name,
output_path=args.output_path,
vis=args.vis,
conf_threshold=float(args.conf_threshold)
)
Binary file added recipes/people_line_counter/people-walking.mp4
Binary file not shown.
2 changes: 2 additions & 0 deletions trolo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@
from .configs import *

from .loaders.registry import GLOBAL_CONFIG
from .inference import DetectionPredictor
from .utils.box_ops import to_sv

0 comments on commit b5f6eac

Please sign in to comment.