-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdeep_sort.py
118 lines (96 loc) · 4.05 KB
/
deep_sort.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import numpy as np
import torch
from deep_sort.deep.feature_extractor import Extractor
from deep_sort.sort.nn_matching import NearestNeighborDistanceMetric
from deep_sort.sort.preprocessing import non_max_suppression
from deep_sort.sort.detection import Detection
from deep_sort.sort.tracker import Tracker
__all__ = ['DeepSort']
class DeepSort(object):
def __init__(self, model_path, max_dist=0.2, min_confidence=0.3, nms_max_overlap=1.0, max_iou_distance=0.7, max_age=70, n_init=3, nn_budget=100, use_cuda=True,num_class = 751):
self.min_confidence = min_confidence
self.nms_max_overlap = nms_max_overlap
self.extractor = Extractor(model_path, use_cuda=use_cuda,num_class=num_class)
max_cosine_distance = max_dist
nn_budget = 100
metric = NearestNeighborDistanceMetric("cosine", max_cosine_distance, nn_budget)
self.tracker = Tracker(metric, max_iou_distance=max_iou_distance, max_age=max_age, n_init=n_init)
def update(self, bbox_xywh, confidences, ori_img,count):
self.height, self.width = ori_img.shape[:2]
# generate detections
features = self._get_features(bbox_xywh, ori_img)##reid 得到512的 行人embamding
bbox_tlwh = self._xywh_to_tlwh(bbox_xywh) ## xywh 转成 左上角 wh
detections = [Detection(bbox_tlwh[i], conf, features[i]) for i,conf in enumerate(confidences) if conf>self.min_confidence]
# run on non-maximum supression
boxes = np.array([d.tlwh for d in detections])
scores = np.array([d.confidence for d in detections])
indices = non_max_suppression(boxes, self.nms_max_overlap, scores)
detections = [detections[i] for i in indices]
# update tracker
self.tracker.predict()
self.tracker.update(detections)
# output bbox identities
outputs = []
detection_id = len(bbox_xywh)
for track in self.tracker.tracks:
if not track.is_confirmed() or track.time_since_update > 1:
continue
box = track.to_tlwh()
x1,y1,x2,y2 = self._tlwh_to_xyxy(box)
track_id = track.track_id
count.append(int(track_id))
outputs.append(np.array([x1,y1,x2,y2,track_id], dtype=np.int))
if len(outputs) > 0:
outputs = np.stack(outputs,axis=0)
return outputs,count,detection_id
"""
TODO:
Convert bbox from xc_yc_w_h to xtl_ytl_w_h
Thanks [email protected] for reporting this bug!
"""
@staticmethod ##该方法不强制要求传递参数,可以不需要实例化:
def _xywh_to_tlwh(bbox_xywh):
if isinstance(bbox_xywh, np.ndarray):
bbox_tlwh = bbox_xywh.copy()
elif isinstance(bbox_xywh, torch.Tensor):
bbox_tlwh = bbox_xywh.clone()
bbox_tlwh[:,0] = bbox_xywh[:,0] - bbox_xywh[:,2]/2.
bbox_tlwh[:,1] = bbox_xywh[:,1] - bbox_xywh[:,3]/2.
return bbox_tlwh
def _xywh_to_xyxy(self, bbox_xywh):
x,y,w,h = bbox_xywh
x1 = max(int(x-w/2),0)
x2 = min(int(x+w/2),self.width-1)
y1 = max(int(y-h/2),0)
y2 = min(int(y+h/2),self.height-1)
return x1,y1,x2,y2
def _tlwh_to_xyxy(self, bbox_tlwh):
"""
TODO:
Convert bbox from xtl_ytl_w_h to xc_yc_w_h
Thanks [email protected] for reporting this bug!
"""
x,y,w,h = bbox_tlwh
x1 = max(int(x),0)
x2 = min(int(x+w),self.width-1)
y1 = max(int(y),0)
y2 = min(int(y+h),self.height-1)
return x1,y1,x2,y2
def _xyxy_to_tlwh(self, bbox_xyxy):
x1,y1,x2,y2 = bbox_xyxy
t = x1
l = y1
w = int(x2-x1)
h = int(y2-y1)
return t,l,w,h
def _get_features(self, bbox_xywh, ori_img):
im_crops = []
for box in bbox_xywh:
x1,y1,x2,y2 = self._xywh_to_xyxy(box)
im = ori_img[y1:y2,x1:x2]
im_crops.append(im)
if im_crops:
features = self.extractor(im_crops)
else:
features = np.array([])
return features