From 2c392d941d3d14126df840e1186955ed4ae0ec9c Mon Sep 17 00:00:00 2001 From: jkbmrz Date: Tue, 30 Jul 2024 17:16:51 +0200 Subject: [PATCH] feat: add support for HRNet model --- depthai_nodes/ml/parsers/__init__.py | 2 + depthai_nodes/ml/parsers/hrnet.py | 62 ++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) create mode 100644 depthai_nodes/ml/parsers/hrnet.py diff --git a/depthai_nodes/ml/parsers/__init__.py b/depthai_nodes/ml/parsers/__init__.py index 77795e77..54ea9185 100644 --- a/depthai_nodes/ml/parsers/__init__.py +++ b/depthai_nodes/ml/parsers/__init__.py @@ -11,6 +11,7 @@ from .xfeat import XFeatParser from .yunet import YuNetParser from .age_gender import AgeGenderParser +from .hrnet import HRNetParser __all__ = [ "ImageOutputParser", @@ -26,4 +27,5 @@ "XFeatParser", "ThermalImageParser", "AgeGenderParser", + "HRNetParser", ] diff --git a/depthai_nodes/ml/parsers/hrnet.py b/depthai_nodes/ml/parsers/hrnet.py new file mode 100644 index 00000000..081b2dd0 --- /dev/null +++ b/depthai_nodes/ml/parsers/hrnet.py @@ -0,0 +1,62 @@ +import depthai as dai +import numpy as np +import cv2 + +from ..messages.creators import create_keypoints_message + + +class HRNetParser(dai.node.ThreadedHostNode): + def __init__(self, score_threshold=0.5, input_size=[256, 256], heatmap_size=[64, 64]): + dai.node.ThreadedHostNode.__init__(self) + self.input = dai.Node.Input(self) + self.out = dai.Node.Output(self) + self.input_size = input_size + self.heatmap_size = heatmap_size + self.score_threshold = score_threshold + + def setScoreThreshold(self, threshold): + self.score_threshold = threshold + + def run(self): + """Postprocessing logic for HRNet pose estimation model. The code is inspired by https://github.com/ibaiGorordo/ONNX-HRNET-Human-Pose-Estimation + + Returns: + ... + """ + + while self.isRunning(): + try: + output: dai.NNData = self.input.get() + except dai.MessageQueue.QueueException: + break # Pipeline was stopped + + img_width, img_height = self.input_size + + heatmaps = output.getTensor("heatmaps", dequantize=True) + + if len(heatmaps.shape) == 4: # add new axis for batch size + heatmaps = heatmaps[0] + + if heatmaps.shape[2] == 16: # HW_ instead of _HW + heatmaps = heatmaps.transpose(2, 0, 1) + + _, map_h, map_w = heatmaps.shape + + # Find the maximum value in each of the heatmaps and its location + max_vals = np.array([np.max(heatmap) for heatmap in heatmaps]) + keypoints = np.array([np.unravel_index(heatmap.argmax(), heatmap.shape) + for heatmap in heatmaps]) + keypoints = keypoints.astype(np.float32) + keypoints[max_vals < self.score_threshold] = np.array([np.nan, np.nan]) + + # Scale keypoints to the image size + # TODO: remove and have relative keypoint values? e.g. * np.array([64 / map_w, 64 / map_h]) to get relative values? + keypoints = keypoints[:, ::-1] * np.array([img_width / map_w, img_height / map_h]) + + keypoints_msg = create_keypoints_message( + keypoints=keypoints, + #scores=max_vals, # TODO: add scores + #confidence_threshold=self.confidence_threshold # TODO: add confidence threshold + ) + + self.out.send(keypoints_msg)