Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
mikel.brostrom committed Jan 12, 2024
2 parents 9383516 + 96a326b commit b4cd344
Show file tree
Hide file tree
Showing 13 changed files with 108 additions and 75 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ jobs:
IMG: ./assets/MOT17-mini/train/MOT17-05-FRCNN/img1/000001.jpg
run: |
# deepocsort fro all supported yolo models
python examples/track.py --tracking-method deepocsort --source $IMG --imgsz 320 --reid-model examples/weights/clip_market1501.pt
python examples/track.py --tracking-method deepocsort --source $IMG --imgsz 320
python examples/track.py --yolo-model yolo_nas_s --tracking-method deepocsort --source $IMG --imgsz 320
python examples/track.py --yolo-model yolox_n --tracking-method deepocsort --source $IMG --imgsz 320
# python examples/track.py --yolo-model yolox_n --tracking-method deepocsort --source $IMG --imgsz 320
# hybridsort
python examples/track.py --tracking-method hybridsort --source $IMG --imgsz 320
Expand Down Expand Up @@ -93,7 +93,7 @@ jobs:
# test exported reid model
python examples/track.py --reid-model examples/weights/osnet_x0_25_msmt17.torchscript --source $IMG --imgsz 320
python examples/track.py --reid-model examples/weights/osnet_x0_25_msmt17.onnx --source $IMG --imgsz 320
python examples/track.py --reid-model examples/weights/osnet_x0_25_msmt17_saved_model/osnet_x0_25_msmt17_float16.tflite --source $IMG --imgsz 320
#python examples/track.py --reid-model examples/weights/osnet_x0_25_msmt17_saved_model/osnet_x0_25_msmt17_float16.tflite --source $IMG --imgsz 320
python examples/track.py --reid-model examples/weights/osnet_x0_25_msmt17_openvino_model --source $IMG --imgsz 320
- name: Test tracking with seg models
Expand Down
2 changes: 1 addition & 1 deletion boxmot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license

__version__ = '10.0.47'
__version__ = '10.0.50'

from boxmot.postprocessing.gsi import gsi
from boxmot.tracker_zoo import create_tracker, get_tracker_config
Expand Down
7 changes: 4 additions & 3 deletions boxmot/appearance/reid_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def export_torchscript(model, im, file, optimize):
def export_onnx(model, im, file, opset, dynamic, fp16, simplify):
# ONNX export
try:
__tr.check_packages(("onnx",))
# required by onnx2tf
__tr.check_packages(("onnx==1.14.0",))
import onnx

f = file.with_suffix(".onnx")
Expand Down Expand Up @@ -107,7 +108,7 @@ def export_onnx(model, im, file, opset, dynamic, fp16, simplify):

def export_openvino(file, half):
__tr.check_packages(
("openvino-dev",)
("openvino-dev>=2023.0",)
) # requires openvino-dev: https://pypi.org/project/openvino-dev/
import openvino.runtime as ov # noqa
from openvino.tools import mo # noqa
Expand Down Expand Up @@ -258,7 +259,7 @@ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose
parser.add_argument(
"--weights",
type=Path,
default=WEIGHTS / "mobilenetv2_x1_4_dukemtmcreid.pt",
default=WEIGHTS / "osnet_x0_25_msmt17.pt",
help="model.pt path(s)",
)
parser.add_argument(
Expand Down
3 changes: 2 additions & 1 deletion boxmot/appearance/reid_multibackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def __init__(
elif self.onnx: # ONNX Runtime
LOGGER.info(f"Loading {w} for ONNX Runtime inference...")
cuda = torch.cuda.is_available() and device.type != "cpu"
tr.check_packages(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime", ))
# https://onnxruntime.ai/docs/reference/compatibility.html
tr.check_packages(("onnx", "onnxruntime-gpu==1.16.3" if cuda else "onnxruntime==1.16.3", ))
import onnxruntime

providers = (
Expand Down
2 changes: 1 addition & 1 deletion boxmot/tracker_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def create_tracker(tracker_type, tracker_config, reid_weights, device, half, per
det_thresh=cfg.det_thresh,
max_age=cfg.max_age,
min_hits=cfg.min_hits,
iou_threshold=cfg.iou_thresh,
asso_threshold=cfg.iou_thresh,
delta_t=cfg.delta_t,
asso_func=cfg.asso_func,
inertia=cfg.inertia,
Expand Down
3 changes: 3 additions & 0 deletions boxmot/trackers/deepocsort/deep_ocsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,10 +434,13 @@ def update(self, dets, img):
matched, unmatched_dets, unmatched_trks = associate(
dets[:, 0:5],
trks,
self.asso_func,
self.iou_threshold,
velocities,
k_observations,
self.inertia,
img.shape[1], # w
img.shape[0], # h
stage1_emb_cost,
self.w_association_emb,
self.aw_off,
Expand Down
24 changes: 13 additions & 11 deletions boxmot/trackers/ocsort/ocsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from boxmot.motion.kalman_filters.ocsort_kf import KalmanFilter
from boxmot.utils.association import associate, linear_assignment
from boxmot.utils.iou import get_asso_func
from boxmot.utils.iou import run_asso_func


def k_previous_obs(observations, cur_age, k):
Expand Down Expand Up @@ -193,7 +194,7 @@ def __init__(
det_thresh=0.2,
max_age=30,
min_hits=3,
iou_threshold=0.3,
asso_threshold=0.3,
delta_t=3,
asso_func="iou",
inertia=0.2,
Expand All @@ -204,7 +205,7 @@ def __init__(
"""
self.max_age = max_age
self.min_hits = min_hits
self.iou_threshold = iou_threshold
self.asso_threshold = asso_threshold
self.trackers = []
self.frame_count = 0
self.det_thresh = det_thresh
Expand All @@ -214,7 +215,7 @@ def __init__(
self.use_byte = use_byte
KalmanBoxTracker.count = 0

def update(self, dets, _):
def update(self, dets, img):
"""
Params:
dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],[x1,y1,x2,y2,score],...]
Expand All @@ -235,6 +236,7 @@ def update(self, dets, _):
), "Unsupported 'dets' 2nd dimension lenght, valid lenghts is 6"

self.frame_count += 1
h, w = img.shape[0:2]

dets = np.hstack([dets, np.arange(len(dets)).reshape(-1, 1)])
confs = dets[:, 4]
Expand Down Expand Up @@ -279,7 +281,7 @@ def update(self, dets, _):
First round of association
"""
matched, unmatched_dets, unmatched_trks = associate(
dets[:, 0:5], trks, self.iou_threshold, velocities, k_observations, self.inertia
dets[:, 0:5], trks, self.asso_func, self.asso_threshold, velocities, k_observations, self.inertia, w, h
)
for m in matched:
self.trackers[m[1]].update(dets[m[0], :5], dets[m[0], 5], dets[m[0], 6])
Expand All @@ -294,17 +296,17 @@ def update(self, dets, _):
dets_second, u_trks
) # iou between low score detections and unmatched tracks
iou_left = np.array(iou_left)
if iou_left.max() > self.iou_threshold:
if iou_left.max() > self.asso_threshold:
"""
NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may
NOTE: by using a lower threshold, e.g., self.asso_threshold - 0.1, you may
get a higher performance especially on MOT17/MOT20 datasets. But we keep it
uniform here for simplicity
"""
matched_indices = linear_assignment(-iou_left)
to_remove_trk_indices = []
for m in matched_indices:
det_ind, trk_ind = m[0], unmatched_trks[m[1]]
if iou_left[m[0], m[1]] < self.iou_threshold:
if iou_left[m[0], m[1]] < self.asso_threshold:
continue
self.trackers[trk_ind].update(
dets_second[det_ind, :5], dets_second[det_ind, 5], dets_second[det_ind, 6]
Expand All @@ -317,11 +319,11 @@ def update(self, dets, _):
if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0:
left_dets = dets[unmatched_dets]
left_trks = last_boxes[unmatched_trks]
iou_left = self.asso_func(left_dets, left_trks)
iou_left = run_asso_func(self.asso_func, left_dets, left_trks, w, h)
iou_left = np.array(iou_left)
if iou_left.max() > self.iou_threshold:
if iou_left.max() > self.asso_threshold:
"""
NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may
NOTE: by using a lower threshold, e.g., self.asso_threshold - 0.1, you may
get a higher performance especially on MOT17/MOT20 datasets. But we keep it
uniform here for simplicity
"""
Expand All @@ -330,7 +332,7 @@ def update(self, dets, _):
to_remove_trk_indices = []
for m in rematched_indices:
det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[1]]
if iou_left[m[0], m[1]] < self.iou_threshold:
if iou_left[m[0], m[1]] < self.asso_threshold:
continue
self.trackers[trk_ind].update(dets[det_ind, :5], dets[det_ind, 5], dets[det_ind, 6])
to_remove_det_indices.append(det_ind)
Expand Down
2 changes: 1 addition & 1 deletion boxmot/trackers/strongsort/sort/linear_assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from scipy.optimize import linear_sum_assignment

from ....utils.matching import chi2inv95
from boxmot.utils.matching import chi2inv95

INFTY_COST = 1e5

Expand Down
9 changes: 7 additions & 2 deletions boxmot/utils/association.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from boxmot.utils.iou import iou_batch
from boxmot.utils.iou import iou_batch, centroid_batch, run_asso_func


def speed_direction_batch(dets, tracks):
Expand Down Expand Up @@ -111,14 +111,18 @@ def compute_aw_max_metric(emb_cost, w_association_emb, bottom=0.5):
def associate(
detections,
trackers,
asso_func,
iou_threshold,
velocities,
previous_obs,
vdc_weight,
w,
h,
emb_cost=None,
w_assoc_emb=None,
aw_off=None,
aw_param=None,

):
if len(trackers) == 0:
return (
Expand All @@ -139,7 +143,8 @@ def associate(
valid_mask = np.ones(previous_obs.shape[0])
valid_mask[np.where(previous_obs[:, 4] < 0)] = 0

iou_matrix = iou_batch(detections, trackers)
iou_matrix = run_asso_func(asso_func, detections, trackers, w, h)
#iou_matrix = iou_batch(detections, trackers)
scores = np.repeat(detections[:, -1][:, np.newaxis], trackers.shape[0], axis=1)
# iou_matrix = iou_matrix * scores # a trick sometiems works, we don't encourage this
valid_mask = np.repeat(valid_mask[:, np.newaxis], X.shape[1], axis=1)
Expand Down
60 changes: 56 additions & 4 deletions boxmot/utils/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np


def iou_batch(bboxes1, bboxes2):
def iou_batch(bboxes1, bboxes2) -> np.ndarray:
"""
From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2]
"""
Expand All @@ -25,7 +25,7 @@ def iou_batch(bboxes1, bboxes2):
return o


def giou_batch(bboxes1, bboxes2):
def giou_batch(bboxes1, bboxes2) -> np.ndarray:
"""
:param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2)
:param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2)
Expand Down Expand Up @@ -62,7 +62,7 @@ def giou_batch(bboxes1, bboxes2):
return giou


def diou_batch(bboxes1, bboxes2):
def diou_batch(bboxes1, bboxes2) -> np.ndarray:
"""
:param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2)
:param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2)
Expand Down Expand Up @@ -105,7 +105,7 @@ def diou_batch(bboxes1, bboxes2):
return (diou + 1) / 2.0 # resize from (-1,1) to (0,1)


def ciou_batch(bboxes1, bboxes2):
def ciou_batch(bboxes1, bboxes2) -> np.ndarray:
"""
:param bbox_p: predict of bbox(N,4)(x1,y1,x2,y2)
:param bbox_g: groundtruth of bbox(N,4)(x1,y1,x2,y2)
Expand Down Expand Up @@ -161,12 +161,64 @@ def ciou_batch(bboxes1, bboxes2):
return (ciou + 1) / 2.0 # resize from (-1,1) to (0,1)


def centroid_batch(bboxes1, bboxes2, w, h) -> np.ndarray:
"""
Computes the normalized centroid distance between two sets of bounding boxes.
Bounding boxes are in the format [x1, y1, x2, y2].
`normalize_scale` is a tuple (width, height) to normalize the distance.
"""

# Calculate centroids
centroids1 = np.stack(((bboxes1[..., 0] + bboxes1[..., 2]) / 2,
(bboxes1[..., 1] + bboxes1[..., 3]) / 2), axis=-1)
centroids2 = np.stack(((bboxes2[..., 0] + bboxes2[..., 2]) / 2,
(bboxes2[..., 1] + bboxes2[..., 3]) / 2), axis=-1)

# Expand dimensions for broadcasting
centroids1 = np.expand_dims(centroids1, 1)
centroids2 = np.expand_dims(centroids2, 0)

# Calculate Euclidean distances
distances = np.sqrt(np.sum((centroids1 - centroids2) ** 2, axis=-1))

# Normalize distances
norm_factor = np.sqrt(w**2 + h**2)
normalized_distances = distances / norm_factor

return 1 - normalized_distances


def run_asso_func(func, *args):
"""
Wrapper function that checks the inputs to the association functions
and then call either one of the iou association functions or centroid.
Parameters:
func: The batch function to call (either *iou*_batch or centroid_batch).
*args: Variable length argument list, containing either bounding boxes and optionally size parameters.
"""
if func not in [iou_batch, giou_batch, diou_batch, ciou_batch, centroid_batch]:
raise ValueError("Invalid function specified. Must be either '(g,d,c, )iou_batch' or 'centroid_batch'.")

if func in (iou_batch, giou_batch, diou_batch, ciou_batch):
if len(args) != 4 or not all(isinstance(arg, (list, np.ndarray)) for arg in args[0:2]):
raise ValueError("Invalid arguments for iou_batch. Expected two bounding boxes.")
return func(*args[0:2])
elif func is centroid_batch:
if len(args) != 4 or not all(isinstance(arg, (list, np.ndarray)) for arg in args[:2]) or not all(isinstance(arg, (int)) for arg in args[2:]):
raise ValueError("Invalid arguments for centroid_batch. Expected two bounding boxes and two size parameters.")
return func(*args)
else:
raise ValueError("No such association method")


def get_asso_func(asso_mode):
ASSO_FUNCS = {
"iou": iou_batch,
"giou": giou_batch,
"ciou": ciou_batch,
"diou": diou_batch,
"centroid": centroid_batch
}

return ASSO_FUNCS[asso_mode]
Loading

0 comments on commit b4cd344

Please sign in to comment.