-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
line based object counter example (#34)
* added example of line counting * update * update * Update README.md * Add files via upload
- Loading branch information
1 parent
2625884
commit b5f6eac
Showing
5 changed files
with
154 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Count Objects Crossing the Line | ||
|
||
![Demo Video](demo.gif) | ||
|
||
This example demonstrate the power of `trolo` using object detection to count object on a video, along with line zone counting functionality. It uses a pretained model for detection and provides an annotated output video with detected objects, tracking IDs, and counts of objects crossing a defined line zone. | ||
|
||
## Features | ||
- **Object Detection**: Detect objects in video frames using a specified detection model. | ||
- **Object Tracking**: Tracks detected objects across video frames using the `ByteTrack` tracker from `supervision` | ||
- **Line Zone Counting**: Counts the number of objects crossing a predefined line in the video from `supervision` | ||
|
||
## Installation: | ||
|
||
- Setup python environment and activate it [optional] | ||
```shell | ||
python3 -m venv venv | ||
source venv/bin/activate | ||
``` | ||
- Install required dependencies | ||
```shell | ||
pip install trolo supervision tqdm | ||
``` | ||
|
||
## How It Works | ||
1. The script loads a detection model using `DetectionPredictor` from `trolo` | ||
2. It reads video frames from the input video file. | ||
3. Each frame is processed to: | ||
- Detect objects based on the specified confidence threshold. | ||
- Track objects across frames. | ||
- Annotate the frame with bounding boxes, labels, traces, and line zone information. | ||
4. The processed frames are either visualized in real-time or written to an output video file (if specified). | ||
|
||
## Command-Line Arguments | ||
|
||
| Argument | Type | Default | Description | | ||
|--------------------|----------|-------------|-----------------------------------------------------------------------------| | ||
| `--video_path` | `str` | (Required) | Path to the input video file. | | ||
| `--model_name` | `str` | `dfine-m` | Name of the detection model to use. | | ||
| `--output_path` | `str` | `None` | Path to save the output annotated video. | | ||
| `--vis` | `bool` | `True` | Whether to visualize the annotated frames in real-time. | | ||
| `--conf_threshold` | `float` | `0.35` | Confidence threshold for filtering detections. | | ||
|
||
### Example Usage | ||
|
||
#### Basic Usage | ||
```bash | ||
python script.py --video_path people-walking.mp4 | ||
``` | ||
|
||
We thank Supervision to provide ready to use application apis. |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
from typing import Optional | ||
import argparse | ||
|
||
from tqdm import tqdm | ||
import cv2 | ||
import numpy as np | ||
import supervision as sv | ||
from trolo import DetectionPredictor, to_sv | ||
|
||
|
||
def detect_objects(model_predictor: DetectionPredictor, | ||
image: np.ndarray, | ||
conf_threshold: Optional[float]=0.35) -> sv.Detections: | ||
""" | ||
Detect objects in an image using the provided model predictor. | ||
""" | ||
image = sv.cv2_to_pillow(image) | ||
results = model_predictor.predict([image], conf_threshold=conf_threshold) | ||
detections = to_sv(results[0]) | ||
return detections | ||
|
||
|
||
def main(video_path: str, | ||
model_name: Optional[str]="dfine-m" , | ||
output_path: Optional[str]=None, | ||
vis: Optional[bool]=True, | ||
conf_threshold: Optional[float]=0.35) -> int: | ||
predictor = DetectionPredictor(model=model_name) | ||
video_info = sv.VideoInfo.from_video_path(video_path) | ||
|
||
# Change as per your requirement | ||
color_lookup = sv.ColorLookup.TRACK | ||
START = sv.Point(0, video_info.height // 2) | ||
END = sv.Point(video_info.width, video_info.height // 2) | ||
|
||
video_info.height = int(video_info.height * 0.5) | ||
video_info.width = int(video_info.width * 0.5) | ||
|
||
frames_generator = sv.get_video_frames_generator(video_path) | ||
if output_path: | ||
video_sink = sv.VideoSink(target_path=str(output_path), video_info=video_info).__enter__() | ||
|
||
# Initialize the tracker and annotators | ||
tracker = sv.ByteTrack() | ||
box_annotator = sv.BoxAnnotator(color_lookup=color_lookup) | ||
label_annotator = sv.LabelAnnotator(color_lookup=color_lookup) | ||
track_annotator = sv.TraceAnnotator(color_lookup=color_lookup) | ||
|
||
# Initialize the line zone counter and annotator | ||
line_zone = sv.LineZone(start=START, end=END) | ||
line_zone_annotator = sv.LineZoneAnnotator( | ||
thickness=4, | ||
text_thickness=4, | ||
text_scale=2) | ||
for frame in tqdm(frames_generator, desc="Counting Objects in the video"): | ||
# Detect objects in the frame | ||
detected_objects = detect_objects(model_predictor=predictor, image=frame, conf_threshold=conf_threshold) | ||
# Update the tracker with the detected objects | ||
tracked_detections = tracker.update_with_detections(detected_objects) | ||
# Update the line zone counter | ||
line_zone.trigger(tracked_detections) | ||
|
||
# Annotate the frame | ||
annotated_frame = frame.copy() | ||
annotated_frame = box_annotator.annotate(annotated_frame, tracked_detections) | ||
labels = [ | ||
f"{track_id[0]}" | ||
for track_id in zip(tracked_detections.tracker_id) | ||
] | ||
annotated_frame = label_annotator.annotate(annotated_frame, tracked_detections, labels) | ||
annotated_frame = track_annotator.annotate(annotated_frame, tracked_detections) | ||
annotated_frame = line_zone_annotator.annotate(frame=annotated_frame, line_counter=line_zone) | ||
|
||
if vis: | ||
cv2.imshow("Annotated Frame", annotated_frame) | ||
key = cv2.waitKey(1) | ||
if key == ord("q"): | ||
break | ||
if output_path: | ||
annotated_frame = sv.resize_image(annotated_frame, (video_info.width, video_info.height)) | ||
annotated_frame = sv.pillow_to_cv2(annotated_frame) | ||
video_sink.write_frame(annotated_frame) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Object detection and tracking in video with line zone counting.") | ||
parser.add_argument("--video_path", type=str, required=True, help="Path to the input video file.") | ||
parser.add_argument("--model_name", type=str, default="dfine-m", help="Name of the detection model.") | ||
parser.add_argument("--output_path", type=str, default="demo.mp4", help="Path to save the output annotated video.") | ||
parser.add_argument("--vis", type=bool, default=True, | ||
help="Whether to visualize the output frames (default: True).") | ||
parser.add_argument("--conf_threshold", type=float, default=0.35, | ||
help="Confidence threshold for detection (default: 0.35).") | ||
|
||
args = parser.parse_args() | ||
main( | ||
video_path=args.video_path, | ||
model_name=args.model_name, | ||
output_path=args.output_path, | ||
vis=args.vis, | ||
conf_threshold=float(args.conf_threshold) | ||
) |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters