From a8f2af08cb512bea776aa406d2fe695d6ee71ca2 Mon Sep 17 00:00:00 2001 From: BingfengYan <1039333912@qq.com> Date: Wed, 31 May 2023 13:34:08 +0800 Subject: [PATCH] add CO-MOT for multi object tracking (#266) * add CO-MOT for multi-object tracking * add CO-MOT for multi-object tracking * Simplify the code of CO_MOT * merge data+dataloader to co-mot --------- Co-authored-by: yangmasheng --- README.md | 9 +- demo/mot_demo.py | 219 +++++ demo/mot_predictors.py | 339 +++++++ detrex/data/__init__.py | 2 +- projects/co_mot/README.md | 96 ++ .../configs/common/dancetrack_schedule.py | 38 + .../configs/common/data/dancetrack_mot.py | 67 ++ projects/co_mot/configs/mot_r50.py | 137 +++ .../co_mot/configs/mot_r50_4scale_10ep.py | 73 ++ projects/co_mot/data/__init__.py | 28 + projects/co_mot/data/datasets/__init__.py | 31 + .../data/datasets/register_dancetrack_mot.py | 223 +++++ projects/co_mot/data/mot_build.py | 659 +++++++++++++ projects/co_mot/data/mot_dataset_mapper.py | 248 +++++ projects/co_mot/data/transforms/__init__.py | 26 + .../co_mot/data/transforms/mot_transforms.py | 617 +++++++++++++ projects/co_mot/evaluation/__init__.py | 4 + .../evaluation/dancetrack_evaluation.py | 270 ++++++ projects/co_mot/modeling/__init__.py | 37 + projects/co_mot/modeling/matcher.py | 128 +++ projects/co_mot/modeling/mot.py | 868 ++++++++++++++++++ projects/co_mot/modeling/mot_transformer.py | 574 ++++++++++++ projects/co_mot/modeling/qim.py | 209 +++++ projects/co_mot/train_net.py | 283 ++++++ projects/co_mot/util/__init__.py | 10 + projects/co_mot/util/checkpoint.py | 40 + projects/co_mot/util/misc.py | 164 ++++ requirements.txt | 3 +- 28 files changed, 5399 insertions(+), 3 deletions(-) create mode 100644 demo/mot_demo.py create mode 100644 demo/mot_predictors.py create mode 100644 projects/co_mot/README.md create mode 100644 projects/co_mot/configs/common/dancetrack_schedule.py create mode 100644 projects/co_mot/configs/common/data/dancetrack_mot.py create mode 100644 projects/co_mot/configs/mot_r50.py create mode 100644 projects/co_mot/configs/mot_r50_4scale_10ep.py create mode 100644 projects/co_mot/data/__init__.py create mode 100644 projects/co_mot/data/datasets/__init__.py create mode 100644 projects/co_mot/data/datasets/register_dancetrack_mot.py create mode 100644 projects/co_mot/data/mot_build.py create mode 100644 projects/co_mot/data/mot_dataset_mapper.py create mode 100644 projects/co_mot/data/transforms/__init__.py create mode 100755 projects/co_mot/data/transforms/mot_transforms.py create mode 100644 projects/co_mot/evaluation/__init__.py create mode 100644 projects/co_mot/evaluation/dancetrack_evaluation.py create mode 100644 projects/co_mot/modeling/__init__.py create mode 100755 projects/co_mot/modeling/matcher.py create mode 100644 projects/co_mot/modeling/mot.py create mode 100755 projects/co_mot/modeling/mot_transformer.py create mode 100644 projects/co_mot/modeling/qim.py create mode 100644 projects/co_mot/train_net.py create mode 100755 projects/co_mot/util/__init__.py create mode 100644 projects/co_mot/util/checkpoint.py create mode 100755 projects/co_mot/util/misc.py diff --git a/README.md b/README.md index cdbdb395..141f7db2 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ Results and models are available in [model zoo](https://detrex.readthedocs.io/en - [x] [DINO (ICLR'2023)](./projects/dino/) - [x] [H-Deformable-DETR (CVPR'2023)](./projects/h_deformable_detr/) - [x] [MaskDINO (CVPR'2023)](./projects/maskdino/) - +- [x] [CO-MOT (ArXiv'2023)](./projects/co_mot/) Please see [projects](./projects/) for the details about projects that are built based on detrex. @@ -222,8 +222,15 @@ relevant publications: archivePrefix={arXiv}, primaryClass={cs.CV} } +@article{yan2023bridging, + title={Bridging the Gap Between End-to-end and Non-End-to-end Multi-Object Tracking}, + author={Yan, Feng and Luo, Weixin and Zhong, Yujie and Gan, Yiyang and Ma, Lin}, + journal={arXiv preprint arXiv:2305.12724}, + year={2023} +} ``` + diff --git a/demo/mot_demo.py b/demo/mot_demo.py new file mode 100644 index 00000000..d6e0b022 --- /dev/null +++ b/demo/mot_demo.py @@ -0,0 +1,219 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import argparse +import glob +import multiprocessing as mp +import numpy as np +import os +import sys +import tempfile +import time +import warnings +import cv2 +import tqdm + +sys.path.insert(0, "./") # noqa +from demo.mot_predictors import VisualizationDemo +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import LazyConfig, instantiate +from detectron2.data.detection_utils import read_image +from detectron2.utils.logger import setup_logger + + +# constants +WINDOW_NAME = "MOT" + + +def setup(args): + cfg = LazyConfig.load(args.config_file) + cfg = LazyConfig.apply_overrides(cfg, args.opts) + return cfg + + +def get_parser(): + parser = argparse.ArgumentParser(description="detrex demo for visualizing customized inputs") + parser.add_argument( + "--config-file", + default="projects/dino/configs/dino_r50_4scale_12ep.py", + metavar="FILE", + help="path to config file", + ) + parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.") + parser.add_argument("--video-input", help="Path to video file.") + parser.add_argument( + "--input", + nargs="+", + help="A list of space separated input images; " + "or a single glob pattern such as 'directory/*.jpg'", + ) + parser.add_argument( + "--output", + help="A file or directory to save output visualizations. " + "If not given, will show output in an OpenCV window.", + ) + parser.add_argument( + "--min_size_test", + type=int, + default=800, + help="Size of the smallest side of the image during testing. Set to zero to disable resize in testing.", + ) + parser.add_argument( + "--max_size_test", + type=float, + default=1333, + help="Maximum size of the side of the image during testing.", + ) + parser.add_argument( + "--img_format", + type=str, + default="RGB", + help="The format of the loading images.", + ) + parser.add_argument( + "--metadata_dataset", + type=str, + default="coco_2017_val", + help="The metadata infomation to be used. Default to COCO val metadata.", + ) + parser.add_argument( + "--confidence-threshold", + type=float, + default=0.5, + help="Minimum score for instance predictions to be shown", + ) + parser.add_argument( + "--opts", + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER, + ) + return parser + + +def test_opencv_video_format(codec, file_ext): + with tempfile.TemporaryDirectory(prefix="video_format_test") as dir: + filename = os.path.join(dir, "test_file" + file_ext) + writer = cv2.VideoWriter( + filename=filename, + fourcc=cv2.VideoWriter_fourcc(*codec), + fps=float(30), + frameSize=(10, 10), + isColor=True, + ) + [writer.write(np.zeros((10, 10, 3), np.uint8)) for _ in range(30)] + writer.release() + if os.path.isfile(filename): + return True + return False + + +if __name__ == "__main__": + mp.set_start_method("spawn", force=True) + args = get_parser().parse_args() + setup_logger(name="fvcore") + logger = setup_logger() + logger.info("Arguments: " + str(args)) + + cfg = setup(args) + + model = instantiate(cfg.model) + model.to(cfg.train.device) + checkpointer = DetectionCheckpointer(model) + checkpointer.load(cfg.train.init_checkpoint) + + model.eval() + + demo = VisualizationDemo( + model=model, + min_size_test=args.min_size_test, + max_size_test=args.max_size_test, + img_format=args.img_format, + metadata_dataset=args.metadata_dataset, + ) + + if args.input: + if len(args.input) == 1: + args.input = glob.glob(os.path.expanduser(args.input[0])) + assert args.input, "The input path(s) was not found" + args.input = sorted(args.input) + for path in tqdm.tqdm(args.input, disable=not args.output): + # use PIL, to be consistent with evaluation + img = read_image(path, format="BGR") + start_time = time.time() + predictions, visualized_output = demo.run_on_image(img, args.confidence_threshold) + logger.info( + "{}: {} in {:.2f}s".format( + path, + "detected {} instances".format(len(predictions["instances"])) + if "instances" in predictions + else "finished", + time.time() - start_time, + ) + ) + + if args.output: + if os.path.isdir(args.output): + assert os.path.isdir(args.output), args.output + out_filename = os.path.join(args.output, os.path.basename(path)) + else: + assert len(args.input) == 1, "Please specify a directory with args.output" + out_filename = args.output + visualized_output.save(out_filename) + else: + cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL) + cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1]) + if cv2.waitKey(0) == 27: + break # esc to quit + elif args.webcam: + assert args.input is None, "Cannot have both --input and --webcam!" + assert args.output is None, "output not yet supported with --webcam!" + cam = cv2.VideoCapture(0) + for vis in tqdm.tqdm(demo.run_on_video(cam)): + cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL) + cv2.imshow(WINDOW_NAME, vis) + if cv2.waitKey(1) == 27: + break # esc to quit + cam.release() + cv2.destroyAllWindows() + elif args.video_input: + video = cv2.VideoCapture(args.video_input) + width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) + frames_per_second = video.get(cv2.CAP_PROP_FPS) + num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + basename = os.path.basename(args.video_input) + codec, file_ext = ( + ("x264", ".mkv") if test_opencv_video_format("x264", ".mkv") else ("mp4v", ".mp4") + ) + if codec == ".mp4v": + warnings.warn("x264 codec not available, switching to mp4v") + if args.output: + if os.path.isdir(args.output): + output_fname = os.path.join(args.output, basename) + output_fname = os.path.splitext(output_fname)[0] + file_ext + else: + output_fname = args.output + + # assert not os.path.isfile(output_fname), output_fname + output_file = cv2.VideoWriter( + filename=output_fname, + # some installation of opencv may not support x264 (due to its license), + # you can try other format (e.g. MPEG) + fourcc=cv2.VideoWriter_fourcc(*codec), + fps=float(frames_per_second), + frameSize=(width, height), + isColor=True, + ) + assert os.path.isfile(args.video_input) + for vis_frame in tqdm.tqdm(demo.run_on_video(video), total=num_frames): + if args.output: + output_file.write(vis_frame) + else: + cv2.namedWindow(basename, cv2.WINDOW_NORMAL) + cv2.imshow(basename, vis_frame) + if cv2.waitKey(1) == 27: + break # esc to quit + video.release() + if args.output: + output_file.release() + else: + cv2.destroyAllWindows() diff --git a/demo/mot_predictors.py b/demo/mot_predictors.py new file mode 100644 index 00000000..6b0553d7 --- /dev/null +++ b/demo/mot_predictors.py @@ -0,0 +1,339 @@ +import atexit +import bisect +from copy import copy +import multiprocessing as mp +from collections import deque +from copy import deepcopy +import cv2 +import torch +import torchvision.transforms.functional as F + +import detectron2.data.transforms as T +from detectron2.data import MetadataCatalog +from detectron2.structures import Instances +from detectron2.utils.visualizer import ( + ColorMode, + Visualizer, + _create_text_labels, + ) +from detectron2.utils.video_visualizer import ( + _DetectedInstance, + VideoVisualizer, +) + +class MOTVideoVisualizer(VideoVisualizer): + + def draw_instance_track(self, frame, predictions): + """ + Draw instance-level prediction results on an image. + + Args: + frame (ndarray): an RGB image of shape (H, W, C), in the range [0, 255]. + predictions (Instances): the output of an instance detection/segmentation + model. Following fields will be used to draw: + "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle"). + + Returns: + output (VisImage): image object with visualizations. + """ + frame_visualizer = Visualizer(frame, self.metadata) + num_instances = len(predictions) + if num_instances == 0: + return frame_visualizer.output + + boxes = predictions.boxes.numpy() if predictions.has("boxes") else None + scores = predictions.scores if predictions.has("scores") else None + classes = predictions.labels.numpy() if predictions.has("labels") else None + keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None + colors = predictions.COLOR if predictions.has("COLOR") else [None] * len(predictions) + periods = predictions.obj_idxes if predictions.has("obj_idxes") else None + period_threshold = self.metadata.get("period_threshold", -1) + visibilities = ( + [True] * len(predictions) + if periods is None + else [x > period_threshold for x in periods] + ) + + if predictions.has("pred_masks"): + masks = predictions.pred_masks + # mask IOU is not yet enabled + # masks_rles = mask_util.encode(np.asarray(masks.permute(1, 2, 0), order="F")) + # assert len(masks_rles) == num_instances + else: + masks = None + + if not predictions.has("COLOR"): + if predictions.has("obj_idxes"): + predictions.ID = predictions.obj_idxes.numpy() + colors = self._assign_colors_by_id(predictions) + else: + # ToDo: clean old assign color method and use a default tracker to assign id + detected = [ + _DetectedInstance(classes[i], boxes[i], mask_rle=None, color=colors[i], ttl=8) + for i in range(num_instances) + ] + colors = self._assign_colors(detected) + + labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None)) + + if self._instance_mode == ColorMode.IMAGE_BW: + # any() returns uint8 tensor + frame_visualizer.output.reset_image( + frame_visualizer._create_grayscale_image( + (masks.any(dim=0) > 0).numpy() if masks is not None else None + ) + ) + alpha = 0.3 + else: + alpha = 0.5 + + labels = ( + None + if labels is None + else [y[0] for y in filter(lambda x: x[1], zip(labels, visibilities))] + ) # noqa + assigned_colors = ( + None + if colors is None + else [y[0] for y in filter(lambda x: x[1], zip(colors, visibilities))] + ) # noqa + frame_visualizer.overlay_instances( + boxes=None if masks is not None else boxes[visibilities], # boxes are a bit distracting + masks=None if masks is None else masks[visibilities], + labels=labels, + keypoints=None if keypoints is None else keypoints[visibilities], + assigned_colors=assigned_colors, + alpha=alpha, + ) + + return frame_visualizer.output + + +def filter_predictions_with_area(predictions, area_threshold=100): + if "track_instances" in predictions: + preds = predictions["track_instances"] + wh = preds.boxes[:, 2:4] - preds.boxes[:, 0:2] + areas = wh[:, 0] * wh[:, 1] + keep_idxs = areas > area_threshold + predictions = copy(predictions) # don't modify the original + predictions["track_instances"] = preds[keep_idxs] + return predictions + +def filter_predictions_with_confidence(predictions, confidence_threshold=0.5): + if "track_instances" in predictions: + preds = predictions["track_instances"] + keep_idxs = preds.scores > confidence_threshold + predictions = copy(predictions) # don't modify the original + predictions["track_instances"] = preds[keep_idxs] + return predictions + + +class VisualizationDemo(object): + def __init__( + self, + model, + min_size_test=800, + max_size_test=1333, + img_format="RGB", + metadata_dataset="coco_2017_val", + instance_mode=ColorMode.IMAGE, + parallel=False, + ): + """ + Args: + cfg (CfgNode): + instance_mode (ColorMode): + parallel (bool): whether to run the model in different processes from visualization. + Useful since the visualization logic can be slow. + """ + self.metadata = MetadataCatalog.get( + metadata_dataset if metadata_dataset is not None else "__unused" + ) + self.cpu_device = torch.device("cpu") + self.instance_mode = instance_mode + + self.parallel = parallel + if parallel: + assert False + else: + self.predictor = DefaultPredictor( + model=model, + min_size_test=min_size_test, + max_size_test=max_size_test, + img_format=img_format, + metadata_dataset=metadata_dataset, + ) + + def run_on_image(self, image, threshold=0.5): + """ + Args: + image (np.ndarray): an image of shape (H, W, C) (in BGR order). + This is the format used by OpenCV. + + Returns: + predictions (dict): the output of the model. + vis_output (VisImage): the visualized image output. + """ + vis_output = None + predictions = self.predictor(image) + predictions = filter_predictions_with_confidence(predictions, threshold) + predictions = filter_predictions_with_area(predictions) + # Convert image from OpenCV BGR format to Matplotlib RGB format. + image = image[:, :, ::-1] + visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode) + if "panoptic_seg" in predictions: + panoptic_seg, segments_info = predictions["panoptic_seg"] + vis_output = visualizer.draw_panoptic_seg_predictions( + panoptic_seg.to(self.cpu_device), segments_info + ) + else: + if "sem_seg" in predictions: + vis_output = visualizer.draw_sem_seg( + predictions["sem_seg"].argmax(dim=0).to(self.cpu_device) + ) + if "instances" in predictions: + instances = predictions["instances"].to(self.cpu_device) + vis_output = visualizer.draw_instance_predictions(predictions=instances) + if "track_instances" in predictions: + predictions = predictions["track_instances"].to(self.cpu_device) + vis_output = visualizer.draw_instance_predictions(predictions) + + return predictions, vis_output + + def _frame_from_video(self, video): + while video.isOpened(): + success, frame = video.read() + if success: + yield frame + else: + break + + def run_on_video(self, video, threshold=0.5): + """ + Visualizes predictions on frames of the input video. + + Args: + video (cv2.VideoCapture): a :class:`VideoCapture` object, whose source can be + either a webcam or a video file. + + Yields: + ndarray: BGR visualizations of each video frame. + """ + video_visualizer = MOTVideoVisualizer(self.metadata, self.instance_mode) + + def process_predictions(frame, predictions, threshold): + predictions = filter_predictions_with_confidence(predictions, threshold) + predictions = filter_predictions_with_area(predictions) + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + if "panoptic_seg" in predictions: + panoptic_seg, segments_info = predictions["panoptic_seg"] + vis_frame = video_visualizer.draw_panoptic_seg_predictions( + frame, panoptic_seg.to(self.cpu_device), segments_info + ) + elif "instances" in predictions: + predictions = predictions["instances"].to(self.cpu_device) + vis_frame = video_visualizer.draw_instance_predictions(frame, predictions) + elif "sem_seg" in predictions: + vis_frame = video_visualizer.draw_sem_seg( + frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device) + ) + elif "track_instances" in predictions: + predictions = predictions["track_instances"].to(self.cpu_device) + vis_frame = video_visualizer.draw_instance_track(frame, predictions) + + # Converts Matplotlib RGB format to OpenCV BGR format + vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR) + return vis_frame + + frame_gen = self._frame_from_video(video) + if self.parallel: + buffer_size = self.predictor.default_buffer_size + + frame_data = deque() + + for cnt, frame in enumerate(frame_gen): + frame_data.append(frame) + self.predictor.put(frame) + + if cnt >= buffer_size: + frame = frame_data.popleft() + predictions = self.predictor.get() + yield process_predictions(frame, predictions, threshold) + + while len(frame_data): + frame = frame_data.popleft() + predictions = self.predictor.get() + yield process_predictions(frame, predictions, threshold) + else: + for frame in frame_gen: + yield process_predictions(frame, self.predictor(frame), threshold) + + +class DefaultPredictor: + def __init__( + self, + model, + min_size_test=800, + max_size_test=1536, + img_format="RGB", + mean = [0.485, 0.456, 0.406], + std = [0.229, 0.224, 0.225], + metadata_dataset="coco_2017_val", + ): + self.model = model + # self.model.eval() + self.mean = mean + self.std = std + + self.metadata = MetadataCatalog.get(metadata_dataset) + + # checkpointer = DetectionCheckpointer(self.model) + # checkpointer.load(init_checkpoint) + + # self.aug = T.ResizeShortestEdge([min_size_test, min_size_test], max_size_test) + self.img_height = min_size_test + self.img_width = max_size_test + + self.input_format = img_format + self.track_instances = None + assert self.input_format in ["RGB", "BGR"], self.input_format + + def __call__(self, original_image): + """ + Args: + original_image (np.ndarray): an image of shape (H, W, C) (in BGR order). + + Returns: + predictions (dict): + the output of the model for one image only. + See :doc:`/tutorials/models` for details about the format. + """ + with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258 + # Apply pre-processing to image. + if self.input_format == "RGB": + # whether the model expects BGR inputs or RGB + original_image = original_image[:, :, ::-1] + height, width = original_image.shape[:2] + self.seq_h, self.seq_w = original_image.shape[:2] + scale = self.img_height / min(self.seq_h, self.seq_w) + if max(self.seq_h, self.seq_w) * scale > self.img_width: + scale = self.img_width / max(self.seq_h, self.seq_w) + target_h = int(self.seq_h * scale) + target_w = int(self.seq_w * scale) + image = cv2.resize(original_image, (target_w, target_h)) + + image = F.normalize(F.to_tensor(image), self.mean, self.std) + image = image.to(self.model.device) + image = image.unsqueeze(0) + + res = self.model.inference_single_image(image, (height, width), self.track_instances) + + self.track_instances = res['track_instances'] + predictions = deepcopy(res) + if len(predictions['track_instances']): + scores = predictions['track_instances'].scores.reshape(-1, self.model.g_size) + keep_idxs = torch.arange(len(predictions['track_instances']), device=scores.device).reshape(-1, self.model.g_size) + keep_idxs = keep_idxs.gather(1, scores.max(-1)[1].reshape(-1, 1)).reshape(-1) + predictions['track_instances'] = predictions['track_instances'][keep_idxs] + + return predictions diff --git a/detrex/data/__init__.py b/detrex/data/__init__.py index 4add9476..94e9cabc 100644 --- a/detrex/data/__init__.py +++ b/detrex/data/__init__.py @@ -22,4 +22,4 @@ MaskFormerPanopticDatasetMapper, ) from . import datasets -from .transforms import ColorAugSSDTransform \ No newline at end of file +from .transforms import ColorAugSSDTransform diff --git a/projects/co_mot/README.md b/projects/co_mot/README.md new file mode 100644 index 00000000..00c76348 --- /dev/null +++ b/projects/co_mot/README.md @@ -0,0 +1,96 @@ + +# CO-MOT: Bridging the Gap Between End-to-end and Non-End-to-end Multi-Object Tracking + + +[![arXiv]](https://arxiv.org/abs/2305.12724) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bridging-the-gap-between-end-to-end-and-non/multi-object-tracking-on-dancetrack)](https://paperswithcode.com/sota/multi-object-tracking-on-dancetrack?p=bridging-the-gap-between-end-to-end-and-non) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bridging-the-gap-between-end-to-end-and-non/multi-object-tracking-on-bdd100k)](https://paperswithcode.com/sota/multi-object-tracking-on-bdd100k?p=bridging-the-gap-between-end-to-end-and-non) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bridging-the-gap-between-end-to-end-and-non/multi-object-tracking-on-mot17)](https://paperswithcode.com/sota/multi-object-tracking-on-mot17?p=bridging-the-gap-between-end-to-end-and-non) + + +This repository is an official implementation of [CO-MOT](https://arxiv.org/abs/2305.12724). + +**TO DO** +1. release bdd100K, MOT17 model. +2. add DINO backbone + +## Introduction + +Bridging the Gap Between End-to-end and Non-End-to-end Multi-Object Tracking. + +**Abstract.** Existing end-to-end Multi-Object Tracking (e2e-MOT) methods have not surpassed non-end-to-end tracking-by-detection methods. One potential reason is its label assignment strategy during training that consistently binds the tracked objects with tracking queries and then assigns the few newborns to detection queries. With one-to-one bipartite matching, such an assignment will yield unbalanced training, i.e., scarce positive samples for detection queries, especially for an enclosed scene, as the majority of the newborns come on stage at the beginning of videos. Thus, e2e-MOT will be easier to yield a tracking terminal without renewal or re-initialization, compared to other tracking-by-detection methods. To alleviate this problem, we present Co-MOT, a simple and effective method to facilitate e2e-MOT by a novel coopetition label assignment with a shadow concept. Specifically, we add tracked objects to the matching targets for detection queries when performing the label assignment for training the intermediate decoders. For query initialization, we expand each query by a set of shadow counterparts with limited disturbance to itself. With extensive ablations, Co-MOT achieves superior performance without extra costs, e.g., 69.4% HOTA on DanceTrack and 52.8% TETA on BDD100K. Impressively, Co-MOT only requires 38\% FLOPs of MOTRv2 to attain a similar performance, resulting in the 1.4× faster inference speed. + + +## Main Results + +### DanceTrack + +| **HOTA** | **DetA** | **AssA** | **MOTA** | **IDF1** | **URL** | +| :------: | :------: | :------: | :------: | :------: | :-----------------------------------------------------------------------------------------: | +| 69.9 | 82.1 | 58.9 | 91.2 | 71.9 | [model](https://drive.google.com/file/d/15HOnAUlYRjFBQVIsek1Qbgf18Pkffy-A/view?usp=share_link) | + + + +## Usage + +### Dataset preparation + +1. Please download [DanceTrack](https://dancetrack.github.io/) and [CrowdHuman](https://www.crowdhuman.org/) and unzip them as follows: + +``` +/data/Dataset/mot +├── crowdhuman +│ ├── annotation_train.odgt +│ ├── annotation_trainval.odgt +│ ├── annotation_val.odgt +│ └── Images +├── DanceTrack +│ ├── test +│ ├── train +│ └── val +``` + + +## Evaluation +Model evaluation can be done as follows: +```bash +python tools/train_net.py --config-file projects/co_mot/configs/mot_r50_4scale_10ep.py --eval-only train.init_checkpoint=./co_mot_dancetrack.pth train.device=cuda +``` + +## Demo +Demo can be done as follows: +```bash +python tools/train_net.py --config-file projects/co_mot/configs/mot_r50.py --video-input ./demo_video.avi --output visualize_video_results.mp4 --opts train.init_checkpoint=./co_mot_dancetrack.pth train.device=cuda +``` + +## Citing DINO +If you find our work helpful for your research, please consider citing the following BibTeX entry. + +```BibTex +@article{yan2023bridging, + title={Bridging the Gap Between End-to-end and Non-End-to-end Multi-Object Tracking}, + author={Yan, Feng and Luo, Weixin and Zhong, Yujie and Gan, Yiyang and Ma, Lin}, + journal={arXiv preprint arXiv:2305.12724}, + year={2023} +} +``` + + +## Acknowledgements + +- [MOTR](https://github.com/megvii-research/MOTR) +- [ByteTrack](https://github.com/ifzhang/ByteTrack) +- [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX) +- [OC-SORT](https://github.com/noahcao/OC_SORT) +- [DanceTrack](https://github.com/DanceTrack/DanceTrack) +- [BDD100K](https://github.com/bdd100k/bdd100k) +- [MOTRv2](https://github.com/megvii-research/MOTRv2) \ No newline at end of file diff --git a/projects/co_mot/configs/common/dancetrack_schedule.py b/projects/co_mot/configs/common/dancetrack_schedule.py new file mode 100644 index 00000000..8bee1f21 --- /dev/null +++ b/projects/co_mot/configs/common/dancetrack_schedule.py @@ -0,0 +1,38 @@ +from fvcore.common.param_scheduler import MultiStepParamScheduler + +from detectron2.config import LazyCall as L +from detectron2.solver import WarmupParamScheduler + + +def default_dancetrack_scheduler(epochs=50, decay_epochs=40, warmup_epochs=0, max_iter_epoch=5225): + """ + Returns the config for a default multi-step LR scheduler such as "50epochs", + commonly referred to in papers, where every 1x has the total length of 1440k + training images (~12 COCO epochs). LR is decayed once at the end of training. + + Args: + epochs (int): total training epochs. + decay_epochs (int): lr decay steps. + warmup_epochs (int): warmup epochs. + + Returns: + DictConfig: configs that define the multiplier for LR during training + """ + # total number of iterations assuming 8 batch size, using 41796/8=5225 + total_steps_16bs = epochs * max_iter_epoch + decay_steps = decay_epochs * max_iter_epoch + warmup_steps = warmup_epochs * max_iter_epoch + scheduler = L(MultiStepParamScheduler)( + values=[1.0, 0.1], + milestones=[decay_steps, total_steps_16bs], + ) + return L(WarmupParamScheduler)( + scheduler=scheduler, + warmup_length=warmup_steps / total_steps_16bs, + warmup_method="linear", + warmup_factor=0.001, + ) + + +# default scheduler for detr +lr_multiplier_12ep = default_dancetrack_scheduler(12, 11, 0, 5225) diff --git a/projects/co_mot/configs/common/data/dancetrack_mot.py b/projects/co_mot/configs/common/data/dancetrack_mot.py new file mode 100644 index 00000000..f0472161 --- /dev/null +++ b/projects/co_mot/configs/common/data/dancetrack_mot.py @@ -0,0 +1,67 @@ +''' +Author: 颜峰 && bphengyan@163.com +Date: 2023-05-25 10:10:31 +LastEditors: 颜峰 && bphengyan@163.com +LastEditTime: 2023-05-26 15:33:44 +FilePath: /detrex/configs/common/data/dancetrack_mot.py +Description: + +Copyright (c) 2023 by ${git_name_email}, All Rights Reserved. +''' +from omegaconf import OmegaConf + +import detectron2.data.transforms as T +from detectron2.config import LazyCall as L +from detectron2.data import get_detection_dataset_dicts + + +from projects.co_mot.data import MotDatasetMapper, MotDatasetInferenceMapper, build_mot_test_loader, build_mot_train_loader, mot_collate_fn +from projects.co_mot.data.transforms import mot_transforms as TMOT +from projects.co_mot.evaluation import DancetrackEvaluator + + +dataloader = OmegaConf.create() + +dataloader.train = L(build_mot_train_loader)( + dataset=L(get_detection_dataset_dicts)(names="dancetrack_train"), + mapper=L(MotDatasetMapper)( + augmentation=TMOT.MotCompose([ + TMOT.MotRandomHorizontalFlip(), + TMOT.MotRandomSelect( + TMOT.MotRandomResize([608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992], max_size=1536), + TMOT.MotCompose([ + TMOT.MotRandomResize([800, 1000, 1200]), + TMOT.FixedMotRandomCrop(800, 1200), + TMOT.MotRandomResize([608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992], max_size=1536), + ]) + ), + TMOT.MOTHSV(), + TMOT.MotCompose([ + TMOT.MotToTensor(), + TMOT.MotNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]), + ]), + augmentation_with_crop=None, + is_train=True, + mask_on=False, + img_format="RGB", + sample_mode='random_interval', + sample_interval=10, + num_frames_per_batch=5, + ), + total_batch_size=16, + num_workers=4, + collate_fn=mot_collate_fn, +) + +dataloader.test = L(build_mot_test_loader)( + dataset=L(get_detection_dataset_dicts)(names="dancetrack_val", filter_empty=False), + mapper=L(MotDatasetInferenceMapper)(), + batch_size=1, + num_workers=4, + collate_fn=None, +) + +dataloader.evaluator = L(DancetrackEvaluator)( + dataset_name="${..test.dataset.names}", +) diff --git a/projects/co_mot/configs/mot_r50.py b/projects/co_mot/configs/mot_r50.py new file mode 100644 index 00000000..62231d75 --- /dev/null +++ b/projects/co_mot/configs/mot_r50.py @@ -0,0 +1,137 @@ +import copy +import torch.nn as nn +from easydict import EasyDict + +from detectron2.modeling.backbone import ResNet, BasicStem +from detectron2.layers import ShapeSpec +from detectron2.config import LazyCall as L + +from detrex.modeling.neck import ChannelMapper +from detrex.layers import PositionEmbeddingSine + +from projects.co_mot.modeling import ( + MOT, + MOTDeformableTransformer, + MOTHungarianMatcherGroup, + MOTQueryInteractionModuleGroup, + MOTClipMatcher, + MOTTrackerPostProcess, + MOTRuntimeTrackerBase, +) +num_frames_per_batch=5 +cls_loss_coef=2 +bbox_loss_coef=5 +giou_loss_coef=2 +aux_loss=True +dec_layers=6 +g_size=3 +weight_dict = {} +for i in range(num_frames_per_batch): + weight_dict.update({"frame_{}_loss_ce".format(i): cls_loss_coef, + 'frame_{}_loss_bbox'.format(i): bbox_loss_coef, + 'frame_{}_loss_giou'.format(i): giou_loss_coef, + }) +# TODO this is a hack +if aux_loss: + for i in range(num_frames_per_batch): + for j in range(dec_layers - 1): + weight_dict.update({"frame_{}_aux{}_loss_ce".format(i, j): cls_loss_coef, + 'frame_{}_aux{}_loss_bbox'.format(i, j): bbox_loss_coef, + 'frame_{}_aux{}_loss_giou'.format(i, j): giou_loss_coef, + }) + for j in range(dec_layers): + weight_dict.update({"frame_{}_ps{}_loss_ce".format(i, j): cls_loss_coef, + 'frame_{}_ps{}_loss_bbox'.format(i, j): bbox_loss_coef, + 'frame_{}_ps{}_loss_giou'.format(i, j): giou_loss_coef, + }) + +model = L(MOT)( + backbone=L(ResNet)( + stem=L(BasicStem)(in_channels=3, out_channels=64, norm="FrozenBN"), + stages=L(ResNet.make_default_stages)( + depth=50, + stride_in_1x1=False, + norm="FrozenBN", + ), + out_features=["res3", "res4", "res5"], + freeze_at=1, + ), + position_embedding=L(PositionEmbeddingSine)( + num_pos_feats=128, + temperature=10000, + normalize=True, + offset=-0.5, + ), + neck=L(ChannelMapper)( + input_shapes={ + "res3": ShapeSpec(channels=512), + "res4": ShapeSpec(channels=1024), + "res5": ShapeSpec(channels=2048), + }, + in_features=["res3", "res4", "res5"], + out_channels=256, + num_outs=4, + kernel_size=1, + norm_layer=L(nn.GroupNorm)(num_groups=32, num_channels=256), + ), + transformer=L(MOTDeformableTransformer)( + d_model=256, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=dec_layers, + dim_feedforward=1024, + dropout=0.0, + activation="relu", + return_intermediate_dec=True, + num_feature_levels=4, + dec_n_points=4, + enc_n_points=4, + two_stage=False, + two_stage_num_proposals=60, + decoder_self_cross=not False, + sigmoid_attn=False, + extra_track_attn=False, + memory_bank=None, + ), + track_embed=L(MOTQueryInteractionModuleGroup)( + args=EasyDict(random_drop=0.1, + fp_ratio=0.3, + update_query_pos=False, + merger_dropout=0.0, + ), + dim_in=256, + hidden_dim=1024, + dim_out=256*2, + ), + embed_dim=256, + num_classes=1, + num_queries=60, + aux_loss=True, + track_base=L(MOTRuntimeTrackerBase)(score_thresh=0.5, filter_score_thresh=0.5, miss_tolerance=20), + post_process=L(MOTTrackerPostProcess)(g_size=g_size), + criterion=L(MOTClipMatcher)( + num_classes=1, + matcher=L(MOTHungarianMatcherGroup)( + cost_class=cls_loss_coef, + cost_bbox=bbox_loss_coef, + cost_giou=giou_loss_coef, + ), + weight_dict=weight_dict, + losses=['labels', 'boxes'], + g_size=g_size + ), + g_size = g_size, +) + +model.device="cuda" + +# # set aux loss weight dict +# base_weight_dict = copy.deepcopy(model.criterion.weight_dict) +# if model.aux_loss: +# weight_dict = model.criterion.weight_dict +# aux_weight_dict = {} +# aux_weight_dict.update({k + "_enc": v for k, v in base_weight_dict.items()}) +# for i in range(model.transformer.decoder.num_layers - 1): +# aux_weight_dict.update({k + f"_{i}": v for k, v in base_weight_dict.items()}) +# weight_dict.update(aux_weight_dict) +# model.criterion.weight_dict = weight_dict diff --git a/projects/co_mot/configs/mot_r50_4scale_10ep.py b/projects/co_mot/configs/mot_r50_4scale_10ep.py new file mode 100644 index 00000000..86f75ccf --- /dev/null +++ b/projects/co_mot/configs/mot_r50_4scale_10ep.py @@ -0,0 +1,73 @@ +''' +Author: 颜峰 && bphengyan@163.com +Date: 2023-05-25 09:54:44 +LastEditors: 颜峰 && bphengyan@163.com +LastEditTime: 2023-05-31 09:38:52 +FilePath: /detrex/projects/co_mot/configs/mot_r50_4scale_10ep.py +Description: + +Copyright (c) 2023 by ${git_name_email}, All Rights Reserved. +''' +from detectron2.config import LazyConfig +from detrex.config import get_config +from .mot_r50 import model + +# get default config +dataloader = LazyConfig.load("projects/co_mot/configs/common/data/dancetrack_mot.py").dataloader +optimizer = get_config("common/optim.py").AdamW +lr_multiplier = get_config("common/coco_schedule.py").lr_multiplier_12ep # 这个需要改 +# lr_multiplier = +train = get_config("common/train.py").train + +# modify training config +train.init_checkpoint = "detectron2://ImageNetPretrained/torchvision/R-50.pkl" +train.output_dir = "/mnt/dolphinfs/hdd_pool/docker/user/hadoop-vacv/yanfeng/project/MOTRv2/detrex/output/mot_r50_4scale_12ep" + +# dancetrack 41796 imgs +# max training iterations +train.max_iter = 90000 +train.eval_period = 5000 +train.log_period = 100 +train.checkpointer.period = 5000 + +# gradient clipping for training +train.clip_grad.enabled = True +train.clip_grad.params.max_norm = 0.1 +train.clip_grad.params.norm_type = 2 + +# set training devices +train.device = "cuda" +model.device = train.device + +# +train.lr_backbone_names = ['backbone.0'] +train.lr_linear_proj_names = ['reference_points', 'sampling_offsets',] + +# for ddp +train.ddp=dict( + broadcast_buffers=False, + find_unused_parameters=True, + fp16_compression=False, + ) + +# modify optimizer config +optimizer.lr = 2e-4 +optimizer.lr_backbone = 2e-5 +optimizer.lr_linear_proj_mult = 0.1 + +optimizer.sgd=False +optimizer.weight_decay = 1e-4 + +optimizer.betas = (0.9, 0.999) +optimizer.params.lr_factor_func = lambda module_name: 0.1 if "backbone" in module_name else 1 + +# modify dataloader config +dataloader.train.num_workers = 16 + +# please notice that this is total batch size. +# surpose you're using 4 gpus for training and the batch size for +# each gpu is 16/4 = 4 +dataloader.train.total_batch_size = 8 + +# dump the testing results into output_dir for visualization +dataloader.evaluator.output_dir = train.output_dir diff --git a/projects/co_mot/data/__init__.py b/projects/co_mot/data/__init__.py new file mode 100644 index 00000000..429f8a15 --- /dev/null +++ b/projects/co_mot/data/__init__.py @@ -0,0 +1,28 @@ +''' +Author: 颜峰 && bphengyan@163.com +Date: 2023-05-31 09:24:33 +LastEditors: 颜峰 && bphengyan@163.com +LastEditTime: 2023-05-31 09:24:33 +FilePath: /detrex/projects/co_mot/data/__init__.py +Description: + +Copyright (c) 2023 by ${git_name_email}, All Rights Reserved. +''' +# coding=utf-8 +# Copyright 2022 The IDEA Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .mot_dataset_mapper import MotDatasetMapper, MotDatasetInferenceMapper +from . import datasets +from .mot_build import build_mot_train_loader, build_mot_test_loader, mot_collate_fn \ No newline at end of file diff --git a/projects/co_mot/data/datasets/__init__.py b/projects/co_mot/data/datasets/__init__.py new file mode 100644 index 00000000..4df48166 --- /dev/null +++ b/projects/co_mot/data/datasets/__init__.py @@ -0,0 +1,31 @@ +''' +Author: 颜峰 && bphengyan@163.com +Date: 2023-05-31 09:41:04 +LastEditors: 颜峰 && bphengyan@163.com +LastEditTime: 2023-05-31 09:41:05 +FilePath: /detrex/projects/co_mot/data/datasets/__init__.py +Description: + +Copyright (c) 2023 by ${git_name_email}, All Rights Reserved. +''' +# coding=utf-8 +# Copyright 2022 The IDEA Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ------------------------------------------------------------------------------------------------ +# Copyright (c) Facebook, Inc. and its affiliates. +# ------------------------------------------------------------------------------------------------ + +from . import ( + register_dancetrack_mot, +) diff --git a/projects/co_mot/data/datasets/register_dancetrack_mot.py b/projects/co_mot/data/datasets/register_dancetrack_mot.py new file mode 100644 index 00000000..1a18a2e9 --- /dev/null +++ b/projects/co_mot/data/datasets/register_dancetrack_mot.py @@ -0,0 +1,223 @@ +''' +Author: 颜峰 && bphengyan@163.com +Date: 2023-05-25 11:00:08 +LastEditors: 颜峰 && bphengyan@163.com +LastEditTime: 2023-05-31 10:19:30 +FilePath: /detrex/projects/co_mot/data/datasets/register_dancetrack_mot.py +Description: + +Copyright (c) 2023 by ${git_name_email}, All Rights Reserved. +''' +# coding=utf-8 +# Copyright 2022 The IDEA Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ------------------------------------------------------------------------------------------------ +# Copyright (c) Facebook, Inc. and its affiliates. +# ------------------------------------------------------------------------------------------------ +# Modified from: +# https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/data/datasets/register_coco_panoptic_annos_semseg.py +# ------------------------------------------------------------------------------------------------ + +import json +import os +import logging +import torch +from PIL import Image +from fvcore.common.timer import Timer +from collections import defaultdict + +from detectron2.data import DatasetCatalog, MetadataCatalog + +logger = logging.getLogger(__name__) + +DANCETRACK_CATEGORIES = [ + {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"}, +] +def get_dancetrack_mot_instances_meta(dataset_name, seqmap): + thing_classes = [k["name"][0] for k in DANCETRACK_CATEGORIES] + meta = {"thing_classes": thing_classes} + meta['seqmap_txt'] = seqmap + return meta + + +def load_dancetrack_mot(image_root, dataset_name=None, extra_annotation_keys=None): + """ + Load a json file in LVIS's annotation format. + + Args: + json_file (str): full path to the LVIS json annotation file. + image_root (str): the directory where the images in this json file exists. + dataset_name (str): the name of the dataset (e.g., "lvis_v0.5_train"). + If provided, this function will put "thing_classes" into the metadata + associated with this dataset. + extra_annotation_keys (list[str]): list of per-annotation keys that should also be + loaded into the dataset dict (besides "bbox", "bbox_mode", "category_id", + "segmentation"). The values for these keys will be returned as-is. + + Returns: + list[dict]: a list of dicts in Detectron2 standard format. (See + `Using Custom Datasets `_ ) + + Notes: + 1. This function does not read the image files. + The results do not have the "image" field. + """ + + def _add_mot_folder(image_root): + logger.info('YF: Adding {} not exists'.format(image_root)) + labels_full = defaultdict(lambda : defaultdict(list)) + for vid in os.listdir(image_root): + vid = os.path.join(image_root, vid) + gt_path = os.path.join(vid, 'gt', 'gt.txt') + if not os.path.exists(gt_path): + logger.warning('YF: {} not exists'.format(gt_path)) + continue + for l in open(gt_path): + t, i, *xywh, mark, label = l.strip().split(',')[:8] + t, i, mark, label = map(int, (t, i, mark, label)) + if mark == 0: + continue + if label in [3, 4, 5, 6, 9, 10, 11]: # Non-person + continue + else: + crowd = False + x, y, w, h = map(float, (xywh)) + labels_full[vid][t].append([x, y, w, h, i, crowd]) + + return labels_full + + timer = Timer() + labels_full = _add_mot_folder(image_root) + vid_files = list(labels_full.keys()) + + dataset_dicts = [] + image_id = 0 + obj_idx_offset = 0 + for vid in vid_files: + t_min = min(labels_full[vid].keys()) + t_max = max(labels_full[vid].keys()) + 1 # 最大帧+1 + obj_idx_offset += 100000 # 100000 unique ids is enough for a video. + for idx in range(t_min, t_max): + + record = {} + record["file_name"] = os.path.join(image_root, vid, 'img1', f'{idx:08d}.jpg') + record["not_exhaustive_category_ids"] = [] + record["neg_category_ids"] = [] + + record['dataset'] = 'DanceTrack' + image_id += 1 # imageid必须从1开始 + record['image_id'] = image_id + record['frame_id'] = torch.as_tensor(idx) + record['video_name'] = vid + record['t_min'] = t_min + record['t_max'] = t_max + if idx == t_min: + img = Image.open(record["file_name"]) + w, h = img._size + record["height"] = h + record["width"] = w + record['size'] = torch.as_tensor([h, w]) + record['orig_size'] = torch.as_tensor([h, w]) + + record['boxes'] = [] + record['iscrowd'] = [] + record['labels'] = [] + record['obj_ids'] = [] + record['scores'] = [] + record['boxes_type'] = "x0y0wh" + for *xywh, id, crowd in labels_full[vid][idx]: + record['boxes'].append(xywh) + assert not crowd + record['iscrowd'].append(crowd) + record['labels'].append(0) + record['obj_ids'].append(id + obj_idx_offset) + record['scores'].append(1.) + record['iscrowd'] = torch.as_tensor(record['iscrowd']) + record['labels'] = torch.as_tensor(record['labels']) + record['obj_ids'] = torch.as_tensor(record['obj_ids'], dtype=torch.float64) + record['scores'] = torch.as_tensor(record['scores']) + record['boxes'] = torch.as_tensor(record['boxes'], dtype=torch.float32).reshape(-1, 4) + + dataset_dicts.append(record) + + logger.info("Loading {} takes {:.2f} seconds.".format(image_root, timer.seconds())) + + return dataset_dicts + + +def register_dancetrack_mot_instances(name, metadata, image_root): + """ + Register a dataset in dancetrack's json annotation format for instance detection and segmentation. + + Args: + name (str): a name that identifies the dataset, e.g. "lvis_v0.5_train". + metadata (dict): extra metadata associated with this dataset. It can be an empty dict. + image_root (str or path-like): directory which contains all the images. + """ + DatasetCatalog.register(name, lambda: load_dancetrack_mot(image_root, name)) + MetadataCatalog.get(name).set(image_root=image_root, evaluator_type="mot17", **metadata) + + + +_PREDEFINED_SPLITS_DANCETRACK_MOT = { + "dancetrack": { + "dancetrack_train": ("train/", "train_seqmap.txt"), + "dancetrack_val": ("val/", 'val_seqmap.txt'), + "dancetrack_test": ("test/", "test_seqmap.txt"), + }, +} + + +def register_dancetrack_mot(root): + for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_DANCETRACK_MOT.items(): + for key, (image_root, seqmap) in splits_per_dataset.items(): + register_dancetrack_mot_instances( + key, + get_dancetrack_mot_instances_meta(key, os.path.join(root, seqmap)), + os.path.join(root, image_root), + ) + + +_root = os.getenv("DETECTRON2_DATASETS", "/mnt/dolphinfs/ssd_pool/docker/user/hadoop-vacv/yanfeng/data/dancetrack") +register_dancetrack_mot(_root) + + +if __name__ == "__main__": + """ + Test the dataset loader. + + Usage: + python -m detectron2.data.datasets.lvis \ + path/to/json path/to/image_root dataset_name vis_limit + """ + import sys + import numpy as np + from detectron2.utils.logger import setup_logger + import detectron2.data.datasets # noqa # add pre-defined metadata + from detectron2.utils.visualizer import Visualizer + + logger = setup_logger(name=__name__) + meta = MetadataCatalog.get('dancetrack_train') + + dicts = load_dancetrack_mot(meta.image_root, meta.name) + logger.info("Done loading {} samples.".format(len(dicts))) + + dirname = "tmp" + os.makedirs(dirname, exist_ok=True) + for d in dicts: + img = np.array(Image.open(d["file_name"])) + visualizer = Visualizer(img, metadata=meta) + vis = visualizer.draw_dataset_dict(d) + fpath = os.path.join(dirname, os.path.basename(d["file_name"])) + vis.save(fpath) diff --git a/projects/co_mot/data/mot_build.py b/projects/co_mot/data/mot_build.py new file mode 100644 index 00000000..1c2de9a5 --- /dev/null +++ b/projects/co_mot/data/mot_build.py @@ -0,0 +1,659 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import itertools +import logging +import numpy as np +import operator +import pickle +from typing import Any, Callable, Dict, List, Optional, Union +import torch +import random +import torch.utils.data as torchdata +from tabulate import tabulate +from termcolor import colored + +from detectron2.config import configurable +from detectron2.structures import BoxMode +from detectron2.utils.comm import get_world_size +from detectron2.utils.env import seed_all_rng +from detectron2.utils.file_io import PathManager +from detectron2.utils.logger import _log_api_usage, log_first_n + +from detectron2.data.catalog import DatasetCatalog, MetadataCatalog +from detectron2.data.common import AspectRatioGroupedDataset, DatasetFromList, ToIterableDataset +from detectron2.data.dataset_mapper import DatasetMapper +from detectron2.data.detection_utils import check_metadata_consistency +from detectron2.data.samplers import ( + InferenceSampler, + RandomSubsetTrainingSampler, + RepeatFactorTrainingSampler, + TrainingSampler, +) +from detectron2.utils.serialize import PicklableWrapper + +""" +This file contains the default logic to build a dataloader for training or testing. +""" + +__all__ = [ + "build_mot_train_loader", + "build_mot_test_loader", + "mot_collate_fn", +] + +def mot_collate_fn(batch: List[dict]) -> dict: + ret_dict = {} + for key in list(batch[0].keys()): + assert not isinstance(batch[0][key], torch.Tensor) + ret_dict[key] = [img_info[key] for img_info in batch] + if isinstance(ret_dict[key][0], list): + ret_dict[key] = list(map(list, zip(*ret_dict[key]))) + return ret_dict + +class _MapIterableDataset(torchdata.IterableDataset): + """ + Map a function over elements in an IterableDataset. + + Similar to pytorch's MapIterDataPipe, but support filtering when map_func + returns None. + + This class is not public-facing. Will be called by `MotMapDataset`. + """ + + def __init__(self, dataset, map_func): + self._dataset = dataset + self._map_func = PicklableWrapper(map_func) # wrap so that a lambda will work + + def __len__(self): + return len(self._dataset) + + def __iter__(self): + for x in map(self._map_func, self._dataset): + if x is not None: + yield x + + +class MotMapDataset(torchdata.Dataset): + """ + Map a function over the elements in a dataset. + """ + + def __init__(self, dataset, map_func): + """ + Args: + dataset: a dataset where map function is applied. Can be either + map-style or iterable dataset. When given an iterable dataset, + the returned object will also be an iterable dataset. + map_func: a callable which maps the element in dataset. map_func can + return None to skip the data (e.g. in case of errors). + How None is handled depends on the style of `dataset`. + If `dataset` is map-style, it randomly tries other elements. + If `dataset` is iterable, it skips the data and tries the next. + """ + self._dataset = dataset + self._map_func = PicklableWrapper(map_func) # wrap so that a lambda will work + + self._rng = random.Random(42) + self._fallback_candidates = set(range(len(dataset))) + + def __new__(cls, dataset, map_func): + is_iterable = isinstance(dataset, torchdata.IterableDataset) + if is_iterable: + return _MapIterableDataset(dataset, map_func) + else: + return super().__new__(cls) + + def __getnewargs__(self): + return self._dataset, self._map_func + + def __len__(self): + return len(self._dataset) + + def __getitem__(self, idx): + retry_count = 0 + cur_idx = int(idx) + + while True: + data = self._map_func(self._dataset, cur_idx) + if data is not None: + self._fallback_candidates.add(cur_idx) + return data + + # _map_func fails for this idx, use a random new index from the pool + retry_count += 1 + self._fallback_candidates.discard(cur_idx) + cur_idx = self._rng.sample(self._fallback_candidates, k=1)[0] + + if retry_count >= 3: + logger = logging.getLogger(__name__) + logger.warning( + "Failed to apply `_map_func` for idx: {}, retry count: {}".format( + idx, retry_count + ) + ) + +def filter_images_with_only_crowd_annotations(dataset_dicts): + """ + Filter out images with none annotations or only crowd annotations + (i.e., images without non-crowd annotations). + A common training-time preprocessing on COCO dataset. + + Args: + dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. + + Returns: + list[dict]: the same format, but filtered. + """ + num_before = len(dataset_dicts) + + def valid(anns): + for ann in anns: + if ann.get("iscrowd", 0) == 0: + return True + return False + + dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])] + num_after = len(dataset_dicts) + logger = logging.getLogger(__name__) + logger.info( + "Removed {} images with no usable annotations. {} images left.".format( + num_before - num_after, num_after + ) + ) + return dataset_dicts + + +def filter_images_with_few_keypoints(dataset_dicts, min_keypoints_per_image): + """ + Filter out images with too few number of keypoints. + + Args: + dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. + + Returns: + list[dict]: the same format as dataset_dicts, but filtered. + """ + num_before = len(dataset_dicts) + + def visible_keypoints_in_image(dic): + # Each keypoints field has the format [x1, y1, v1, ...], where v is visibility + annotations = dic["annotations"] + return sum( + (np.array(ann["keypoints"][2::3]) > 0).sum() + for ann in annotations + if "keypoints" in ann + ) + + dataset_dicts = [ + x for x in dataset_dicts if visible_keypoints_in_image(x) >= min_keypoints_per_image + ] + num_after = len(dataset_dicts) + logger = logging.getLogger(__name__) + logger.info( + "Removed {} images with fewer than {} keypoints.".format( + num_before - num_after, min_keypoints_per_image + ) + ) + return dataset_dicts + + +def load_proposals_into_dataset(dataset_dicts, proposal_file): + """ + Load precomputed object proposals into the dataset. + + The proposal file should be a pickled dict with the following keys: + + - "ids": list[int] or list[str], the image ids + - "boxes": list[np.ndarray], each is an Nx4 array of boxes corresponding to the image id + - "objectness_logits": list[np.ndarray], each is an N sized array of objectness scores + corresponding to the boxes. + - "bbox_mode": the BoxMode of the boxes array. Defaults to ``BoxMode.XYXY_ABS``. + + Args: + dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. + proposal_file (str): file path of pre-computed proposals, in pkl format. + + Returns: + list[dict]: the same format as dataset_dicts, but added proposal field. + """ + logger = logging.getLogger(__name__) + logger.info("Loading proposals from: {}".format(proposal_file)) + + with PathManager.open(proposal_file, "rb") as f: + proposals = pickle.load(f, encoding="latin1") + + # Rename the key names in D1 proposal files + rename_keys = {"indexes": "ids", "scores": "objectness_logits"} + for key in rename_keys: + if key in proposals: + proposals[rename_keys[key]] = proposals.pop(key) + + # Fetch the indexes of all proposals that are in the dataset + # Convert image_id to str since they could be int. + img_ids = set({str(record["image_id"]) for record in dataset_dicts}) + id_to_index = {str(id): i for i, id in enumerate(proposals["ids"]) if str(id) in img_ids} + + # Assuming default bbox_mode of precomputed proposals are 'XYXY_ABS' + bbox_mode = BoxMode(proposals["bbox_mode"]) if "bbox_mode" in proposals else BoxMode.XYXY_ABS + + for record in dataset_dicts: + # Get the index of the proposal + i = id_to_index[str(record["image_id"])] + + boxes = proposals["boxes"][i] + objectness_logits = proposals["objectness_logits"][i] + # Sort the proposals in descending order of the scores + inds = objectness_logits.argsort()[::-1] + record["proposal_boxes"] = boxes[inds] + record["proposal_objectness_logits"] = objectness_logits[inds] + record["proposal_bbox_mode"] = bbox_mode + + return dataset_dicts + + +def print_instances_class_histogram(dataset_dicts, class_names): + """ + Args: + dataset_dicts (list[dict]): list of dataset dicts. + class_names (list[str]): list of class names (zero-indexed). + """ + num_classes = len(class_names) + hist_bins = np.arange(num_classes + 1) + histogram = np.zeros((num_classes,), dtype=np.int) + for entry in dataset_dicts: + annos = entry["annotations"] + classes = np.asarray( + [x["category_id"] for x in annos if not x.get("iscrowd", 0)], dtype=np.int + ) + if len(classes): + assert classes.min() >= 0, f"Got an invalid category_id={classes.min()}" + assert ( + classes.max() < num_classes + ), f"Got an invalid category_id={classes.max()} for a dataset of {num_classes} classes" + histogram += np.histogram(classes, bins=hist_bins)[0] + + N_COLS = min(6, len(class_names) * 2) + + def short_name(x): + # make long class names shorter. useful for lvis + if len(x) > 13: + return x[:11] + ".." + return x + + data = list( + itertools.chain(*[[short_name(class_names[i]), int(v)] for i, v in enumerate(histogram)]) + ) + total_num_instances = sum(data[1::2]) + data.extend([None] * (N_COLS - (len(data) % N_COLS))) + if num_classes > 1: + data.extend(["total", total_num_instances]) + data = itertools.zip_longest(*[data[i::N_COLS] for i in range(N_COLS)]) + table = tabulate( + data, + headers=["category", "#instances"] * (N_COLS // 2), + tablefmt="pipe", + numalign="left", + stralign="center", + ) + log_first_n( + logging.INFO, + "Distribution of instances among all {} categories:\n".format(num_classes) + + colored(table, "cyan"), + key="message", + ) + + +def get_detection_dataset_dicts( + names, + filter_empty=True, + min_keypoints=0, + proposal_files=None, + check_consistency=True, +): + """ + Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation. + + Args: + names (str or list[str]): a dataset name or a list of dataset names + filter_empty (bool): whether to filter out images without instance annotations + min_keypoints (int): filter out images with fewer keypoints than + `min_keypoints`. Set to 0 to do nothing. + proposal_files (list[str]): if given, a list of object proposal files + that match each dataset in `names`. + check_consistency (bool): whether to check if datasets have consistent metadata. + + Returns: + list[dict]: a list of dicts following the standard dataset dict format. + """ + if isinstance(names, str): + names = [names] + assert len(names), names + dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in names] + + if isinstance(dataset_dicts[0], torchdata.Dataset): + if len(dataset_dicts) > 1: + # ConcatDataset does not work for iterable style dataset. + # We could support concat for iterable as well, but it's often + # not a good idea to concat iterables anyway. + return torchdata.ConcatDataset(dataset_dicts) + return dataset_dicts[0] + + for dataset_name, dicts in zip(names, dataset_dicts): + assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) + + if proposal_files is not None: + assert len(names) == len(proposal_files) + # load precomputed proposals from proposal files + dataset_dicts = [ + load_proposals_into_dataset(dataset_i_dicts, proposal_file) + for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files) + ] + + dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) + + has_instances = "annotations" in dataset_dicts[0] + if filter_empty and has_instances: + dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts) + if min_keypoints > 0 and has_instances: + dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints) + + if check_consistency and has_instances: + try: + class_names = MetadataCatalog.get(names[0]).thing_classes + check_metadata_consistency("thing_classes", names) + print_instances_class_histogram(dataset_dicts, class_names) + except AttributeError: # class names are not available for this dataset + pass + + assert len(dataset_dicts), "No valid data found in {}.".format(",".join(names)) + return dataset_dicts + + +def build_batch_data_loader( + dataset, + sampler, + total_batch_size, + *, + aspect_ratio_grouping=False, + num_workers=0, + collate_fn=None, +): + """ + Build a batched dataloader. The main differences from `torch.utils.data.DataLoader` are: + 1. support aspect ratio grouping options + 2. use no "batch collation", because this is common for detection training + + Args: + dataset (torch.utils.data.Dataset): a pytorch map-style or iterable dataset. + sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces indices. + Must be provided iff. ``dataset`` is a map-style dataset. + total_batch_size, aspect_ratio_grouping, num_workers, collate_fn: see + :func:`build_detection_train_loader`. + + Returns: + iterable[list]. Length of each list is the batch size of the current + GPU. Each element in the list comes from the dataset. + """ + world_size = get_world_size() + assert ( + total_batch_size > 0 and total_batch_size % world_size == 0 + ), "Total batch size ({}) must be divisible by the number of gpus ({}).".format( + total_batch_size, world_size + ) + batch_size = total_batch_size // world_size + + if isinstance(dataset, torchdata.IterableDataset): + assert sampler is None, "sampler must be None if dataset is IterableDataset" + else: + dataset = ToIterableDataset(dataset, sampler) + + if aspect_ratio_grouping: + data_loader = torchdata.DataLoader( + dataset, + num_workers=num_workers, + collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements + worker_init_fn=worker_init_reset_seed, + ) # yield individual mapped dict + data_loader = AspectRatioGroupedDataset(data_loader, batch_size) + if collate_fn is None: + return data_loader + return MotMapDataset(data_loader, collate_fn) + else: + return torchdata.DataLoader( + dataset, + batch_size=batch_size, + drop_last=True, + num_workers=num_workers, + collate_fn=trivial_batch_collator if collate_fn is None else collate_fn, + worker_init_fn=worker_init_reset_seed, + ) + + +def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None): + if dataset is None: + dataset = get_detection_dataset_dicts( + cfg.DATASETS.TRAIN, + filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, + min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE + if cfg.MODEL.KEYPOINT_ON + else 0, + proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, + ) + _log_api_usage("dataset." + cfg.DATASETS.TRAIN[0]) + + if mapper is None: + mapper = DatasetMapper(cfg, True) + + if sampler is None: + sampler_name = cfg.DATALOADER.SAMPLER_TRAIN + logger = logging.getLogger(__name__) + if isinstance(dataset, torchdata.IterableDataset): + logger.info("Not using any sampler since the dataset is IterableDataset.") + sampler = None + else: + logger.info("Using training sampler {}".format(sampler_name)) + if sampler_name == "TrainingSampler": + sampler = TrainingSampler(len(dataset)) + elif sampler_name == "RepeatFactorTrainingSampler": + repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( + dataset, cfg.DATALOADER.REPEAT_THRESHOLD + ) + sampler = RepeatFactorTrainingSampler(repeat_factors) + elif sampler_name == "RandomSubsetTrainingSampler": + sampler = RandomSubsetTrainingSampler( + len(dataset), cfg.DATALOADER.RANDOM_SUBSET_RATIO + ) + else: + raise ValueError("Unknown training sampler: {}".format(sampler_name)) + + return { + "dataset": dataset, + "sampler": sampler, + "mapper": mapper, + "total_batch_size": cfg.SOLVER.IMS_PER_BATCH, + "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING, + "num_workers": cfg.DATALOADER.NUM_WORKERS, + } + + +@configurable(from_config=_train_loader_from_config) +def build_mot_train_loader( + dataset, + *, + mapper, + sampler=None, + total_batch_size, + aspect_ratio_grouping=True, + num_workers=0, + collate_fn=None, +): + """ + Build a dataloader for object detection with some default features. + + Args: + dataset (list or torch.utils.data.Dataset): a list of dataset dicts, + or a pytorch dataset (either map-style or iterable). It can be obtained + by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. + mapper (callable): a callable which takes a sample (dict) from dataset and + returns the format to be consumed by the model. + When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``. + sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces + indices to be applied on ``dataset``. + If ``dataset`` is map-style, the default sampler is a :class:`TrainingSampler`, + which coordinates an infinite random shuffle sequence across all workers. + Sampler must be None if ``dataset`` is iterable. + total_batch_size (int): total batch size across all workers. + aspect_ratio_grouping (bool): whether to group images with similar + aspect ratio for efficiency. When enabled, it requires each + element in dataset be a dict with keys "width" and "height". + num_workers (int): number of parallel data loading workers + collate_fn: a function that determines how to do batching, same as the argument of + `torch.utils.data.DataLoader`. Defaults to do no collation and return a list of + data. No collation is OK for small batch size and simple data structures. + If your batch size is large and each sample contains too many small tensors, + it's more efficient to collate them in data loader. + + Returns: + torch.utils.data.DataLoader: + a dataloader. Each output from it is a ``list[mapped_element]`` of length + ``total_batch_size / num_workers``, where ``mapped_element`` is produced + by the ``mapper``. + """ + if isinstance(dataset, list): + dataset = DatasetFromList(dataset, copy=False) + if mapper is not None: + dataset = MotMapDataset(dataset, mapper) + + if isinstance(dataset, torchdata.IterableDataset): + assert sampler is None, "sampler must be None if dataset is IterableDataset" + else: + if sampler is None: + sampler = TrainingSampler(len(dataset)) + assert isinstance(sampler, torchdata.Sampler), f"Expect a Sampler but got {type(sampler)}" + return build_batch_data_loader( + dataset, + sampler, + total_batch_size, + aspect_ratio_grouping=aspect_ratio_grouping, + num_workers=num_workers, + collate_fn=collate_fn, + ) + + +def _test_loader_from_config(cfg, dataset_name, mapper=None): + """ + Uses the given `dataset_name` argument (instead of the names in cfg), because the + standard practice is to evaluate each test set individually (not combining them). + """ + if isinstance(dataset_name, str): + dataset_name = [dataset_name] + + dataset = get_detection_dataset_dicts( + dataset_name, + filter_empty=False, + proposal_files=[ + cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] for x in dataset_name + ] + if cfg.MODEL.LOAD_PROPOSALS + else None, + ) + if mapper is None: + mapper = DatasetMapper(cfg, False) + return { + "dataset": dataset, + "mapper": mapper, + "num_workers": cfg.DATALOADER.NUM_WORKERS, + "sampler": InferenceSampler(len(dataset)) + if not isinstance(dataset, torchdata.IterableDataset) + else None, + } + + +@configurable(from_config=_test_loader_from_config) +def build_mot_test_loader( + dataset: Union[List[Any], torchdata.Dataset], + *, + mapper: Callable[[Dict[str, Any]], Any], + sampler: Optional[torchdata.Sampler] = None, + batch_size: int = 1, + num_workers: int = 0, + collate_fn: Optional[Callable[[List[Any]], Any]] = None, +) -> torchdata.DataLoader: + """ + Similar to `build_detection_train_loader`, with default batch size = 1, + and sampler = :class:`InferenceSampler`. This sampler coordinates all workers + to produce the exact set of all samples. + + Args: + dataset: a list of dataset dicts, + or a pytorch dataset (either map-style or iterable). They can be obtained + by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. + mapper: a callable which takes a sample (dict) from dataset + and returns the format to be consumed by the model. + When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``. + sampler: a sampler that produces + indices to be applied on ``dataset``. Default to :class:`InferenceSampler`, + which splits the dataset across all workers. Sampler must be None + if `dataset` is iterable. + batch_size: the batch size of the data loader to be created. + Default to 1 image per worker since this is the standard when reporting + inference time in papers. + num_workers: number of parallel data loading workers + collate_fn: same as the argument of `torch.utils.data.DataLoader`. + Defaults to do no collation and return a list of data. + + Returns: + DataLoader: a torch DataLoader, that loads the given detection + dataset, with test-time transformation and batching. + + Examples: + :: + data_loader = build_detection_test_loader( + DatasetRegistry.get("my_test"), + mapper=DatasetMapper(...)) + + # or, instantiate with a CfgNode: + data_loader = build_detection_test_loader(cfg, "my_test") + """ + + # 按视频分组 + dataset_ = [] + vid_name = "" + data_vid = [] + for d in dataset: + if vid_name != d['video_name']: + vid_name = d['video_name'] + if len(data_vid): dataset_.append(data_vid) + data_vid = [] + data_vid.append(d) + if len(data_vid): dataset_.append(data_vid) + dataset = dataset_ + + if isinstance(dataset, list): + dataset = DatasetFromList(dataset, copy=False) + if mapper is not None: + dataset = MotMapDataset(dataset, mapper) + if isinstance(dataset, torchdata.IterableDataset): + assert sampler is None, "sampler must be None if dataset is IterableDataset" + else: + if sampler is None: + sampler = InferenceSampler(len(dataset)) + return torchdata.DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + drop_last=False, + num_workers=num_workers, + collate_fn=trivial_batch_collator if collate_fn is None else collate_fn, + ) + + +def trivial_batch_collator(batch): + """ + A batch collator that does nothing. + """ + return batch + + +def worker_init_reset_seed(worker_id): + initial_seed = torch.initial_seed() % 2**31 + seed_all_rng(initial_seed + worker_id) diff --git a/projects/co_mot/data/mot_dataset_mapper.py b/projects/co_mot/data/mot_dataset_mapper.py new file mode 100644 index 00000000..c23a587d --- /dev/null +++ b/projects/co_mot/data/mot_dataset_mapper.py @@ -0,0 +1,248 @@ +# coding=utf-8 +# Copyright 2022 The IDEA Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ------------------------------------------------------------------------------------------------ +# Copyright (c) Facebook, Inc. and its affiliates. +# All Rights Reserved +# ------------------------------------------------------------------------------------------------ +# Modified from: +# https://github.com/facebookresearch/detr/blob/main/d2/detr/dataset_mapper.py +# ------------------------------------------------------------------------------------------------ + +import copy +import logging +import numpy as np +import torch +import itertools +from PIL import Image +from typing import Optional +from random import choice, randint +from torch.utils.data.sampler import Sampler +from torch.utils.data import Dataset, DataLoader + +from detectron2.utils import comm +from detectron2.data import detection_utils as utils +from detectron2.data import transforms as T +from detectron2.structures import Instances + +__all__ = ["MotDatasetMapper", "MotDatasetInferenceMapper"] + + +class MotDatasetMapper: + """ + A callable which takes a dataset dict in Detectron2 Dataset format, + and map it into the format used by DETR. + + The callable currently does the following: + + 1. Read the image from "file_name" + 2. Applies geometric transforms to the image and annotation + 3. Find and applies suitable cropping to the image and annotation + 4. Prepare image and annotation to Tensors + + Args: + augmentation (list[detectron.data.Transforms]): The geometric transforms for + the input raw image and annotations. + augmentation_with_crop (list[detectron.data.Transforms]): The geometric transforms with crop. + is_train (bool): Whether to load train set or val set. Default: True. + mask_on (bool): Whether to return the mask annotations. Default: False. + img_format (str): The format of the input raw images. Default: RGB. + + Because detectron2 did not implement `RandomSelect` augmentation. So we provide both `augmentation` and + `augmentation_with_crop` here and randomly apply one of them to the input raw images. + """ + + def __init__( + self, + augmentation, + augmentation_with_crop, + is_train=True, + mask_on=False, + img_format="RGB", + sample_mode='random_interval', + sample_interval=10, + num_frames_per_batch=5, + ): + self.mask_on = mask_on + self.augmentation = augmentation + self.augmentation_with_crop = augmentation_with_crop + logging.getLogger(__name__).info( + "Full TransformGens used in training: {}, crop: {}".format( + str(self.augmentation), str(self.augmentation_with_crop) + ) + ) + + self.img_format = img_format + self.is_train = is_train + + self.sample_mode = sample_mode + assert self.sample_mode == 'random_interval' + self.sample_interval=sample_interval + self.num_frames_per_batch=num_frames_per_batch + + @staticmethod + def _targets_to_instances(targets: dict, img_shape): + gt_instances = Instances(tuple(img_shape)) + n_gt = len(targets['labels']) + gt_instances.boxes = targets['boxes'][:n_gt] + gt_instances.labels = targets['labels'] + gt_instances.obj_ids = targets['obj_ids'] + return gt_instances + + def __call__(self, dataset, cur_idx): + """ + Args: + dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. + + Returns: + dict: a format that builtin models in detectron2 accept + """ + dataset_dict = dataset[cur_idx] + + rate = randint(1, self.sample_interval + 1) + tmax = dataset_dict['t_max'] - dataset_dict['frame_id'] - 1 + indexes = [min(rate * i, tmax) + cur_idx for i in range(self.num_frames_per_batch)] + + images, targets = [], [] + for cur_idx in indexes: + dataset_dict = dataset[cur_idx] + dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below + assert dataset_dict['boxes_type'] == "x0y0wh" + + img = Image.open(dataset_dict["file_name"]) + w, h = img._size + assert self.img_format == 'RGB' and w == dataset_dict['width'] and h == dataset_dict['height'] + + images.append(img) + targets.append(dataset_dict) + + if self.augmentation is not None: + images, targets = self.augmentation(images, targets) + + gt_instances = [] + for img_i, targets_i in zip(images, targets): + gt_instances_i = self._targets_to_instances(targets_i, img_i.shape[1:3]) + gt_instances.append(gt_instances_i) + + return { + 'imgs': images, + 'gt_instances': gt_instances, + 'width': w, + 'height': h, + } + + +import os +import cv2 +import torchvision.transforms.functional as TransF +from torch.utils.data import Dataset, DataLoader +class ListImgDataset(Dataset): + def __init__(self, mot_path, img_list, det_db) -> None: + super().__init__() + self.mot_path = mot_path + self.img_list = img_list + self.det_db = det_db + + ''' + common settings + ''' + self.img_height = 800 + self.img_width = 1536 + self.mean = [0.485, 0.456, 0.406] + self.std = [0.229, 0.224, 0.225] + + def load_img_from_file(self, f_path): + cur_img = cv2.imread(os.path.join(self.mot_path, f_path)) + assert cur_img is not None, f_path + cur_img = cv2.cvtColor(cur_img, cv2.COLOR_BGR2RGB) + proposals = [] + im_h, im_w = cur_img.shape[:2] + if len(self.det_db): + for line in self.det_db[f_path[:-4].replace('dancetrack/', 'DanceTrack/') + '.txt']: + l, t, w, h, s = list(map(float, line.split(','))) + proposals.append([(l + w / 2) / im_w, + (t + h / 2) / im_h, + w / im_w, + h / im_h, + s]) + return cur_img, torch.as_tensor(proposals).reshape(-1, 5), f_path + + def init_img(self, img, proposals): + ori_img = img.copy() + self.seq_h, self.seq_w = img.shape[:2] + scale = self.img_height / min(self.seq_h, self.seq_w) + if max(self.seq_h, self.seq_w) * scale > self.img_width: + scale = self.img_width / max(self.seq_h, self.seq_w) + target_h = int(self.seq_h * scale) + target_w = int(self.seq_w * scale) + img = cv2.resize(img, (target_w, target_h)) + img = TransF.normalize(TransF.to_tensor(img), self.mean, self.std) + img = img.unsqueeze(0) + return img, ori_img, proposals + + def __len__(self): + return len(self.img_list) + + def __getitem__(self, index): # 加载图像和proposal。并对图像颜色通道转换+resize+normalize+to_tensor。 + img, proposals, f_path = self.load_img_from_file(self.img_list[index]) + img, ori_img, proposals = self.init_img(img, proposals) + return img, ori_img, proposals, f_path + +class MotDatasetInferenceMapper: + """ + A callable which takes a dataset dict in Detectron2 Dataset format, + and map it into the format used by DETR. + + The callable currently does the following: + + 1. Read the image from "file_name" + 2. Applies geometric transforms to the image and annotation + 3. Find and applies suitable cropping to the image and annotation + 4. Prepare image and annotation to Tensors + + Args: + augmentation (list[detectron.data.Transforms]): The geometric transforms for + the input raw image and annotations. + augmentation_with_crop (list[detectron.data.Transforms]): The geometric transforms with crop. + is_train (bool): Whether to load train set or val set. Default: True. + mask_on (bool): Whether to return the mask annotations. Default: False. + img_format (str): The format of the input raw images. Default: RGB. + + Because detectron2 did not implement `RandomSelect` augmentation. So we provide both `augmentation` and + `augmentation_with_crop` here and randomly apply one of them to the input raw images. + """ + + def __init__( + self + ): + pass + + def __call__(self, dataset, cur_idx): + """ + Args: + dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. + + Returns: + dict: a format that builtin models in detectron2 accept + """ + dataset_dict = dataset[cur_idx] + + img_list = [d['file_name'] for d in dataset_dict] + img_list = sorted(img_list) + loader = DataLoader(ListImgDataset('', img_list, ''), 1, num_workers=2) + + return { + 'data_loader': loader, + "dataset_dict": dataset_dict, + } diff --git a/projects/co_mot/data/transforms/__init__.py b/projects/co_mot/data/transforms/__init__.py new file mode 100644 index 00000000..fbc5bfdf --- /dev/null +++ b/projects/co_mot/data/transforms/__init__.py @@ -0,0 +1,26 @@ +''' +Author: 颜峰 && bphengyan@163.com +Date: 2023-05-31 09:41:55 +LastEditors: 颜峰 && bphengyan@163.com +LastEditTime: 2023-05-31 09:41:56 +FilePath: /detrex/projects/co_mot/data/transforms/__init__.py +Description: + +Copyright (c) 2023 by ${git_name_email}, All Rights Reserved. +''' +# coding=utf-8 +# Copyright 2022 The IDEA Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import mot_transforms \ No newline at end of file diff --git a/projects/co_mot/data/transforms/mot_transforms.py b/projects/co_mot/data/transforms/mot_transforms.py new file mode 100755 index 00000000..c9d45874 --- /dev/null +++ b/projects/co_mot/data/transforms/mot_transforms.py @@ -0,0 +1,617 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-research. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + +""" +Transforms and data augmentation for both image + bbox. +""" +import copy +import random +import PIL +import cv2 +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as F +from PIL import Image, ImageDraw +import numpy as np +import os + +def box_xywh_to_cxcywh(x): + x0, y0, w, h = x.unbind(-1) + b = [x0 + w / 2, y0 + h / 2, w, h] + return torch.stack(b, dim=-1) + +def crop_mot(image, target, region): + cropped_image = F.crop(image, *region) + + target = target.copy() + i, j, h, w = region + + # should we do something wrt the original size? + target["size"] = torch.tensor([h, w]) + + fields = ["labels", "iscrowd", "obj_ids", "scores"] + + if "boxes" in target: + boxes = target["boxes"] + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) + cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + target["boxes"] = cropped_boxes.reshape(-1, 4) + fields.append("boxes") + + if "masks" in target: + # FIXME should we update the area here if there are no boxes? + target['masks'] = target['masks'][:, i:i + h, j:j + w] + fields.append("masks") + + # remove elements for which the boxes or masks that have zero area + if "boxes" in target or "masks" in target: + # favor boxes selection when defining which elements to keep + # this is compatible with previous implementation + if "boxes" in target: + cropped_boxes = target['boxes'].reshape(-1, 2, 2) + keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) + else: + keep = target['masks'].flatten(1).any(1) + + for field in fields: + n_size = len(target[field]) + target[field] = target[field][keep[:n_size]] + + return cropped_image, target + + +def random_shift(image, target, region, sizes): + oh, ow = sizes + # step 1, shift crop and re-scale image firstly + cropped_image = F.crop(image, *region) + cropped_image = F.resize(cropped_image, sizes) + + target = target.copy() + i, j, h, w = region + + # should we do something wrt the original size? + target["size"] = torch.tensor([h, w]) + + fields = ["labels", "scores", "iscrowd", "obj_ids"] + + if "boxes" in target: + boxes = target["boxes"] + cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) + cropped_boxes *= torch.as_tensor([ow / w, oh / h, ow / w, oh / h]) + target["boxes"] = cropped_boxes.reshape(-1, 4) + fields.append("boxes") + + if "masks" in target: + # FIXME should we update the area here if there are no boxes? + target['masks'] = target['masks'][:, i:i + h, j:j + w] + fields.append("masks") + + # remove elements for which the boxes or masks that have zero area + if "boxes" in target or "masks" in target: + # favor boxes selection when defining which elements to keep + # this is compatible with previous implementation + if "boxes" in target: + cropped_boxes = target['boxes'].reshape(-1, 2, 2) + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) + else: + keep = target['masks'].flatten(1).any(1) + + for field in fields: + n_size = len(target[field]) + target[field] = target[field][keep[:n_size]] + + return cropped_image, target + + +def crop(image, target, region): + cropped_image = F.crop(image, *region) + + target = target.copy() + i, j, h, w = region + + # should we do something wrt the original size? + target["size"] = torch.tensor([h, w]) + + fields = ["labels", "area", "iscrowd"] + if 'obj_ids' in target: + fields.append('obj_ids') + + if "boxes" in target: + boxes = target["boxes"] + max_size = torch.as_tensor([w, h], dtype=torch.float32) + cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) + cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) + cropped_boxes = cropped_boxes.clamp(min=0) + + area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) + target["boxes"] = cropped_boxes.reshape(-1, 4) + target["area"] = area + fields.append("boxes") + + if "masks" in target: + # FIXME should we update the area here if there are no boxes? + target['masks'] = target['masks'][:, i:i + h, j:j + w] + fields.append("masks") + + # remove elements for which the boxes or masks that have zero area + if "boxes" in target or "masks" in target: + # favor boxes selection when defining which elements to keep + # this is compatible with previous implementation + if "boxes" in target: + cropped_boxes = target['boxes'].reshape(-1, 2, 2) + keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) + else: + keep = target['masks'].flatten(1).any(1) + + for field in fields: + target[field] = target[field][keep] + + return cropped_image, target + + +def hflip(image, target): + flipped_image = F.hflip(image) + + w, h = image.size + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) + target["boxes"] = boxes + + if "masks" in target: + target['masks'] = target['masks'].flip(-1) + + return flipped_image, target + + +def resize(image, target, size, max_size=None): + # size can be min_size (scalar) or (w, h) tuple + + def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = int(round(max_size * min_original_size / max_original_size)) + + if (w <= h and w == size) or (h <= w and h == size): + return (h, w) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + + return (oh, ow) + + def get_size(image_size, size, max_size=None): + if isinstance(size, (list, tuple)): + return size[::-1] + else: + return get_size_with_aspect_ratio(image_size, size, max_size) + + size = get_size(image.size, size, max_size) + rescaled_image = F.resize(image, size) + + if target is None: + return rescaled_image, None + + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) + ratio_width, ratio_height = ratios + + target = target.copy() + if "boxes" in target: + boxes = target["boxes"] + scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) + target["boxes"] = scaled_boxes + + if "area" in target: + area = target["area"] + scaled_area = area * (ratio_width * ratio_height) + target["area"] = scaled_area + + h, w = size + target["size"] = torch.tensor([h, w]) + + if "masks" in target: + target['masks'] = interpolate( + target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 + + return rescaled_image, target + + +def pad(image, target, padding): + # assumes that we only pad on the bottom right corners + padded_image = F.pad(image, (0, 0, padding[0], padding[1])) + if target is None: + return padded_image, None + target = target.copy() + # should we do something wrt the original size? + target["size"] = torch.tensor(padded_image[::-1]) + if "masks" in target: + target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1])) + return padded_image, target + + +class MOTHSV: + def __init__(self, hgain=5, sgain=30, vgain=30) -> None: + self.hgain = hgain + self.sgain = sgain + self.vgain = vgain + + def __call__(self, imgs: list, targets: list): + hsv_augs = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] # random gains + hsv_augs *= np.random.randint(0, 2, 3) # random selection of h, s, v + hsv_augs = hsv_augs.astype(np.int16) + for i in range(len(imgs)): + img = np.array(imgs[i]) + img_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV).astype(np.int16) + + img_hsv[..., 0] = (img_hsv[..., 0] + hsv_augs[0]) % 180 + img_hsv[..., 1] = np.clip(img_hsv[..., 1] + hsv_augs[1], 0, 255) + img_hsv[..., 2] = np.clip(img_hsv[..., 2] + hsv_augs[2], 0, 255) + + imgs[i] = cv2.cvtColor(img_hsv.astype(img.dtype), cv2.COLOR_HSV2RGB) # no return needed + return imgs, targets + + +class RandomCrop(object): + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + region = T.RandomCrop.get_params(img, self.size) + return crop(img, target, region) + + +class MotRandomCrop(RandomCrop): + def __call__(self, imgs: list, targets: list): + ret_imgs = [] + ret_targets = [] + region = T.RandomCrop.get_params(imgs[0], self.size) + for img_i, targets_i in zip(imgs, targets): + img_i, targets_i = crop(img_i, targets_i, region) + ret_imgs.append(img_i) + ret_targets.append(targets_i) + return ret_imgs, ret_targets + +class FixedMotRandomCrop(object): + def __init__(self, min_size: int, max_size: int): + self.min_size = min_size + self.max_size = max_size + + def __call__(self, imgs: list, targets: list): + ret_imgs = [] + ret_targets = [] + w = random.randint(self.min_size, min(imgs[0].width, self.max_size)) + h = random.randint(self.min_size, min(imgs[0].height, self.max_size)) + region = T.RandomCrop.get_params(imgs[0], [h, w]) + for img_i, targets_i in zip(imgs, targets): + img_i, targets_i = crop_mot(img_i, targets_i, region) + ret_imgs.append(img_i) + ret_targets.append(targets_i) + return ret_imgs, ret_targets + +class MotRandomShift(object): + def __init__(self, bs=1): + self.bs = bs + + def __call__(self, imgs: list, targets: list): + ret_imgs = copy.deepcopy(imgs) + ret_targets = copy.deepcopy(targets) + + n_frames = len(imgs) + select_i = random.choice(list(range(n_frames))) + w, h = imgs[select_i].size + + xshift = (100 * torch.rand(self.bs)).int() + xshift *= (torch.randn(self.bs) > 0.0).int() * 2 - 1 + yshift = (100 * torch.rand(self.bs)).int() + yshift *= (torch.randn(self.bs) > 0.0).int() * 2 - 1 + ymin = max(0, -yshift[0]) + ymax = min(h, h - yshift[0]) + xmin = max(0, -xshift[0]) + xmax = min(w, w - xshift[0]) + + region = (int(ymin), int(xmin), int(ymax-ymin), int(xmax-xmin)) + ret_imgs[select_i], ret_targets[select_i] = random_shift(imgs[select_i], targets[select_i], region, (h,w)) + + return ret_imgs, ret_targets + + +class FixedMotRandomShift(object): + def __init__(self, bs=1, padding=50): + self.bs = bs + self.padding = padding + + def __call__(self, imgs: list, targets: list): + ret_imgs = [] + ret_targets = [] + + n_frames = self.bs + w, h = imgs[0].size + xshift = (self.padding * torch.rand(self.bs)).int() + 1 + xshift *= (torch.randn(self.bs) > 0.0).int() * 2 - 1 + yshift = (self.padding * torch.rand(self.bs)).int() + 1 + yshift *= (torch.randn(self.bs) > 0.0).int() * 2 - 1 + ret_imgs.append(imgs[0]) + ret_targets.append(targets[0]) + for i in range(1, n_frames): + ymin = max(0, -yshift[0]) + ymax = min(h, h - yshift[0]) + xmin = max(0, -xshift[0]) + xmax = min(w, w - xshift[0]) + prev_img = ret_imgs[i-1].copy() + prev_target = copy.deepcopy(ret_targets[i-1]) + region = (int(ymin), int(xmin), int(ymax - ymin), int(xmax - xmin)) + img_i, target_i = random_shift(prev_img, prev_target, region, (h, w)) + ret_imgs.append(img_i) + ret_targets.append(target_i) + + return ret_imgs, ret_targets + + +class RandomSizeCrop(object): + def __init__(self, min_size: int, max_size: int): + self.min_size = min_size + self.max_size = max_size + + def __call__(self, img: PIL.Image.Image, target: dict): + w = random.randint(self.min_size, min(img.width, self.max_size)) + h = random.randint(self.min_size, min(img.height, self.max_size)) + region = T.RandomCrop.get_params(img, [h, w]) + return crop(img, target, region) + + +class MotRandomSizeCrop(RandomSizeCrop): + def __call__(self, imgs, targets): + w = random.randint(self.min_size, min(imgs[0].width, self.max_size)) + h = random.randint(self.min_size, min(imgs[0].height, self.max_size)) + region = T.RandomCrop.get_params(imgs[0], [h, w]) + ret_imgs = [] + ret_targets = [] + for img_i, targets_i in zip(imgs, targets): + img_i, targets_i = crop(img_i, targets_i, region) + ret_imgs.append(img_i) + ret_targets.append(targets_i) + return ret_imgs, ret_targets + + +class CenterCrop(object): + def __init__(self, size): + self.size = size + + def __call__(self, img, target): + image_width, image_height = img.size + crop_height, crop_width = self.size + crop_top = int(round((image_height - crop_height) / 2.)) + crop_left = int(round((image_width - crop_width) / 2.)) + return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) + + +class MotCenterCrop(CenterCrop): + def __call__(self, imgs, targets): + image_width, image_height = imgs[0].size + crop_height, crop_width = self.size + crop_top = int(round((image_height - crop_height) / 2.)) + crop_left = int(round((image_width - crop_width) / 2.)) + ret_imgs = [] + ret_targets = [] + for img_i, targets_i in zip(imgs, targets): + img_i, targets_i = crop(img_i, targets_i, (crop_top, crop_left, crop_height, crop_width)) + ret_imgs.append(img_i) + ret_targets.append(targets_i) + return ret_imgs, ret_targets + + +class RandomHorizontalFlip(object): + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return hflip(img, target) + return img, target + + +class MotRandomHorizontalFlip(RandomHorizontalFlip): + def __call__(self, imgs, targets): + if random.random() < self.p: + ret_imgs = [] + ret_targets = [] + for img_i, targets_i in zip(imgs, targets): + img_i, targets_i = hflip(img_i, targets_i) + ret_imgs.append(img_i) + ret_targets.append(targets_i) + return ret_imgs, ret_targets + return imgs, targets + + +class RandomResize(object): + def __init__(self, sizes, max_size=None): + assert isinstance(sizes, (list, tuple)) + self.sizes = sizes + self.max_size = max_size + + def __call__(self, img, target=None): + size = random.choice(self.sizes) + return resize(img, target, size, self.max_size) + + +class MotRandomResize(RandomResize): + def __call__(self, imgs, targets): + size = random.choice(self.sizes) + ret_imgs = [] + ret_targets = [] + for img_i, targets_i in zip(imgs, targets): + img_i, targets_i = resize(img_i, targets_i, size, self.max_size) + ret_imgs.append(img_i) + ret_targets.append(targets_i) + return ret_imgs, ret_targets + + +class RandomPad(object): + def __init__(self, max_pad): + self.max_pad = max_pad + + def __call__(self, img, target): + pad_x = random.randint(0, self.max_pad) + pad_y = random.randint(0, self.max_pad) + return pad(img, target, (pad_x, pad_y)) + + +class MotRandomPad(RandomPad): + def __call__(self, imgs, targets): + pad_x = random.randint(0, self.max_pad) + pad_y = random.randint(0, self.max_pad) + ret_imgs = [] + ret_targets = [] + for img_i, targets_i in zip(imgs, targets): + img_i, target_i = pad(img_i, targets_i, (pad_x, pad_y)) + ret_imgs.append(img_i) + ret_targets.append(targets_i) + return ret_imgs, ret_targets + + +class RandomSelect(object): + """ + Randomly selects between transforms1 and transforms2, + with probability p for transforms1 and (1 - p) for transforms2 + """ + def __init__(self, transforms1, transforms2, p=0.5): + self.transforms1 = transforms1 + self.transforms2 = transforms2 + self.p = p + + def __call__(self, img, target): + if random.random() < self.p: + return self.transforms1(img, target) + return self.transforms2(img, target) + + +class MotRandomSelect(RandomSelect): + """ + Randomly selects between transforms1 and transforms2, + with probability p for transforms1 and (1 - p) for transforms2 + """ + def __call__(self, imgs, targets): + if random.random() < self.p: + return self.transforms1(imgs, targets) + return self.transforms2(imgs, targets) + + +class ToTensor(object): + def __call__(self, img, target): + return F.to_tensor(img), target + + +class MotToTensor(ToTensor): + def __call__(self, imgs, targets): + ret_imgs = [] + for img in imgs: + ret_imgs.append(F.to_tensor(img)) + return ret_imgs, targets + + +class RandomErasing(object): + + def __init__(self, *args, **kwargs): + self.eraser = T.RandomErasing(*args, **kwargs) + + def __call__(self, img, target): + return self.eraser(img), target + + +class MotRandomErasing(RandomErasing): + def __call__(self, imgs, targets): + # TODO: Rewrite this part to ensure the data augmentation is same to each image. + ret_imgs = [] + for img_i, targets_i in zip(imgs, targets): + ret_imgs.append(self.eraser(img_i)) + return ret_imgs, targets + + +class MoTColorJitter(T.ColorJitter): + def __call__(self, imgs, targets): + transform = self.get_params(self.brightness, self.contrast, + self.saturation, self.hue) + ret_imgs = [] + for img_i, targets_i in zip(imgs, targets): + ret_imgs.append(transform(img_i)) + return ret_imgs, targets + + +class Normalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, image, target=None): + if target is not None: + target['ori_img'] = image.clone() + image = F.normalize(image, mean=self.mean, std=self.std) + if target is None: + return image, None + target = target.copy() + h, w = image.shape[-2:] + if "boxes" in target: + boxes = target["boxes"] + boxes = box_xywh_to_cxcywh(boxes) + boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) + target["boxes"] = boxes + return image, target + + +class MotNormalize(Normalize): + def __call__(self, imgs, targets=None): + ret_imgs = [] + ret_targets = [] + for i in range(len(imgs)): + img_i = imgs[i] + targets_i = targets[i] if targets is not None else None + img_i, targets_i = super().__call__(img_i, targets_i) + ret_imgs.append(img_i) + ret_targets.append(targets_i) + return ret_imgs, ret_targets + + +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return image, target + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" + return format_string + + +class MotCompose(Compose): + def __call__(self, imgs, targets): + for t in self.transforms: + imgs, targets = t(imgs, targets) + return imgs, targets diff --git a/projects/co_mot/evaluation/__init__.py b/projects/co_mot/evaluation/__init__.py new file mode 100644 index 00000000..6c4299aa --- /dev/null +++ b/projects/co_mot/evaluation/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .dancetrack_evaluation import DancetrackEvaluator + +__all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/projects/co_mot/evaluation/dancetrack_evaluation.py b/projects/co_mot/evaluation/dancetrack_evaluation.py new file mode 100644 index 00000000..06f8e71d --- /dev/null +++ b/projects/co_mot/evaluation/dancetrack_evaluation.py @@ -0,0 +1,270 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import contextlib +import copy +import io +import itertools +import json +import logging +import numpy as np +import os +import pickle +from collections import OrderedDict +import pycocotools.mask as mask_util +import torch +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval +from tabulate import tabulate + +import detectron2.utils.comm as comm +from detectron2.config import CfgNode +from detectron2.data import MetadataCatalog +from detectron2.data.datasets.coco import convert_to_coco_json +from detectron2.structures import Boxes, BoxMode, pairwise_iou +from detectron2.utils.file_io import PathManager +from detectron2.utils.logger import create_small_table + +from detectron2.evaluation.evaluator import DatasetEvaluator + +try: + from detectron2.evaluation.fast_eval_api import COCOeval_opt +except ImportError: + COCOeval_opt = COCOeval + +try: + from multiprocessing import freeze_support + import trackeval +except ImportError: + # trackeval = None + raise ImportError( + 'Please run ' + 'pip install git+https://github.com/JonathonLuiten/TrackEval.git' # noqa + ' to manually install trackeval') +# from thirdparty.TrackEval.scripts import run_mot_challenge + + +class DancetrackEvaluator(DatasetEvaluator): + """ + Evaluate AR for object proposals, AP for instance detection/segmentation, AP + for keypoint detection outputs using COCO's metrics. + See http://cocodataset.org/#detection-eval and + http://cocodataset.org/#keypoints-eval to understand its metrics. + The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means + the metric cannot be computed (e.g. due to no predictions made). + + In addition to COCO, this evaluator is able to support any bounding box detection, + instance segmentation, or keypoint detection dataset. + """ + + def __init__( + self, + dataset_name, + tasks=None, + distributed=True, + output_dir=None, + *, + max_dets_per_image=None, + use_fast_impl=True, + kpt_oks_sigmas=(), + allow_cached_coco=True, + ): + """ + Args: + dataset_name (str): name of the dataset to be evaluated. + It must have either the following corresponding metadata: + + "json_file": the path to the COCO format annotation + + Or it must be in detectron2's standard dataset format + so it can be converted to COCO format automatically. + tasks (tuple[str]): tasks that can be evaluated under the given + configuration. A task is one of "bbox", "segm", "keypoints". + By default, will infer this automatically from predictions. + distributed (True): if True, will collect results from all ranks and run evaluation + in the main process. + Otherwise, will only evaluate the results in the current process. + output_dir (str): optional, an output directory to dump all + results predicted on the dataset. The dump contains two files: + + 1. "instances_predictions.pth" a file that can be loaded with `torch.load` and + contains all the results in the format they are produced by the model. + 2. "coco_instances_results.json" a json file in COCO's result format. + max_dets_per_image (int): limit on the maximum number of detections per image. + By default in COCO, this limit is to 100, but this can be customized + to be greater, as is needed in evaluation metrics AP fixed and AP pool + (see https://arxiv.org/pdf/2102.01066.pdf) + This doesn't affect keypoint evaluation. + use_fast_impl (bool): use a fast but **unofficial** implementation to compute AP. + Although the results should be very close to the official implementation in COCO + API, it is still recommended to compute results with the official API for use in + papers. The faster implementation also uses more RAM. + kpt_oks_sigmas (list[float]): The sigmas used to calculate keypoint OKS. + See http://cocodataset.org/#keypoints-eval + When empty, it will use the defaults in COCO. + Otherwise it should be the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS. + allow_cached_coco (bool): Whether to use cached coco json from previous validation + runs. You should set this to False if you need to use different validation data. + Defaults to True. + """ + self._logger = logging.getLogger(__name__) + self._distributed = distributed + self._output_dir = output_dir + + if use_fast_impl and (COCOeval_opt is COCOeval): + self._logger.info("Fast COCO eval is not built. Falling back to official COCO eval.") + use_fast_impl = False + self._use_fast_impl = use_fast_impl + + # COCOeval requires the limit on the number of detections per image (maxDets) to be a list + # with at least 3 elements. The default maxDets in COCOeval is [1, 10, 100], in which the + # 3rd element (100) is used as the limit on the number of detections per image when + # evaluating AP. COCOEvaluator expects an integer for max_dets_per_image, so for COCOeval, + # we reformat max_dets_per_image into [1, 10, max_dets_per_image], based on the defaults. + if max_dets_per_image is None: + max_dets_per_image = [1, 10, 100] + else: + max_dets_per_image = [1, 10, max_dets_per_image] + self._max_dets_per_image = max_dets_per_image + + if tasks is not None and isinstance(tasks, CfgNode): + kpt_oks_sigmas = ( + tasks.TEST.KEYPOINT_OKS_SIGMAS if not kpt_oks_sigmas else kpt_oks_sigmas + ) + self._logger.warn( + "COCO Evaluator instantiated using config, this is deprecated behavior." + " Please pass in explicit arguments instead." + ) + self._tasks = None # Infering it from predictions should be better + else: + self._tasks = tasks + + self._cpu_device = torch.device("cpu") + + self._metadata = MetadataCatalog.get(dataset_name) + if not hasattr(self._metadata, "json_file"): + if output_dir is None: + raise ValueError( + "output_dir must be provided to COCOEvaluator " + "for datasets not in COCO format." + ) + self._logger.info(f"Trying to convert '{dataset_name}' to COCO format ...") + + cache_path = os.path.join(output_dir, f"{dataset_name}_coco_format.json") + self._metadata.json_file = cache_path + convert_to_coco_json(dataset_name, cache_path, allow_cached=allow_cached_coco) + + json_file = PathManager.get_local_path(self._metadata.json_file) + with contextlib.redirect_stdout(io.StringIO()): + self._coco_api = COCO(json_file) + + # Test set json files do not contain annotations (evaluation must be + # performed using the COCO evaluation server). + self._do_evaluation = "annotations" in self._coco_api.dataset + if self._do_evaluation: + self._kpt_oks_sigmas = kpt_oks_sigmas + + self.image_ids = [] + + def reset(self): + self._predictions = [] + + def process(self, inputs, outputs): + """ + Args: + inputs: the inputs to a COCO model (e.g., GeneralizedRCNN). + It is a list of dict. Each dict corresponds to an image and + contains keys like "height", "width", "file_name", "image_id". + outputs: the outputs of a COCO model. It is a list of dicts with key + "instances" that contains :class:`Instances`. + """ + + for input, output in zip(inputs, outputs): + lines = [] + total_dts, total_occlusion_dts = 0, 0 + dataset_dict = input['dataset_dict'] + seq_num = os.path.basename(dataset_dict[0]['video_name']) + for i_frame, o_frame in zip(dataset_dict, output): + frame_ith = int(i_frame['frame_id']) + + dt_instances = o_frame + bbox_xyxy = dt_instances.boxes.tolist() + identities = dt_instances.obj_idxes.tolist() + labels = dt_instances.labels.tolist() + + total_dts += len(dt_instances) + + save_format = '{frame},{id},{x1:.2f},{y1:.2f},{w:.2f},{h:.2f},1,-1,-1,-1\n' + for xyxy, track_id in zip(bbox_xyxy, identities): + if track_id < 0 or track_id is None: + continue + x1, y1, x2, y2 = xyxy + w, h = x2 - x1, y2 - y1 + lines.append(save_format.format(frame=frame_ith, id=track_id, x1=x1, y1=y1, w=w, h=h)) + + os.makedirs(os.path.join(self._output_dir, 'results'), exist_ok=True) + with open(os.path.join(self._output_dir, 'results', f'{seq_num}.txt'), 'w') as f: + f.writelines(lines) + print("{}: totally {} dts {} occlusion dts".format(seq_num, total_dts, total_occlusion_dts)) + + def evaluate(self, img_ids=None): + """ + Args: + img_ids: a list of image IDs to evaluate on. Default to None for the whole dataset + """ + if self._distributed: + comm.synchronize() + if not comm.is_main_process(): + return {} + + res_eval = _run_mot_challenge(SPLIT_TO_EVAL="val", + METRICS=['HOTA', 'CLEAR', 'Identity'], + GT_FOLDER=self._metadata.image_root, + SEQMAP_FILE=self._metadata.seqmap_txt, + SKIP_SPLIT_FOL=True, + TRACKERS_TO_EVAL=[''], + TRACKER_SUB_FOLDER='', + USE_PARALLEL=True, + NUM_PARALLEL_CORES=8, + PLOT_CURVES=False, + TRACKERS_FOLDER=os.path.join(self._output_dir, 'results') + ) + self._results = OrderedDict() + self._results = res_eval + return res_eval[0]['MotChallenge2DBox']['']['COMBINED_SEQ']['pedestrian'] + + +def _run_mot_challenge(**argc_dict): + freeze_support() + + # Command line interface: + default_eval_config = trackeval.Evaluator.get_default_eval_config() + default_eval_config['DISPLAY_LESS_PROGRESS'] = False + default_dataset_config = trackeval.datasets.MotChallenge2DBox.get_default_dataset_config() + default_metrics_config = {'METRICS': ['HOTA', 'CLEAR', 'Identity'], 'THRESHOLD': 0.5} + config = {**default_eval_config, **default_dataset_config, **default_metrics_config} # Merge default configs + for setting in config.keys(): + if setting in argc_dict: + if type(config[setting]) == type(True): + x = argc_dict[setting] + elif type(config[setting]) == type(1): + x = int(argc_dict[setting]) + elif type(argc_dict[setting]) == type(None): + x = None + elif setting == 'SEQ_INFO': + x = dict(zip(argc_dict[setting], [None]*len(argc_dict[setting]))) + else: + x = argc_dict[setting] + config[setting] = x + eval_config = {k: v for k, v in config.items() if k in default_eval_config.keys()} + dataset_config = {k: v for k, v in config.items() if k in default_dataset_config.keys()} + metrics_config = {k: v for k, v in config.items() if k in default_metrics_config.keys()} + + # Run code + evaluator = trackeval.Evaluator(eval_config) + dataset_list = [trackeval.datasets.MotChallenge2DBox(dataset_config)] + metrics_list = [] + for metric in [trackeval.metrics.HOTA, trackeval.metrics.CLEAR, trackeval.metrics.Identity, trackeval.metrics.VACE]: + if metric.get_name() in metrics_config['METRICS']: + metrics_list.append(metric(metrics_config)) + if len(metrics_list) == 0: + raise Exception('No metrics selected for evaluation') + return evaluator.evaluate(dataset_list, metrics_list) diff --git a/projects/co_mot/modeling/__init__.py b/projects/co_mot/modeling/__init__.py new file mode 100644 index 00000000..c1338f1b --- /dev/null +++ b/projects/co_mot/modeling/__init__.py @@ -0,0 +1,37 @@ +''' +Author: 颜峰 && bphengyan@163.com +Date: 2023-05-26 10:06:20 +LastEditors: 颜峰 && bphengyan@163.com +LastEditTime: 2023-05-30 16:03:02 +FilePath: /detrex/projects/co_mot/modeling/__init__.py +Description: + +Copyright (c) 2023 by ${git_name_email}, All Rights Reserved. +''' +# coding=utf-8 +# Copyright 2022 The IDEA Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .mot import MOT +from .mot import ClipMatcher as MOTClipMatcher +from .mot import TrackerPostProcess as MOTTrackerPostProcess +from .mot import RuntimeTrackerBase as MOTRuntimeTrackerBase + +from .mot_transformer import DeformableTransformer as MOTDeformableTransformer + +from .qim import QueryInteractionModuleGroup as MOTQueryInteractionModuleGroup + +from .matcher import HungarianMatcherGroup as MOTHungarianMatcherGroup + diff --git a/projects/co_mot/modeling/matcher.py b/projects/co_mot/modeling/matcher.py new file mode 100755 index 00000000..283a9a93 --- /dev/null +++ b/projects/co_mot/modeling/matcher.py @@ -0,0 +1,128 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-research. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + + +""" +Modules to compute the matching cost and solve the corresponding LSAP. +""" +import numpy as np +import torch +from scipy.optimize import linear_sum_assignment +from torch import nn + +from detrex.layers import box_cxcywh_to_xyxy, generalized_box_iou +from detectron2.structures import Instances + + +class HungarianMatcherGroup(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + def __init__(self, + cost_class: float = 1, + cost_bbox: float = 1, + cost_giou: float = 1): + """Creates the matcher + + Params: + cost_class: This is the relative weight of the classification error in the matching cost + cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost + cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost + """ + super().__init__() + self.cost_class = cost_class + self.cost_bbox = cost_bbox + self.cost_giou = cost_giou + assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" + + def forward(self, outputs, targets, use_focal=True, g_size=1): + """ Performs the matching + + Params: + outputs: This is a dict that contains at least these entries: + "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates + + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth + objects in the target) containing the class labels + "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + with torch.no_grad(): + bs, num_queries = outputs["pred_logits"].shape[:2] + # We flatten to compute the cost matrices in a batch + if use_focal: + out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() + else: + out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] + out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] + + # Also concat the target labels and boxes + if isinstance(targets[0], Instances): + tgt_ids = torch.cat([gt_per_img.labels for gt_per_img in targets]) + tgt_bbox = torch.cat([gt_per_img.boxes for gt_per_img in targets]) + else: + tgt_ids = torch.cat([v["labels"] for v in targets]) + tgt_bbox = torch.cat([v["boxes"] for v in targets]) + + # Compute the classification cost. + if use_focal: + alpha = 0.25 + gamma = 2.0 + neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log()) + pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) + cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] + else: + # Compute the classification cost. Contrary to the loss, we don't use the NLL, + # but approximate it in 1 - proba[target class]. + # The 1 is a constant that doesn't change the matching, it can be ommitted. + cost_class = -out_prob[:, tgt_ids] + + # Compute the L1 cost between boxes + cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) + + # Compute the giou cost betwen boxes + cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), + box_cxcywh_to_xyxy(tgt_bbox)) + + # Final cost matrix + C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou + C2 = C.view(bs, num_queries//g_size, g_size, -1).cpu() + if C2.shape[-1] > 0: + # Cmin, _ = C2.min(dim=2) + Cmin, _ = C2.max(dim=2) + # Cmin = C2.mean(dim=2) + + if isinstance(targets[0], Instances): + sizes = [len(gt_per_img.boxes) for gt_per_img in targets] + else: + sizes = [len(v["boxes"]) for v in targets] + + indices = [linear_sum_assignment(c[i]) for i, c in enumerate(Cmin.split(sizes, -1))] + + gindices = [] + for ind in indices: + Cindx = np.arange(num_queries).reshape(num_queries//g_size, g_size) + gindices.append((Cindx[ind[0]].reshape(-1), ind[1].repeat(g_size))) + else: + gindices = [(np.zeros(0, dtype=np.int32), np.zeros(0, dtype=np.int32))] + return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in gindices] + diff --git a/projects/co_mot/modeling/mot.py b/projects/co_mot/modeling/mot.py new file mode 100644 index 00000000..7bc420b9 --- /dev/null +++ b/projects/co_mot/modeling/mot.py @@ -0,0 +1,868 @@ +''' +Author: 颜峰 && bphengyan@163.com +Date: 2023-05-24 13:53:39 +LastEditors: 颜峰 && bphengyan@163.com +LastEditTime: 2023-05-30 19:02:41 +FilePath: /detrex/projects/co_mot/modeling/mot.py +Description: + +Copyright (c) 2023 by ${git_name_email}, All Rights Reserved. +''' +# coding=utf-8 + +import math +import numpy as np +from typing import List, Optional +import torch +import torch.nn as nn +import torch.nn.functional as F +from copy import copy +from copy import deepcopy +from collections import defaultdict + +from detrex.layers import MLP, box_cxcywh_to_xyxy, box_xyxy_to_cxcywh, generalized_box_iou +from detrex.utils import inverse_sigmoid +from detrex.modeling import SetCriterion +from detrex.modeling.criterion.criterion import sigmoid_focal_loss + +from detectron2.structures import Boxes, ImageList, Instances +from detectron2.structures.boxes import matched_pairwise_iou +from detrex.utils import get_world_size, is_dist_avail_and_initialized + +from projects.co_mot.util import checkpoint +from projects.co_mot.util.misc import (NestedTensor, nested_tensor_from_tensor_list, accuracy) + + +class MOT(nn.Module): + """ Implement CO-MOT: Bridging the Gap Between End-to-end and Non-End-to-end Multi-Object Tracking + Args: + backbone (nn.Module): backbone module + position_embedding (nn.Module): position embedding module + neck (nn.Module): neck module to handle the intermediate outputs features + transformer (nn.Module): transformer module + embed_dim (int): dimension of embedding + num_classes (int): Number of total categories. + num_queries (int): Number of proposal dynamic anchor boxes in Transformer + criterion (nn.Module): Criterion for calculating the total losses. + pixel_mean (List[float]): Pixel mean value for image normalization. + Default: [123.675, 116.280, 103.530]. + pixel_std (List[float]): Pixel std value for image normalization. + Default: [58.395, 57.120, 57.375]. + aux_loss (bool): Whether to calculate auxiliary loss in criterion. Default: True. + select_box_nums_for_evaluation (int): the number of topk candidates + slected at postprocess for evaluation. Default: 300. + device (str): Training device. Default: "cuda". + """ + + def __init__( + self, + backbone: nn.Module, + position_embedding: nn.Module, + neck: nn.Module, + transformer: nn.Module, + embed_dim: int, + num_classes: int, + num_queries: int, + criterion: nn.Module, + track_embed: nn.Module, + track_base: nn.Module, + post_process: nn.Module, + aux_loss: bool = True, + device="cuda", + g_size=1, + ): + super().__init__() + # define backbone and position embedding module + self.backbone = backbone + self.position_embedding = position_embedding + + # define neck module + self.neck = neck + + # number of dynamic anchor boxes and embedding dimension + self.num_queries = num_queries + self.embed_dim = embed_dim + + # define transformer module + self.transformer = transformer + + # define classification head and box head + self.class_embed = nn.Linear(embed_dim, num_classes) + self.bbox_embed = MLP(embed_dim, embed_dim, 4, 3) + self.num_classes = num_classes + + # where to calculate auxiliary loss in criterion + self.aux_loss = aux_loss + self.criterion = criterion + + self.device = device + + # initialize weights + prior_prob = 0.01 + bias_value = -math.log((1 - prior_prob) / prior_prob) + self.class_embed.bias.data = torch.ones(num_classes) * bias_value + nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) + nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) + for _, neck_layer in self.neck.named_modules(): + if isinstance(neck_layer, nn.Conv2d): + nn.init.xavier_uniform_(neck_layer.weight, gain=1) + nn.init.constant_(neck_layer.bias, 0) + + # if two-stage, the last class_embed and bbox_embed is for region proposal generation + num_pred = transformer.decoder.num_layers + self.class_embed = nn.ModuleList([deepcopy(self.class_embed) for i in range(num_pred)]) + self.bbox_embed = nn.ModuleList([deepcopy(self.bbox_embed) for i in range(num_pred)]) + nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0) + + # two-stage + self.transformer.decoder.class_embed = self.class_embed + self.transformer.decoder.bbox_embed = self.bbox_embed + + # hack implementation for two-stage + for bbox_embed_layer in self.bbox_embed: + nn.init.constant_(bbox_embed_layer.layers[-1].bias.data[2:], 0.0) + + + # for Track + self.track_embed = track_embed + self.post_process = post_process # TrackerPostProcess(g_size=g_size) + self.track_base = track_base + + # for shadow + self.g_size = g_size + + # for init of query + self.position = nn.Embedding(num_queries, 4) + self.position_offset = nn.Embedding(num_queries*g_size, 4) + self.query_embed = nn.Embedding(num_queries, embed_dim) + self.query_embed_offset = nn.Embedding(num_queries*g_size, embed_dim) + + nn.init.uniform_(self.position.weight.data, 0, 1) + nn.init.normal_(self.position_offset.weight.data, 0, 10e-6) # 默认为10e-6 + nn.init.normal_(self.query_embed_offset.weight.data, 0, 10e-6) # 默认为10e-6 + + def _generate_empty_tracks(self, g_size=1, batch_size=1): + track_instances = Instances((1, 1)) + num_queries, d_model = self.query_embed.weight.shape # (300, 512) + device = self.query_embed.weight.device + track_instances.ref_pts = self.position.weight.view(-1, 1, 4).repeat(1, g_size, 1).view(-1, 4) + self.position_offset.weight + track_instances.query_pos = self.query_embed.weight.view(-1, 1, d_model).repeat(1, g_size, 1).view(-1, d_model) + self.query_embed_offset.weight + track_instances.ref_pts = track_instances.ref_pts.view(-1, 1, 4).repeat(1, batch_size, 1) + track_instances.query_pos = track_instances.query_pos.view(-1, 1, d_model).repeat(1, batch_size, 1) + + track_instances.output_embedding = torch.zeros((len(track_instances), batch_size, d_model), device=device) # motr decode输出的feature,把这个输入qim中可以获得track的query,某个目标跟踪过程中不再使用query_pos + track_instances.obj_idxes = torch.full((len(track_instances), batch_size), -1, dtype=torch.long, device=device) # ID + track_instances.matched_gt_idxes = torch.full((len(track_instances), batch_size), -1, dtype=torch.long, device=device) # 与匹配到的gt在该图片中的索引 + track_instances.disappear_time = torch.zeros((len(track_instances), batch_size), dtype=torch.long, device=device) # 消失时间,假如目标跟踪多久后删除该目标 + track_instances.iou = torch.zeros((len(track_instances), batch_size), dtype=torch.float, device=device) # 与对应GT的IOU + track_instances.scores = torch.zeros((len(track_instances), batch_size), dtype=torch.float, device=device) # 实际是当前帧检测输出的置信度 + track_instances.track_scores = torch.zeros((len(track_instances), batch_size), dtype=torch.float, device=device) + track_instances.pred_boxes = torch.zeros((len(track_instances), batch_size, 4), dtype=torch.float, device=device) # 检测或跟踪query输出的box框 + track_instances.pred_logits = torch.zeros((len(track_instances), batch_size, self.num_classes), dtype=torch.float, device=device) # 检测或跟踪query输出的置信度argsigmod + track_instances.group_ids = torch.arange(g_size, dtype=torch.long, device=device).repeat(num_queries).view(-1, 1).repeat(1, batch_size) + track_instances.labels = torch.full((len(track_instances), batch_size), -1, dtype=torch.long, device=device) + + return track_instances.to(self.query_embed.weight.device) + + def clear(self): + self.track_base.clear() + + def _forward_single_image(self, samples, track_instances, gtboxes=None): + """Forward function of `MOT`. + """ + + # original features + features = self.backbone(samples.tensors) # output feature dict + img_masks = samples.mask + + # project backbone features to the reuired dimension of transformer + # we use multi-scale features in DINO + multi_level_feats = self.neck(features) + multi_level_masks = [] + multi_level_position_embeddings = [] + for feat in multi_level_feats: + multi_level_masks.append( + F.interpolate(img_masks[None].float(), size=feat.shape[-2:]).to(torch.bool).squeeze(0) + ) + multi_level_position_embeddings.append(self.position_embedding(multi_level_masks[-1])) + + # prepare label query embedding + input_query_label = track_instances.query_pos + input_query_bbox = track_instances.ref_pts + attn_mask = None + + # feed into transformer 包括encode + decode + ( + inter_states, + init_reference, + inter_references, + enc_state, + enc_reference, # [0..1] + ) = self.transformer( + multi_level_feats, + multi_level_masks, + multi_level_position_embeddings, + input_query_label, + ref_pts=input_query_bbox, + attn_mask=attn_mask, + ) + + # Calculate output coordinates and classes. + outputs_classes = [] + outputs_coords = [] + for lvl in range(inter_states.shape[0]): + if lvl == 0: + reference = init_reference + else: + reference = inter_references[lvl - 1] + reference = inverse_sigmoid(reference) + outputs_class = self.class_embed[lvl](inter_states[lvl]) + tmp = self.bbox_embed[lvl](inter_states[lvl]) + if reference.shape[-1] == 4: + tmp += reference + else: + assert reference.shape[-1] == 2 + tmp[..., :2] += reference + outputs_coord = tmp.sigmoid() + outputs_classes.append(outputs_class) + outputs_coords.append(outputs_coord) + outputs_class = torch.stack(outputs_classes) + # tensor shape: [num_decoder_layers, bs, num_query, num_classes] + outputs_coord = torch.stack(outputs_coords) + # tensor shape: [num_decoder_layers, bs, num_query, 4] + + # prepare for loss computation + output = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]} + if self.aux_loss: + output["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord) + output['hs'] = inter_states[-1] + return output + + def _post_process_single_image(self, frame_res, track_instances, is_last): + with torch.no_grad(): + if self.training: + track_scores = frame_res['pred_logits'].sigmoid().max(dim=-1).values + else: + track_scores = frame_res['pred_logits'].sigmoid().max(dim=-1).values + + track_instances.scores = track_scores.transpose(0, 1) + track_instances.pred_logits = frame_res['pred_logits'].transpose(0, 1) + track_instances.pred_boxes = frame_res['pred_boxes'].transpose(0, 1) + track_instances.output_embedding = frame_res['hs'].transpose(0, 1) + + if self.training: + # the track id will be assigned by the mather. + frame_res['track_instances'] = track_instances + track_instances = self.criterion.match_for_single_frame(frame_res) # 找匹配(跟踪query+检测query)分配GT的ID+ 算loss + else: + # each track will be assigned an unique global id by the track base. + self.track_base.update(track_instances, g_size=self.g_size) # 为存在的目标分配ID,并删除长时间消失的目标ID + + tmp = {} + tmp['track_instances'] = track_instances + if not is_last: # 经过这步后将仅保留有ID的目标,且更新了track的query和pos + out_track_instances = self.track_embed(tmp, g_size=self.g_size) # 更新跟踪的query,用于下一帧(检测query为学习的,跟踪query则为上一帧输出经过qim变换的特征) + frame_res['track_instances'] = out_track_instances + else: + frame_res['track_instances'] = None + + # print('post:', t1-t0, t2-t1) + return frame_res + + # 获取当前帧的跟踪框 + @torch.no_grad() + def inference_single_image(self, img, ori_img_size, track_instances=None): + if not isinstance(img, NestedTensor): + img = nested_tensor_from_tensor_list(img) # 补pad并或者pad的mask + if track_instances is None: + track_instances = self._generate_empty_tracks(g_size=self.g_size) # 初始化decode的输入或者说目标query,包括(query+pose) + else: + track_instances = Instances.cat([self._generate_empty_tracks(g_size=self.g_size), track_instances]) + + res = self._forward_single_image(img, track_instances=track_instances) # backbone+encode+decode,获得decode的输出和中间迭代过程输出的box和最后输出的feat(hs,作为经过QIM后可作为下一帧的query) + res = self._post_process_single_image(res, track_instances, False) # train是算loss,test时过滤有效跟踪框+为下一帧更新query/pos + + track_instances = res['track_instances'] + track_instances = self.post_process(track_instances, ori_img_size) # 把box框换算到图像大小 + ret = {'track_instances': track_instances} + if 'ref_pts' in res: # 把参考点也换算到图像大小 + ref_pts = res['ref_pts'] + img_h, img_w = ori_img_size + scale_fct = torch.Tensor([img_w, img_h]).to(ref_pts) + ref_pts = ref_pts * scale_fct[None] + ret['ref_pts'] = ref_pts + return ret + + def forward(self, data: dict): + # 准备 + def fn(frame, gtboxes, track_instances): + frame = nested_tensor_from_tensor_list(frame) + frame_res = self._forward_single_image(frame, track_instances, gtboxes) + return frame_res + + track_instances = None + if self.training: + self.criterion.initialize_for_single_clip(data['gt_instances']) + frames = data['imgs'] # list of Tensor. + outputs = { + 'pred_logits': [], + 'pred_boxes': [], + } + for frame_index, (frame, gt) in enumerate(zip(frames, data['gt_instances'])): + for f in frame: + f.requires_grad = False + is_last = frame_index == len(frames) - 1 + nbatch = len(frame) + gtboxes = None + + if track_instances is None: + track_instances = self._generate_empty_tracks(g_size=self.g_size, batch_size=nbatch) + else: + track_instances = Instances.cat([ + self._generate_empty_tracks(g_size=self.g_size, batch_size=nbatch), + track_instances]) + + if frame_index < len(frames) - 1: + args = [frame, gtboxes, track_instances] + params = tuple((p for p in self.parameters() if p.requires_grad)) + frame_res = checkpoint.CheckpointFunction.apply(fn, len(args), *args, *params) + else: + frame = nested_tensor_from_tensor_list(frame) + frame_res = self._forward_single_image(frame, track_instances, gtboxes) + + frame_res = self._post_process_single_image(frame_res, track_instances, is_last) + + track_instances = frame_res['track_instances'] + outputs['pred_logits'].append(frame_res['pred_logits']) + outputs['pred_boxes'].append(frame_res['pred_boxes']) + + # compute loss + outputs['losses_dict'] = self.criterion.losses_dict + loss_dict = self.criterion(outputs, data) + weight_dict = self.criterion.weight_dict + for k in loss_dict.keys(): + if k in weight_dict: + loss_dict[k] *= weight_dict[k] + return loss_dict + else: + assert len(data) == 1 + device = self.device + outputs = [] + for i, data_ in enumerate(data[0]['data_loader']): # tqdm(loader)): + cur_img, ori_img, proposals, f_path = [d[0] for d in data_] + cur_img = cur_img.to(device) + if track_instances is not None: + track_instances.remove('boxes') + # track_instances.remove('labels') + seq_h, seq_w, _ = ori_img.shape + + # 内部包含backboe+encode+decode+跟踪匹配关系+跟踪目标过滤(从query中过滤) + try: + res = self.inference_single_image(cur_img, (seq_h, seq_w), track_instances) + except: + res = self.inference_single_image(cur_img, (seq_h, seq_w), track_instances) + track_instances = res['track_instances'] + + predictions = deepcopy(res) + if len(predictions['track_instances']): + scores = predictions['track_instances'].scores.reshape(-1, self.g_size) + keep_idxs = torch.arange(len(predictions['track_instances']), device=scores.device).reshape(-1, self.g_size) + keep_idxs = keep_idxs.gather(1, scores.max(-1)[1].reshape(-1, 1)).reshape(-1) + predictions['track_instances'] = predictions['track_instances'][keep_idxs] + + predictions = _filter_predictions_with_confidence(predictions, 0.5) + predictions = _filter_predictions_with_area(predictions) + outputs.append(predictions['track_instances'].to('cpu')) + + return [outputs] + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [ + {"pred_logits": a, "pred_boxes": b} + for a, b in zip(outputs_class[:-1], outputs_coord[:-1]) + ] + +def _filter_predictions_with_area(predictions, area_threshold=100): + if "track_instances" in predictions: + preds = predictions["track_instances"] + wh = preds.boxes[:, 2:4] - preds.boxes[:, 0:2] + areas = wh[:, 0] * wh[:, 1] + keep_idxs = areas > area_threshold + predictions = copy(predictions) # don't modify the original + predictions["track_instances"] = preds[keep_idxs] + return predictions + +def _filter_predictions_with_confidence(predictions, confidence_threshold=0.5): + if "track_instances" in predictions: + preds = predictions["track_instances"] + keep_idxs = preds.scores > confidence_threshold + predictions = copy(predictions) # don't modify the original + predictions["track_instances"] = preds[keep_idxs] + return predictions + + +class ClipMatcher(SetCriterion): + """This class computes the loss for Conditional DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + + def __init__( + self, + num_classes, + matcher, + weight_dict, + losses: List[str] = ["class", "boxes"], + eos_coef: float = 0.1, + loss_class_type: str = "focal_loss", + alpha: float = 0.25, + gamma: float = 2.0, + g_size=1 + ): + """Create the criterion. + Parameters: + num_classes: number of object categories, omitting the special no-object category + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category + losses: list of all the losses to be applied. See get_loss for list of available losses. + focal_alpha: alpha in Focal Loss + """ + super().__init__(num_classes, matcher, weight_dict, losses, eos_coef, loss_class_type, alpha, gamma) + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.losses = losses + self.focal_loss = True + self.losses_dict = defaultdict(float) + self._current_frame_idx = 0 + self.g_size = g_size + + def initialize_for_single_clip(self, gt_instances: List[Instances]): # 训练过程中每个视频段之前调用,传入GT值 + self.gt_instances = gt_instances + self.num_samples = 0 + self.sample_device = None + self._current_frame_idx = 0 + self.losses_dict = defaultdict(float) + + def _step(self): + self._current_frame_idx += 1 + + def calc_loss_for_track_scores(self, track_instances: Instances): + gt_instances_i = self.gt_instances[self._current_frame_idx] + outputs = { + 'pred_logits': track_instances.track_scores[None], + } + device = track_instances.track_scores.device + + num_tracks = len(track_instances) + src_idx = torch.arange(num_tracks, dtype=torch.long, device=device) + tgt_idx = track_instances.matched_gt_idxes # -1 for FP tracks and disappeared tracks + + track_losses = self.get_loss('labels', + outputs=outputs, + gt_instances=[gt_instances_i], + indices=[(src_idx, tgt_idx)], + num_boxes=1) + self.losses_dict.update( + {'frame_{}_track_{}'.format(self._current_frame_idx, key): value for key, value in + track_losses.items()}) + + def get_num_boxes(self, num_samples): + num_boxes = torch.as_tensor(num_samples, dtype=torch.float, device=self.sample_device) + if is_dist_avail_and_initialized(): + torch.distributed.all_reduce(num_boxes) + num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() + return num_boxes + + def loss_labels(self, outputs, gt_instances: List[Instances], indices, num_boxes, log=False): + """Classification loss (Binary focal loss) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + assert "pred_logits" in outputs + src_logits = outputs["pred_logits"] + + idx = self._get_src_permutation_idx(indices) + target_classes = torch.full( + src_logits.shape[:2], + self.num_classes, + dtype=torch.int64, + device=src_logits.device, + ) + # The matched gt for disappear track query is set -1. + labels = [] + for gt_per_img, (_, J) in zip(gt_instances, indices): + labels_per_img = torch.ones_like(J) * self.num_classes + # set labels of track-appear slots to 0. + if len(gt_per_img) > 0: + labels_per_img[J != -1] = gt_per_img.labels[J[J != -1]] + labels.append(labels_per_img) + target_classes_o = torch.cat(labels) + target_classes[idx] = target_classes_o + + # Computation classification loss + if self.loss_class_type == "ce_loss": + # loss_class = F.cross_entropy( + # src_logits.transpose(1, 2), target_classes, self.empty_weight + # ) + loss_class = F.cross_entropy( + src_logits.transpose(1, 2), target_classes, self.empty_weight + ) + elif self.loss_class_type == "focal_loss": + # src_logits: (b, num_queries, num_classes) = (2, 300, 80) + # target_classes_one_hot = (2, 300, 80) + gt_labels_target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[:, :, :-1] # no loss for the last (background) class + gt_labels_target = gt_labels_target.to(src_logits) + loss_class = sigmoid_focal_loss( + src_logits.flatten(1), + gt_labels_target.flatten(1), + num_boxes=num_boxes, + alpha=self.alpha, + gamma=self.gamma + ) + + losses = {"loss_ce": loss_class} + + if log: + # TODO this should probably be a separate loss, not hacked in this one here + losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] + + return losses + + + def loss_boxes(self, outputs, gt_instances: List[Instances], indices: List[tuple], num_boxes): + """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss + targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] + The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. + """ + assert "pred_boxes" in outputs + # We ignore the regression loss of the track-disappear slots. + #TODO: Make this filter process more elegant. + filtered_idx = [] + for src_per_img, tgt_per_img in indices: + keep = tgt_per_img != -1 + filtered_idx.append((src_per_img[keep], tgt_per_img[keep])) + indices = filtered_idx + idx = self._get_src_permutation_idx(indices) + src_boxes = outputs["pred_boxes"][idx] + target_boxes = torch.cat([gt_per_img.boxes[i] for gt_per_img, (_, i) in zip(gt_instances, indices)], dim=0) + + # for pad target, don't calculate regression loss, judged by whether obj_id=-1 + target_obj_ids = torch.cat([gt_per_img.obj_ids[i] for gt_per_img, (_, i) in zip(gt_instances, indices)], dim=0) # size(16) + mask = (target_obj_ids != -1) + + loss_bbox = F.l1_loss(src_boxes[mask], target_boxes[mask], reduction="none") + + losses = {} + losses["loss_bbox"] = loss_bbox.sum() / num_boxes + loss_giou = 1 - torch.diag( + generalized_box_iou( + box_cxcywh_to_xyxy(src_boxes[mask]), + box_cxcywh_to_xyxy(target_boxes[mask]) + ) + ) + losses["loss_giou"] = loss_giou.sum() / num_boxes + + return losses + + def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): + loss_map = { + 'labels': self.loss_labels, + "boxes": self.loss_boxes, + } + assert loss in loss_map, f"do you really want to compute {loss} loss?" + return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) + + def match_for_single_frame(self, outputs: dict): + """This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + + return_indices: used for vis. if True, the layer0-5 indices will be returned as well. + + """ + outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"} + + gt_instances_i = self.gt_instances[self._current_frame_idx] # gt instances of i-th image. + track_instances: Instances = outputs_without_aux['track_instances'] + pred_logits_i = track_instances.pred_logits # predicted logits of i-th image. + pred_boxes_i = track_instances.pred_boxes # predicted boxes of i-th image. + + if not (track_instances.obj_idxes !=-1).any(): # 没有跟踪 + outputs_i = { + 'pred_logits': pred_logits_i.transpose(0,1), + 'pred_boxes': pred_boxes_i.transpose(0,1), + } + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_i, gt_instances_i, g_size=self.g_size) + indices = [(ind[0].to(pred_logits_i.device), ind[1].to(pred_logits_i.device)) for ind in indices] + + + track_instances.matched_gt_idxes[...] = -1 + for i, ind in enumerate(indices): + track_instances.matched_gt_idxes[ind[0], i] = ind[1] + track_instances.obj_idxes[ind[0], i] = gt_instances_i[i].obj_ids[ind[1]].long() + + active_idxes = (track_instances.obj_idxes[:, i] >= 0) & (track_instances.matched_gt_idxes[:, i] >= 0) # 当前帧能够匹配到的目标 + active_track_boxes = track_instances.pred_boxes[active_idxes, i] + if len(active_track_boxes) > 0: + gt_boxes = gt_instances_i[i].boxes[track_instances.matched_gt_idxes[active_idxes, i]] + active_track_boxes = box_cxcywh_to_xyxy(active_track_boxes) + gt_boxes = box_cxcywh_to_xyxy(gt_boxes) + track_instances.iou[active_idxes, i] = matched_pairwise_iou(Boxes(active_track_boxes), Boxes(gt_boxes)) + + self.num_samples += sum(len(t.boxes) for t in gt_instances_i)*self.g_size + self.sample_device = pred_logits_i.device + for loss in self.losses: + new_track_loss = self.get_loss(loss, + outputs=outputs_i, + gt_instances=gt_instances_i, + indices=indices, + num_boxes=1) + self.losses_dict.update( + {'frame_{}_{}'.format(self._current_frame_idx, key): value for key, value in new_track_loss.items()}) + + if 'aux_outputs' in outputs: # 此处匹配时对新生儿时重新算对应关系的,不直接使用最后一层box输出的对应关系 + for i, aux_outputs in enumerate(outputs['aux_outputs']): + unmatched_outputs_layer = { + 'pred_logits': aux_outputs['pred_logits'], + 'pred_boxes': aux_outputs['pred_boxes'], + } + + matched_indices_layer = self.matcher(unmatched_outputs_layer, gt_instances_i, g_size=self.g_size) + matched_indices_layer = [(ind[0].to(pred_logits_i.device), ind[1].to(pred_logits_i.device)) for ind in matched_indices_layer] + + for loss in self.losses: + if loss == 'masks': + # Intermediate masks losses are too costly to compute, we ignore them. + continue + l_dict = self.get_loss(loss, + aux_outputs, + gt_instances=gt_instances_i, + indices=matched_indices_layer, + num_boxes=1, ) + self.losses_dict.update( + {'frame_{}_aux{}_{}'.format(self._current_frame_idx, i, key): value for key, value in + l_dict.items()}) + else: + track_instances.matched_gt_idxes[...] = -1 + def match_for_single_decoder_layer(unmatched_outputs, matcher, untracked_gt_instances, unmatched_track_idxes, untracked_tgt_indexes): + new_track_indices = matcher(unmatched_outputs, + [untracked_gt_instances], g_size=self.g_size) # list[tuple(src_idx, tgt_idx)] + + src_idx = new_track_indices[0][0] + tgt_idx = new_track_indices[0][1] + # concat src and tgt. + new_matched_indices = torch.stack([unmatched_track_idxes[src_idx], untracked_tgt_indexes[tgt_idx]], + dim=1).to(pred_logits_i.device) + return new_matched_indices + for ibn, gt_ins in enumerate(gt_instances_i): + # step1. inherit and update the previous tracks. + obj_idxes = gt_ins.obj_ids + i, j = torch.where(track_instances.obj_idxes[:, ibn:ibn+1] == obj_idxes) # 获取跟踪query与之相同ID的对应索引 + track_instances.matched_gt_idxes[i, ibn] = j + + full_track_idxes = torch.arange(len(track_instances), dtype=torch.long, device=pred_logits_i.device) + matched_track_idxes = (track_instances.obj_idxes[:, ibn] >= 0) # occu >=0表明该query为跟踪query + prev_matched_indices = torch.stack( + [full_track_idxes[matched_track_idxes], track_instances.matched_gt_idxes[matched_track_idxes, ibn]], dim=1) # 检测或跟踪与gt的对应关系 + + # step2. select the unmatched slots. + # note that the FP tracks whose obj_idxes are -2 will not be selected here. + unmatched_track_idxes = full_track_idxes[track_instances.obj_idxes[:, ibn] == -1] # 获取检测query + + # step3. select the untracked gt instances (new tracks). + tgt_indexes = track_instances.matched_gt_idxes[:, ibn] + tgt_indexes = tgt_indexes[tgt_indexes != -1] # 获取跟踪query匹配GT,非新生儿(除了这些之外便是新生儿) + + tgt_state = torch.zeros(len(gt_ins), device=pred_logits_i.device) + tgt_state[tgt_indexes] = 1 # 新生儿为0,跟踪对应的GT为1 + full_tgt_idxes = torch.arange(len(gt_ins), device=pred_logits_i.device) + untracked_tgt_indexes = full_tgt_idxes[tgt_state == 0] + untracked_gt_instances = gt_ins[untracked_tgt_indexes] # 新生儿的索引 + + # step4. do matching between the unmatched slots and GTs.该过程就是DET匈牙利匹配过程 + unmatched_outputs = { + 'pred_logits': track_instances.pred_logits[unmatched_track_idxes, ibn].unsqueeze(0), + 'pred_boxes': track_instances.pred_boxes[unmatched_track_idxes, ibn].unsqueeze(0), + } + # new_matched_indices = match_for_single_decoder_layer(unmatched_outputs, self.matcher, untracked_gt_instances, unmatched_track_idxes, untracked_tgt_indexes) + new_matched_indices = match_for_single_decoder_layer(unmatched_outputs, self.matcher, untracked_gt_instances, unmatched_track_idxes, untracked_tgt_indexes) + + # step5. update obj_idxes according to the new matching result. 分配GT的ID给track和GT所在的索引 + track_instances.obj_idxes[new_matched_indices[:, 0], ibn] = gt_ins.obj_ids[new_matched_indices[:, 1]].long() + track_instances.matched_gt_idxes[new_matched_indices[:, 0], ibn] = new_matched_indices[:, 1] + + # step6. calculate iou. + active_idxes = (track_instances.obj_idxes[:, ibn] >= 0) & (track_instances.matched_gt_idxes[:, ibn] >= 0) # 当前帧能够匹配到的目标 + active_track_boxes = track_instances.pred_boxes[active_idxes, ibn] + if len(active_track_boxes) > 0: + gt_boxes = gt_ins.boxes[track_instances.matched_gt_idxes[active_idxes, ibn]] + active_track_boxes = box_cxcywh_to_xyxy(active_track_boxes) + gt_boxes = box_cxcywh_to_xyxy(gt_boxes) + track_instances.iou[active_idxes, ibn] = matched_pairwise_iou(Boxes(active_track_boxes), Boxes(gt_boxes)) + + # step7. merge the unmatched pairs and the matched pairs. + matched_indices = torch.cat([new_matched_indices, prev_matched_indices], dim=0) + + # step8. calculate losses. + self.num_samples += len(gt_ins)*self.g_size + self.sample_device = pred_logits_i.device + outputs_i = { + 'pred_logits': pred_logits_i[:, ibn].unsqueeze(0), + 'pred_boxes': pred_boxes_i[:, ibn].unsqueeze(0), + } + for loss in self.losses: + new_track_loss = self.get_loss(loss, + outputs=outputs_i, + gt_instances=[gt_ins], + indices=[(matched_indices[:, 0], matched_indices[:, 1])], + num_boxes=1) + for key, value in new_track_loss.items(): + self.losses_dict['frame_{}_{}'.format(self._current_frame_idx, key)] += value + + if 'aux_outputs' in outputs: # 此处匹配时对新生儿时重新算对应关系的,不直接使用最后一层box输出的对应关系 + for i, aux_outputs in enumerate(outputs['aux_outputs']): + unmatched_outputs_layer = { + 'pred_logits': aux_outputs['pred_logits'][ibn, unmatched_track_idxes].unsqueeze(0), + 'pred_boxes': aux_outputs['pred_boxes'][ibn, unmatched_track_idxes].unsqueeze(0), + } + new_matched_indices_layer = match_for_single_decoder_layer(unmatched_outputs_layer, self.matcher, gt_ins[full_tgt_idxes], unmatched_track_idxes, full_tgt_idxes) + matched_indices_layer = torch.cat([new_matched_indices_layer, prev_matched_indices], dim=0) + outputs_layer = { + 'pred_logits': aux_outputs['pred_logits'][ibn].unsqueeze(0), + 'pred_boxes': aux_outputs['pred_boxes'][ibn].unsqueeze(0), + } + for loss in self.losses: + if loss == 'masks': + # Intermediate masks losses are too costly to compute, we ignore them. + continue + l_dict = self.get_loss(loss, + outputs_layer, + gt_instances=[gt_ins], + indices=[(matched_indices_layer[:, 0], matched_indices_layer[:, 1])], + num_boxes=1, ) + for key, value in l_dict.items(): + self.losses_dict['frame_{}_aux{}_{}'.format(self._current_frame_idx, i, key)] += value + + self._step() + return track_instances + + def forward(self, outputs, input_data: dict): + # losses of each frame are calculated during the model's forwarding and are outputted by the model as outputs['losses_dict]. + losses = outputs.pop("losses_dict") + num_samples = self.get_num_boxes(self.num_samples) + for loss_name, loss in losses.items(): + losses[loss_name] /= num_samples + return losses + + +class RuntimeTrackerBase(object): # 实际为一个跟踪ID分配器 + def __init__(self, score_thresh=0.6, filter_score_thresh=0.5, miss_tolerance=10): + self.score_thresh = score_thresh + self.filter_score_thresh = filter_score_thresh + self.miss_tolerance = miss_tolerance + self.max_obj_id = 0 + + def clear(self): + self.max_obj_id = 0 + + def update(self, track_instances: Instances, g_size=1): + assert track_instances.obj_idxes.shape[1] == 1 + + device = track_instances.obj_idxes.device + + num_queries = len(track_instances) + Cindx = torch.arange(num_queries, device=device).reshape(num_queries//g_size, g_size) + active_idxes = torch.full((num_queries,), False, dtype=torch.bool, device=device) + # active_idxes[Cindx[track_instances.scores.reshape(-1, g_size).max(-1)[0] >= self.score_thresh].view(-1)] = True + active_idxes[Cindx[track_instances.scores.reshape(-1, g_size).min(-1)[0] >= self.score_thresh].view(-1)] = True + # active_idxes[Cindx[track_instances.scores.reshape(-1, g_size).mean(-1) >= self.score_thresh].view(-1)] = True + track_instances.disappear_time[active_idxes] = 0 # 假如当前帧检测到目标,则disappear_time=0 + + active_debug = track_instances.scores.reshape(-1, g_size) >= self.score_thresh + if not (active_debug == active_debug[:,0:1]).any(): + print(track_instances.scores) + + new_obj = (track_instances.obj_idxes.reshape(-1) == -1) & (active_idxes) # 挑选新生儿,obj_idxes=-1表示为检测query + disappeared_obj = (track_instances.obj_idxes.reshape(-1) >= 0) & (~active_idxes) # 跟踪中假如置信度偏低则 disappear_time++ + num_new_objs = new_obj.sum().item() // g_size + + track_instances.obj_idxes[new_obj, 0] = (self.max_obj_id + torch.arange(num_new_objs, device=device)).view(-1, 1).repeat(1, g_size).view(-1) # 分配ID + self.max_obj_id += num_new_objs # max_obj_id为已有多人ID + + track_instances.disappear_time[disappeared_obj] += 1 + to_del = disappeared_obj & (track_instances.disappear_time[:, 0] >= self.miss_tolerance) # 假如当前帧检测不到,且消失很长时间,则把ID删掉, + track_instances.obj_idxes[to_del, 0] = -1 + + +class TrackerPostProcess(nn.Module): + """ This module converts the model's output into the format expected by the coco api""" + def __init__(self, g_size=1): + super().__init__() + + self.g_size = g_size + + @torch.no_grad() + def forward(self, track_instances: Instances, target_size) -> Instances: + """ Perform the computation + Parameters: + outputs: raw outputs of the model + target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch + For evaluation, this must be the original image size (before any data augmentation) + For visualization, this should be the image size after data augment, but before padding + """ + out_logits = track_instances.pred_logits + out_bbox = track_instances.pred_boxes + + # scores = out_logits[..., 0].sigmoid() + prob = out_logits.sigmoid() + if len(prob): + num_query, bn, cls_num = prob.shape + scores, labels = prob.reshape(-1, self.g_size, bn, cls_num).max(1)[0].reshape(-1, 1, bn, cls_num).repeat(1, self.g_size, 1, 1).reshape(-1, bn, cls_num).max(-1) + else: + scores = out_logits[..., 0].sigmoid() + labels = torch.full_like(scores, 0, dtype=torch.long) + + # convert to [x0, y0, x1, y1] format + boxes = box_cxcywh_to_xyxy(out_bbox) + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_size + scale_fct = torch.Tensor([img_w, img_h, img_w, img_h]).to(boxes) + boxes = boxes * scale_fct[None, :] + + track_instances.boxes = boxes + track_instances.scores = scores + mask = (track_instances.labels[:, 0] == -1) | (track_instances.labels[:, 0] == labels[:, 0]) + track_instances.labels = labels # torch.full_like(scores, 0) + # track_instances.remove('pred_logits') + # track_instances.remove('pred_boxes') + # if len(track_instances) != len(track_instances[mask]): + # print(track_instances) + track_instances = track_instances[mask] + + return track_instances + + +def img(): + image = np.ascontiguousarray(((img.tensors[0].permute(1,2,0).cpu()*torch.tensor([0.229, 0.224, 0.225])+torch.tensor([0.485, 0.456, 0.406]))*255).numpy().astype(np.uint8)) + img_h, img_w, _ = image.shape + bboxes = track_instances.ref_pts.cpu().numpy().reshape(-1, 2, 2) + bboxes[..., 0] *= img_w + bboxes[..., 1] *= img_h + bboxes[:, 0] -= bboxes[:, 1]/2 + bboxes[:, 1] += bboxes[:, 0] + import cv2 + for i in range(68): + image_copy = image.copy() + for box in bboxes[5*i:5*(i+1)]: + cv2.rectangle(image_copy, pt1 = (int(box[0, 0]), int(box[0, 1])), pt2 =(int(box[1, 0]), int(box[1, 1])), color = (0, 0, 255), thickness = 2) + cv2.imwrite('tmp2/%d.jpg'%i, image_copy) \ No newline at end of file diff --git a/projects/co_mot/modeling/mot_transformer.py b/projects/co_mot/modeling/mot_transformer.py new file mode 100755 index 00000000..d3785f11 --- /dev/null +++ b/projects/co_mot/modeling/mot_transformer.py @@ -0,0 +1,574 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-research. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + + +import copy +from typing import Optional, List +import math + +import torch +import torch.nn.functional as F +from torch import nn, Tensor +from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_ + +from detrex.utils import inverse_sigmoid + +from detrex.layers import MultiScaleDeformableAttention +# from .ops.modules import MSDeformAttn + +iter_for_debug = 0 + +class DeformableTransformer(nn.Module): + def __init__(self, d_model=256, nhead=8, + num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024, dropout=0.1, + activation="relu", return_intermediate_dec=False, + num_feature_levels=4, dec_n_points=4, enc_n_points=4, + two_stage=False, two_stage_num_proposals=300, decoder_self_cross=True, sigmoid_attn=False, + extra_track_attn=False, memory_bank=False, im2col_step=64): + super().__init__() + + self.new_frame_adaptor = None + self.d_model = d_model + self.nhead = nhead + self.two_stage = two_stage + self.two_stage_num_proposals = two_stage_num_proposals + + encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward, + dropout, activation, + num_feature_levels, nhead, enc_n_points, + sigmoid_attn=sigmoid_attn) + self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers) + + decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward, + dropout, activation, + num_feature_levels, nhead, dec_n_points, decoder_self_cross, + sigmoid_attn=sigmoid_attn, extra_track_attn=extra_track_attn, + memory_bank=memory_bank, im2col_step=im2col_step) + self.decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate_dec) + + self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) + + if two_stage: + self.enc_output = nn.Linear(d_model, d_model) + self.enc_output_norm = nn.LayerNorm(d_model) + self.pos_trans = nn.Linear(d_model * 2, d_model * 2) + self.pos_trans_norm = nn.LayerNorm(d_model * 2) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MultiScaleDeformableAttention): + m.init_weights() + normal_(self.level_embed) + + def get_proposal_pos_embed(self, proposals): + num_pos_feats = 128 + temperature = 10000 + scale = 2 * math.pi + + dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device) + dim_t = temperature ** ( + 2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats + ) + # N, L, 4 + proposals = proposals.sigmoid() * scale + # N, L, 4, 128 + pos = proposals[:, :, :, None] / dim_t + # N, L, 4, 64, 2 + pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2) + return pos + + def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes): + N_, S_, C_ = memory.shape + base_scale = 4.0 + proposals = [] + _cur = 0 + for lvl, (H_, W_) in enumerate(spatial_shapes): + mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1) + valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) + valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) + + grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), + torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device)) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + + scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2) + grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale + wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl) + proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) + proposals.append(proposal) + _cur += (H_ * W_) + output_proposals = torch.cat(proposals, 1) + output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) + output_proposals = torch.log(output_proposals / (1 - output_proposals)) + output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf')) + output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf')) + + output_memory = memory + output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) + output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) + output_memory = self.enc_output_norm(self.enc_output(output_memory)) + return output_memory, output_proposals + + def get_valid_ratio(self, mask): + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def forward(self, srcs, masks, pos_embeds, query_embed=None, ref_pts=None, mem_bank=None, mem_bank_pad_mask=None, attn_mask=None): + # assert self.two_stage or query_embed is not None + + # prepare input for encoder + src_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): # 把feat+mask+pos摊平,方便encode输入 + bs, c, h, w = src.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + src = src.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + pos_embed = pos_embed.flatten(2).transpose(1, 2) + lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) + lvl_pos_embed_flatten.append(lvl_pos_embed) + src_flatten.append(src) + mask_flatten.append(mask) + src_flatten = torch.cat(src_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) + + # encoder, 输入 feat+mask+pos, 经过encode输出h*w*256的特征 + memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten) + # prepare input for decoder + bs, _, c = memory.shape + if self.two_stage: + output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes) + + # hack implementation for two-stage Deformable DETR + enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory) + enc_outputs_coord_unact = self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals + + topk = self.two_stage_num_proposals + topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1] + topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) + topk_coords_unact = topk_coords_unact.detach() + reference_points_two = topk_coords_unact.sigmoid() + pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))) + _, tgt_two = torch.split(pos_trans_out, c, dim=2) + tgt = torch.cat([tgt_two, query_embed.unsqueeze(0).expand(bs, -1, -1)[:, topk:]], axis=1) + reference_points = torch.cat([reference_points_two, ref_pts.unsqueeze(0).expand(bs, -1, -1)[:, topk:]], axis=1) + init_reference_out = reference_points + else: + tgt = query_embed.transpose(0,1) + reference_points = ref_pts.transpose(0,1) + init_reference_out = reference_points + # decoder, 输入 query+query_pos(即reference_points)+encode的输出,输出N*256的特征 + hs, inter_references = self.decoder(tgt, reference_points, memory, + spatial_shapes, level_start_index, + valid_ratios, mask_flatten, + mem_bank, mem_bank_pad_mask, attn_mask) + + inter_references_out = inter_references + if self.two_stage: + return hs, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact + return hs, init_reference_out, inter_references_out, None, None + + +class DeformableTransformerEncoderLayer(nn.Module): + def __init__(self, + d_model=256, d_ffn=1024, + dropout=0.1, activation="relu", + n_levels=4, n_heads=8, n_points=4, sigmoid_attn=False): + super().__init__() + + # self attention + self.self_attn = MultiScaleDeformableAttention( + embed_dim=d_model, + num_heads=n_heads, + num_levels=n_levels, + dropout=False, + batch_first=True, + num_points=n_points, + img2col_step=64, + ) + # self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points, sigmoid_attn=sigmoid_attn) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = _get_activation_fn(activation) + self.dropout_relu = ReLUDropout(dropout, True) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout3 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, src): + src2 = self.linear2(self.dropout_relu(self.linear1(src))) + src = src + self.dropout3(src2) + src = self.norm2(src) + return src + + def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None): + # self attention + identity = torch.zeros_like(src) + src2 = self.self_attn(src, query_pos=pos, identity=identity, reference_points=reference_points, value=src, spatial_shapes=spatial_shapes, level_start_index=level_start_index, key_padding_mask=padding_mask) + src = src + self.dropout1(src2) + src = self.norm1(src) + + # ffn + src = self.forward_ffn(src) + return src + + +class DeformableTransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, device): + reference_points_list = [] + for lvl, (H_, W_) in enumerate(spatial_shapes): + + ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), + torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) + ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) + ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None): + output = src + reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device) # 由于这里使用的transformer-detr,需要先确定参考点,实际grid + for _, layer in enumerate(self.layers): + output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask) + + return output + + +class ReLUDropout(torch.nn.Dropout): + def forward(self, input): + return relu_dropout(input, p=self.p, training=self.training, inplace=self.inplace) + +def relu_dropout(x, p=0, inplace=False, training=False): + if not training or p == 0: + return x.clamp_(min=0) if inplace else x.clamp(min=0) + + mask = (x < 0) | (torch.rand_like(x) > 1 - p) + return x.masked_fill_(mask, 0).div_(1 - p) if inplace else x.masked_fill(mask, 0).div(1 - p) + + +class DeformableTransformerDecoderLayer(nn.Module): + def __init__(self, d_model=256, d_ffn=1024, + dropout=0.1, activation="relu", + n_levels=4, n_heads=8, n_points=4, self_cross=True, sigmoid_attn=False, + extra_track_attn=False, memory_bank=False, im2col_step=64): + super().__init__() + + self.self_cross = self_cross + self.num_head = n_heads + self.memory_bank = memory_bank + + # cross attention + self.cross_attn = MultiScaleDeformableAttention( + embed_dim=d_model, + num_heads=n_heads, + num_levels=n_levels, + dropout=False, + batch_first=True, + num_points=n_points, + img2col_step=im2col_step, + ) + # self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points, sigmoid_attn=sigmoid_attn, im2col_step=im2col_step) + self.dropout1 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(d_model) + + self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) # , add_zero_attn=True + # self.self_attn = MultiheadAttention(d_model, n_heads, dropout=dropout) # , add_zero_attn=True + self.dropout2 = nn.Dropout(dropout) + self.norm2 = nn.LayerNorm(d_model) + + # ffn + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = _get_activation_fn(activation) + self.dropout_relu = ReLUDropout(dropout, True) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout4 = nn.Dropout(dropout) + self.norm3 = nn.LayerNorm(d_model) + + # memory bank + if self.memory_bank: + self.temporal_attn = nn.MultiheadAttention(d_model, 8, dropout=0) + self.temporal_fc1 = nn.Linear(d_model, d_ffn) + self.temporal_fc2 = nn.Linear(d_ffn, d_model) + self.temporal_norm1 = nn.LayerNorm(d_model) + self.temporal_norm2 = nn.LayerNorm(d_model) + + position = torch.arange(5).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) + pe = torch.zeros(5, 1, d_model) + pe[:, 0, 0::2] = torch.sin(position * div_term) + pe[:, 0, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe) + + # update track query_embed + self.extra_track_attn = extra_track_attn + if self.extra_track_attn: + print('Training with Extra Self Attention in Every Decoder.', flush=True) + self.update_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) + self.dropout5 = nn.Dropout(dropout) + self.norm4 = nn.LayerNorm(d_model) + + if self_cross: + print('Training with Self-Cross Attention.') + else: + print('Training with Cross-Self Attention.') + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward_ffn(self, tgt): + tgt2 = self.linear2(self.dropout_relu(self.linear1(tgt))) + tgt = tgt + self.dropout4(tgt2) + tgt = self.norm3(tgt) + return tgt + + def _forward_self_attn(self, tgt, query_pos, attn_mask=None): + q = k = self.with_pos_embed(tgt, query_pos) + if attn_mask is not None: + tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1), + attn_mask=attn_mask)[0].transpose(0, 1) + else: + tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1) + tgt = tgt + self.dropout2(tgt2) + return self.norm2(tgt) + + def _forward_self_attn_output_weights(self, tgt, query_pos, attn_mask=None, lid=None): + q = k = self.with_pos_embed(tgt, query_pos) + if attn_mask is not None: + tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1), + attn_mask=attn_mask)[0].transpose(0, 1) + else: + tgt2, weigths = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1)) + tgt2 = tgt2.transpose(0, 1) + weigths = weigths[0].mean(0) + import numpy as np + np.save('tmp/weight_%08d_%d.txt'%(iter_for_debug, lid), weigths.cpu().numpy()) + tgt = tgt + self.dropout2(tgt2) + return self.norm2(tgt) + + def _forward_self_cross(self, tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index, + src_padding_mask=None, attn_mask=None): + + # self attention + # if attn_mask is not None: + # len_n_dt = sum(attn_mask[0]==False) + # tgt = torch.cat([self._forward_self_attn(tgt[:, :len_n_dt], query_pos[:, :len_n_dt]), self._forward_self_attn(tgt[:, len_n_dt:], query_pos[:, len_n_dt:])], dim=1) + # else: + tgt = self._forward_self_attn(tgt, query_pos, attn_mask) + # cross attention + tgt2 = self.cross_attn(tgt, + query_pos=query_pos, + identity=torch.zeros_like(tgt), + reference_points=reference_points, + value=src, + spatial_shapes=src_spatial_shapes, + level_start_index=level_start_index, + key_padding_mask=src_padding_mask + ) + # tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos), + # reference_points, + # src, src_spatial_shapes, level_start_index, src_padding_mask) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + # ffn + tgt = self.forward_ffn(tgt) + + return tgt + + def _forward_self_cross_output_weight(self, tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index, + src_padding_mask=None, attn_mask=None, lid=None): + + # self attention + # if attn_mask is not None: + # len_n_dt = sum(attn_mask[0]==False) + # tgt = torch.cat([self._forward_self_attn(tgt[:, :len_n_dt], query_pos[:, :len_n_dt]), self._forward_self_attn(tgt[:, len_n_dt:], query_pos[:, len_n_dt:])], dim=1) + # else: + tgt = self._forward_self_attn_output_weights(tgt, query_pos, attn_mask, lid=lid) + # cross attention + tgt2 = self.cross_attn(tgt, + query_pos=query_pos, + identity=torch.zeros_like(tgt), + reference_points=reference_points, + value=src, + spatial_shapes=src_spatial_shapes, + level_start_index=level_start_index, + key_padding_mask=src_padding_mask + ) + # tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos), + # reference_points, + # src, src_spatial_shapes, level_start_index, src_padding_mask) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + # ffn + tgt = self.forward_ffn(tgt) + + return tgt + + def _forward_cross_self(self, tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index, + src_padding_mask=None, attn_mask=None): + # cross attention + tgt2 = self.cross_attn(tgt, + query_pos=query_pos, + identity=torch.zeros_like(tgt), + reference_points=reference_points, + value=src, + spatial_shapes=src_spatial_shapes, + level_start_index=level_start_index, + key_padding_mask=src_padding_mask + ) + # tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos), + # reference_points, + # src, src_spatial_shapes, level_start_index, src_padding_mask) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + # self attention + tgt = self._forward_self_attn(tgt, query_pos, attn_mask) + # ffn + tgt = self.forward_ffn(tgt) + + return tgt + + def forward(self, tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index, src_padding_mask=None, mem_bank=None, mem_bank_pad_mask=None, attn_mask=None): + if self.self_cross: + return self._forward_self_cross(tgt, query_pos, reference_points, src, src_spatial_shapes, + level_start_index, src_padding_mask, attn_mask) + return self._forward_cross_self(tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index, + src_padding_mask, attn_mask) + + def forward_output_weight(self, tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index, src_padding_mask=None, mem_bank=None, mem_bank_pad_mask=None, attn_mask=None, lid=None): + if self.self_cross: + return self._forward_self_cross_output_weight(tgt, query_pos, reference_points, src, src_spatial_shapes, + level_start_index, src_padding_mask, attn_mask, lid=lid) + return self._forward_cross_self(tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index, + src_padding_mask, attn_mask) + + +def pos2posemb(pos, num_pos_feats=64, temperature=10000): + scale = 2 * math.pi + pos = pos * scale + dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device) + dim_t = temperature ** ( + 2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats + ) + + posemb = pos[..., None] / dim_t + posemb = torch.stack((posemb[..., 0::2].sin(), posemb[..., 1::2].cos()), dim=-1).flatten(-3) + return posemb + + +class DeformableTransformerDecoder(nn.Module): + def __init__(self, decoder_layer, num_layers, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.return_intermediate = return_intermediate + # hack implementation for iterative bounding box refinement and two-stage Deformable DETR + self.bbox_embed = None + self.class_embed = None + self.obj_embed = None + + def forward(self, tgt, reference_points, src, src_spatial_shapes, src_level_start_index, src_valid_ratios, + src_padding_mask=None, mem_bank=None, mem_bank_pad_mask=None, attn_mask=None): + output = tgt + + intermediate = [] + intermediate_reference_points = [] + for lid, layer in enumerate(self.layers): + if reference_points.shape[-1] == 4: # 参考点先缩放到图像大小维度 + reference_points_input = reference_points[:, :, None] \ + * torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None] + else: + assert reference_points.shape[-1] == 2 + reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None] + query_pos = pos2posemb(reference_points) # 把参考点转为query-position(参看DAB-DETR) + # if lid == 4: + # output = layer.forward_output_weight(output, query_pos, reference_points_input, src, src_spatial_shapes, + # src_level_start_index, src_padding_mask, mem_bank, mem_bank_pad_mask, attn_mask, lid=lid) + # else: + output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, + src_level_start_index, src_padding_mask, mem_bank, mem_bank_pad_mask, attn_mask) + + # hack implementation for iterative bounding box refinement + if self.bbox_embed is not None: # 迭代优化,获取下一次迭代的参考点 + tmp = self.bbox_embed[lid](output) + if reference_points.shape[-1] == 4: + new_reference_points = tmp + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + else: + assert reference_points.shape[-1] == 2 + new_reference_points = tmp + new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + + if self.return_intermediate: + intermediate.append(output) + intermediate_reference_points.append(reference_points) + + # 注意要删掉 + # global iter_for_debug + # iter_for_debug += 1 + + if self.return_intermediate: + return torch.stack(intermediate), torch.stack(intermediate_reference_points) + + return output, reference_points + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return nn.ReLU(True) + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") + + diff --git a/projects/co_mot/modeling/qim.py b/projects/co_mot/modeling/qim.py new file mode 100644 index 00000000..acb4e00f --- /dev/null +++ b/projects/co_mot/modeling/qim.py @@ -0,0 +1,209 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-research. All Rights Reserved. +# ------------------------------------------------------------------------ + +import math +import torch +from torch import nn + +from detrex.layers import box_cxcywh_to_xyxy +from detectron2.structures import Boxes, Instances, pairwise_iou + + +def random_drop_tracks(track_instances: Instances, drop_probability: float) -> Instances: + if drop_probability > 0 and len(track_instances) > 0: + keep_idxes = torch.rand_like(track_instances.scores) > drop_probability + track_instances = track_instances[keep_idxes] + return track_instances + + +class QueryInteractionBase(nn.Module): + def __init__(self, args, dim_in, hidden_dim, dim_out): + super().__init__() + self.args = args + self._build_layers(args, dim_in, hidden_dim, dim_out) + self._reset_parameters() + + def _build_layers(self, args, dim_in, hidden_dim, dim_out): + raise NotImplementedError() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def _select_active_tracks(self, data: dict) -> Instances: + raise NotImplementedError() + + def _update_track_embedding(self, track_instances): + raise NotImplementedError() + + +class FFN(nn.Module): + def __init__(self, d_model, d_ffn, dropout=0): + super().__init__() + self.linear1 = nn.Linear(d_model, d_ffn) + self.activation = nn.ReLU(True) + self.dropout1 = nn.Dropout(dropout) + self.linear2 = nn.Linear(d_ffn, d_model) + self.dropout2 = nn.Dropout(dropout) + self.norm = nn.LayerNorm(d_model) + + def forward(self, tgt): + tgt2 = self.linear2(self.dropout1(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm(tgt) + return tgt + +class QueryInteractionModuleGroup(QueryInteractionBase): + def __init__(self, args, dim_in, hidden_dim, dim_out): + super().__init__(args, dim_in, hidden_dim, dim_out) + self.random_drop = args.random_drop + self.fp_ratio = args.fp_ratio + self.update_query_pos = args.update_query_pos + self.score_thr = 0.5 + + def _build_layers(self, args, dim_in, hidden_dim, dim_out): + dropout = args.merger_dropout + + self.self_attn = nn.MultiheadAttention(dim_in, 8, dropout) + self.linear1 = nn.Linear(dim_in, hidden_dim) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(hidden_dim, dim_in) + + if args.update_query_pos: + self.linear_pos1 = nn.Linear(dim_in, hidden_dim) + self.linear_pos2 = nn.Linear(hidden_dim, dim_in) + self.dropout_pos1 = nn.Dropout(dropout) + self.dropout_pos2 = nn.Dropout(dropout) + self.norm_pos = nn.LayerNorm(dim_in) + + self.linear_feat1 = nn.Linear(dim_in, hidden_dim) + self.linear_feat2 = nn.Linear(hidden_dim, dim_in) + self.dropout_feat1 = nn.Dropout(dropout) + self.dropout_feat2 = nn.Dropout(dropout) + self.norm_feat = nn.LayerNorm(dim_in) + + self.norm1 = nn.LayerNorm(dim_in) + self.norm2 = nn.LayerNorm(dim_in) + if args.update_query_pos: + self.norm3 = nn.LayerNorm(dim_in) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + if args.update_query_pos: + self.dropout3 = nn.Dropout(dropout) + self.dropout4 = nn.Dropout(dropout) + + self.activation = nn.ReLU(True) + + def _random_drop_tracks(self, track_instances: Instances) -> Instances: # 随机删掉track + return random_drop_tracks(track_instances, self.random_drop) + + def _add_fp_tracks(self, track_instances: Instances, active_track_instances: Instances) -> Instances: # 随机添加track(选择与跟踪框最大iou),表示消失儿 + inactive_instances = track_instances[track_instances.obj_idxes < 0] + + # add fp for each active track in a specific probability. + fp_prob = torch.ones_like(active_track_instances.scores) * self.fp_ratio + selected_active_track_instances = active_track_instances[torch.bernoulli(fp_prob).bool()] # torch.bernoulli提取二进制随机数 + + if len(inactive_instances) > 0 and len(selected_active_track_instances) > 0: + num_fp = len(selected_active_track_instances) # 添加的个数 + if num_fp >= len(inactive_instances): + fp_track_instances = inactive_instances + else: + inactive_boxes = Boxes(box_cxcywh_to_xyxy(inactive_instances.pred_boxes)) + selected_active_boxes = Boxes(box_cxcywh_to_xyxy(selected_active_track_instances.pred_boxes)) + ious = pairwise_iou(inactive_boxes, selected_active_boxes) + # select the fp with the largest IoU for each active track. + fp_indexes = ious.max(dim=0).indices + + # remove duplicate fp. + fp_indexes = torch.unique(fp_indexes) + fp_track_instances = inactive_instances[fp_indexes] + + merged_track_instances = Instances.cat([active_track_instances, fp_track_instances]) + return merged_track_instances + + return active_track_instances + + def _select_active_tracks(self, data: dict, g_size=1) -> Instances: + track_instances: Instances = data['track_instances'] + if self.training: + num_queries, bs = track_instances.obj_idxes.shape[:2] + min_prev_target_ind = min(sum((track_instances.obj_idxes.reshape(-1, g_size, bs)>=0).any(1) | (track_instances.scores.reshape(-1, g_size, bs)> 0.5).any(1) )) + + active_track_instances = [] + for i in range(bs): + topk_proposals = torch.topk(track_instances.scores.reshape(-1, g_size, bs).min(1)[0][..., i], min_prev_target_ind)[1] + index_all = torch.full((num_queries//g_size, g_size), False, dtype=torch.bool, device=topk_proposals.device) + index_all[topk_proposals] = True + index_all = index_all.reshape(-1) + active_track_instances.append(track_instances[index_all]) + + active_track_instances = Instances.merge(active_track_instances) + + # active_idxes = (track_instances.obj_idxes >= 0) | (track_instances.scores > 0.5) + # active_idxes = active_idxes.reshape(-1, g_size).any(dim=1).view(-1, 1).repeat(1, g_size).view(-1) + # active_track_instances = track_instances[active_idxes] + del_idxes = active_track_instances.iou <= 0.5 + del_idxes = del_idxes.reshape(-1, g_size, bs).any(dim=1).view(-1, 1, bs).repeat(1, g_size, 1).view(-1, bs) + active_track_instances.obj_idxes[del_idxes] = -1 + else: + assert track_instances.obj_idxes.shape[1] == 1 + active_idxes = track_instances.obj_idxes[:, 0] >= 0 + active_idxes = active_idxes.reshape(-1, g_size).any(dim=1).view(-1, 1).repeat(1, g_size).view(-1) + # active_idxes = active_idxes.reshape(-1, g_size).all(dim=1).view(-1, 1).repeat(1, g_size).view(-1) + active_track_instances = track_instances[active_idxes] + + return active_track_instances + + def _update_track_embedding(self, track_instances: Instances) -> Instances: + is_pos = track_instances.scores > self.score_thr + track_instances.ref_pts[is_pos] = track_instances.pred_boxes.detach().clone()[is_pos] + + out_embed = track_instances.output_embedding + query_feat = track_instances.query_pos + query_pos = pos2posemb(track_instances.ref_pts) + q = k = query_pos + out_embed + + tgt = out_embed + tgt2 = self.self_attn(q, k, value=tgt)[0] # bacht + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + + # if self.update_query_pos: + # query_pos2 = self.linear_pos2(self.dropout_pos1(self.activation(self.linear_pos1(tgt)))) + # query_pos = query_pos + self.dropout_pos2(query_pos2) + # query_pos = self.norm_pos(query_pos) + # track_instances.query_pos = query_pos + + query_feat2 = self.linear_feat2(self.dropout_feat1(self.activation(self.linear_feat1(tgt)))) + query_feat = query_feat + self.dropout_feat2(query_feat2) + query_feat = self.norm_feat(query_feat) + track_instances.query_pos[is_pos] = query_feat[is_pos] + + return track_instances + + def forward(self, data, g_size=1) -> Instances: + active_track_instances = self._select_active_tracks(data, g_size) # 选择活的(即有ID的目标,因为之前已经经过score的判断为活的目标分配了ID) + active_track_instances = self._update_track_embedding(active_track_instances) # 根据update_query_pos的不同(仅对当前帧置信度高的目标更新embedding,有ID的目标可能有当前帧消失,但前几帧存在的目标) + return active_track_instances + + +def pos2posemb(pos, num_pos_feats=64, temperature=10000): + scale = 2 * math.pi + pos = pos * scale + dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device) + dim_t = temperature ** ( + 2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats + ) + + posemb = pos[..., None] / dim_t + posemb = torch.stack((posemb[..., 0::2].sin(), posemb[..., 1::2].cos()), dim=-1).flatten(-3) + return posemb + diff --git a/projects/co_mot/train_net.py b/projects/co_mot/train_net.py new file mode 100644 index 00000000..08521f4f --- /dev/null +++ b/projects/co_mot/train_net.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python +# Copyright (c) Facebook, Inc. and its affiliates. +""" +Training script using the new "LazyConfig" python config files. + +This scripts reads a given python config file and runs the training or evaluation. +It can be used to train any models or dataset as long as they can be +instantiated by the recursive construction defined in the given config file. + +Besides lazy construction of models, dataloader, etc., this scripts expects a +few common configuration parameters currently defined in "configs/common/train.py". +To add more complicated training logic, you can easily add other configs +in the config file and implement a new train_net.py to handle them. +""" +import logging +import os +import sys +import time +import torch +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from detectron2.checkpoint import DetectionCheckpointer +from detectron2.config import LazyConfig, instantiate +from detectron2.engine import ( + SimpleTrainer, + default_argument_parser, + default_setup, + default_writers, + hooks, + launch, +) +from detectron2.engine.defaults import create_ddp_model +from detectron2.evaluation import inference_on_dataset, print_csv_format +from detectron2.utils import comm + +from projects.co_mot.util.misc import data_dict_to_cuda + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) + +logger = logging.getLogger("detrex") + + +def match_name_keywords(n, name_keywords): + out = False + for b in name_keywords: + if b in n: + out = True + break + return out + + +class Trainer(SimpleTrainer): + """ + We've combine Simple and AMP Trainer together. + """ + + def __init__( + self, + model, + dataloader, + optimizer, + amp=False, + clip_grad_params=None, + grad_scaler=None, + ): + super().__init__(model=model, data_loader=dataloader, optimizer=optimizer) + + unsupported = "AMPTrainer does not support single-process multi-device training!" + if isinstance(model, DistributedDataParallel): + assert not (model.device_ids and len(model.device_ids) > 1), unsupported + assert not isinstance(model, DataParallel), unsupported + + if amp: + if grad_scaler is None: + from torch.cuda.amp import GradScaler + + grad_scaler = GradScaler() + self.grad_scaler = grad_scaler + + # set True to use amp training + self.amp = amp + + # gradient clip hyper-params + self.clip_grad_params = clip_grad_params + + self.device = self.model.device + + def run_step(self): + """ + Implement the standard training logic described above. + """ + assert self.model.training, "[Trainer] model was changed to eval mode!" + assert torch.cuda.is_available(), "[Trainer] CUDA is required for AMP training!" + from torch.cuda.amp import autocast + + start = time.perf_counter() + """ + If you want to do something with the data, you can wrap the dataloader. + """ + data = next(self._data_loader_iter) + data_time = time.perf_counter() - start + + """ + If you want to do something with the losses, you can wrap the model. + """ + data = data_dict_to_cuda(data, self.device) + loss_dict = self.model(data) + with autocast(enabled=self.amp): + if isinstance(loss_dict, torch.Tensor): + losses = loss_dict + loss_dict = {"total_loss": loss_dict} + else: + losses = sum(loss_dict.values()) + + """ + If you need to accumulate gradients or do something similar, you can + wrap the optimizer with your custom `zero_grad()` method. + """ + self.optimizer.zero_grad() + + if self.amp: + self.grad_scaler.scale(losses).backward() + if self.clip_grad_params is not None: + self.grad_scaler.unscale_(self.optimizer) + self.clip_grads(self.model.parameters()) + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() + else: + losses.backward() + if self.clip_grad_params is not None: + self.clip_grads(self.model.parameters()) + self.optimizer.step() + + self._write_metrics(loss_dict, data_time) + + def clip_grads(self, params): + params = list(filter(lambda p: p.requires_grad and p.grad is not None, params)) + if len(params) > 0: + return torch.nn.utils.clip_grad_norm_( + parameters=params, + **self.clip_grad_params, + ) + + +def do_test(cfg, model): + if "evaluator" in cfg.dataloader: + ret = inference_on_dataset( + model, instantiate(cfg.dataloader.test), instantiate(cfg.dataloader.evaluator) + ) + print_csv_format(ret) + return ret + + +def do_train(args, cfg): + """ + Args: + cfg: an object with the following attributes: + model: instantiate to a module + dataloader.{train,test}: instantiate to dataloaders + dataloader.evaluator: instantiate to evaluator for test set + optimizer: instantaite to an optimizer + lr_multiplier: instantiate to a fvcore scheduler + train: other misc config defined in `configs/common/train.py`, including: + output_dir (str) + init_checkpoint (str) + amp.enabled (bool) + max_iter (int) + eval_period, log_period (int) + device (str) + checkpointer (dict) + ddp (dict) + """ + model = instantiate(cfg.model) + logger = logging.getLogger("detectron2") + logger.info("Model:\n{}".format(model)) + model.to(cfg.train.device) + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + logger.info('number of params: {}'.format(n_parameters)) + + # this is an hack of train_net + param_dicts = [ + { + "params": [ + p + for n, p in model.named_parameters() + if not match_name_keywords(n, cfg.train.lr_backbone_names) + and not match_name_keywords(n, cfg.train.lr_linear_proj_names) + and p.requires_grad + ], + "lr": cfg.optimizer.lr, + }, + { + "params": [ + p + for n, p in model.named_parameters() + if match_name_keywords(n, cfg.train.lr_backbone_names) and p.requires_grad + ], + "lr": cfg.optimizer.lr_backbone, + }, + { + "params": [ + p + for n, p in model.named_parameters() + if match_name_keywords(n, cfg.train.lr_linear_proj_names) + and p.requires_grad + ], + "lr": cfg.optimizer.lr * cfg.optimizer.lr_linear_proj_mult, + }, + ] + if cfg.optimizer.sgd: + optim = torch.optim.SGD(param_dicts, lr=cfg.optimizer.lr, momentum=0.9, + weight_decay=cfg.optimizer.weight_decay) + else: + optim = torch.optim.AdamW(param_dicts, lr=cfg.optimizer.lr, + weight_decay=cfg.optimizer.weight_decay) + + train_loader = instantiate(cfg.dataloader.train) + + model = create_ddp_model(model, **cfg.train.ddp) + + trainer = Trainer( + model=model, + dataloader=train_loader, + optimizer=optim, + amp=cfg.train.amp.enabled, # default False + clip_grad_params=cfg.train.clip_grad.params if cfg.train.clip_grad.enabled else None, # default False + ) + + checkpointer = DetectionCheckpointer( + model, + cfg.train.output_dir, + trainer=trainer, + ) + + trainer.register_hooks( + [ + hooks.IterationTimer(), + hooks.LRScheduler(scheduler=instantiate(cfg.lr_multiplier)), + hooks.PeriodicCheckpointer(checkpointer, **cfg.train.checkpointer) if comm.is_main_process() else None, + hooks.EvalHook(cfg.train.eval_period, lambda: do_test(cfg, model)), + hooks.PeriodicWriter( + default_writers(cfg.train.output_dir, cfg.train.max_iter), + period=cfg.train.log_period, + ) if comm.is_main_process() else None, + ] + ) + + checkpointer.resume_or_load(cfg.train.init_checkpoint, resume=args.resume) + if args.resume and checkpointer.has_checkpoint(): + # The checkpoint stores the training iteration that just finished, thus we start + # at the next iteration + start_iter = trainer.iter + 1 + else: + start_iter = 0 + trainer.train(start_iter, cfg.train.max_iter) + + +def main(args): + cfg = LazyConfig.load(args.config_file) + cfg = LazyConfig.apply_overrides(cfg, args.opts) + default_setup(cfg, args) + + if args.eval_only: + model = instantiate(cfg.model) + model.to(cfg.train.device) + model = create_ddp_model(model) + DetectionCheckpointer(model).load(cfg.train.init_checkpoint) + print(do_test(cfg, model)) + else: + do_train(args, cfg) + + +if __name__ == "__main__": + args = default_argument_parser().parse_args() + launch( + main, + args.num_gpus, + num_machines=args.num_machines, + machine_rank=args.machine_rank, + dist_url=args.dist_url, + args=(args,), + ) diff --git a/projects/co_mot/util/__init__.py b/projects/co_mot/util/__init__.py new file mode 100755 index 00000000..11752457 --- /dev/null +++ b/projects/co_mot/util/__init__.py @@ -0,0 +1,10 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-research. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + diff --git a/projects/co_mot/util/checkpoint.py b/projects/co_mot/util/checkpoint.py new file mode 100644 index 00000000..7b166ad9 --- /dev/null +++ b/projects/co_mot/util/checkpoint.py @@ -0,0 +1,40 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-research. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from pytorch-checkpoint (https://github.com/csrhddlam/pytorch-checkpoint) +# ------------------------------------------------------------------------ + +import torch + + +def check_require_grad(t): + return isinstance(t, torch.Tensor) and t.requires_grad + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + for i in range(len(ctx.input_tensors)): + temp = ctx.input_tensors[i] + if check_require_grad(temp): + ctx.input_tensors[i] = temp.detach() + ctx.input_tensors[i].requires_grad = temp.requires_grad + with torch.enable_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + to_autograd = list(filter(check_require_grad, ctx.input_tensors)) + output_tensors, output_grads = zip(*filter(lambda t: t[0].requires_grad, zip(output_tensors, output_grads))) + input_grads = torch.autograd.grad(output_tensors, to_autograd + ctx.input_params, output_grads, allow_unused=True) + input_grads = list(input_grads) + for i in range(len(ctx.input_tensors)): + if not check_require_grad(ctx.input_tensors[i]): + input_grads.insert(i, None) + return (None, None) + tuple(input_grads) diff --git a/projects/co_mot/util/misc.py b/projects/co_mot/util/misc.py new file mode 100755 index 00000000..7f088a6b --- /dev/null +++ b/projects/co_mot/util/misc.py @@ -0,0 +1,164 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-research. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) +# Copyright (c) 2020 SenseTime. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# ------------------------------------------------------------------------ + + +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import os +import subprocess +import time +from collections import OrderedDict, defaultdict, deque +import datetime +import pickle +from typing import Optional, List + +import torch +import torch.nn as nn +import torch.distributed as dist +from torch import Tensor +from functools import partial +from detectron2.structures import Instances + + +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision + + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor], size_divisibility: int = 0): + # TODO make this more general + if tensor_list[0].ndim == 3: + # TODO make it support different-sized images + + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + if size_divisibility > 0: + stride = size_divisibility + # the last two dims are H,W, both subject to divisibility requirement + max_size[-1] = (max_size[-1] + (stride - 1)) // stride * stride + max_size[-2] = (max_size[-2] + (stride - 1)) // stride * stride + + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], :img.shape[2]] = False + else: + raise ValueError('not supported') + return NestedTensor(tensor, mask) + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device, non_blocking=False): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device, non_blocking=non_blocking) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device, non_blocking=non_blocking) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def record_stream(self, *args, **kwargs): + self.tensors.record_stream(*args, **kwargs) + if self.mask is not None: + self.mask.record_stream(*args, **kwargs) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def to_cuda(samples, targets, device): + samples = samples.to(device, non_blocking=True) + targets = [{k: v.to(device, non_blocking=True) for k, v in t.items()} for t in targets] + return samples, targets + + +def tensor_to_cuda(tensor: torch.Tensor, device): + return tensor.to(device) + + +def is_tensor_or_instances(data): + return isinstance(data, torch.Tensor) or isinstance(data, Instances) + + +def data_apply(data, check_func, apply_func): + if isinstance(data, dict): + for k in data.keys(): + if check_func(data[k]): + data[k] = apply_func(data[k]) + elif isinstance(data[k], dict) or isinstance(data[k], list): + data_apply(data[k], check_func, apply_func) + elif isinstance(data[k], int): + pass + else: + raise ValueError() + elif isinstance(data, list): + for i in range(len(data)): + if check_func(data[i]): + data[i] = apply_func(data[i]) + elif isinstance(data[i], dict) or isinstance(data[i], list): + data_apply(data[i], check_func, apply_func) + elif isinstance(data[i], int): + pass + else: + raise ValueError("invalid type {}".format(type(data[i]))) + elif isinstance(data, int): + pass + else: + raise ValueError("invalid type {}".format(type(data))) + return data + + +def data_dict_to_cuda(data_dict, device): + return data_apply(data_dict, is_tensor_or_instances, partial(tensor_to_cuda, device=device)) diff --git a/requirements.txt b/requirements.txt index 01292a42..b3198171 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,5 @@ scipy==1.7.3 psutil opencv-python wandb -submitit \ No newline at end of file +submitit +git+https://github.com/JonathonLuiten/TrackEval.git \ No newline at end of file