-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
315 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
from typing import Union | ||
|
||
import depthai as dai | ||
|
||
from depthai_nodes.ml.messages import ImgDetectionExtended, ImgDetectionsExtended | ||
|
||
from .host_spatials_calc import HostSpatialsCalc | ||
|
||
|
||
class DepthMerger(dai.node.HostNode): | ||
"""DepthMerger is a custom host node for merging 2D detections with depth | ||
information to produce spatial detections. | ||
Attributes | ||
---------- | ||
output : dai.Node.Output | ||
The output of the DepthMerger node containing dai.SpatialImgDetections. | ||
Usage | ||
----- | ||
depth_merger = pipeline.create(DepthMerger).build( | ||
output_2d=nn.out, | ||
output_depth=stereo.depth | ||
) | ||
""" | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
self.output = self.createOutput( | ||
possibleDatatypes=[ | ||
dai.Node.DatatypeHierarchy(dai.DatatypeEnum.SpatialImgDetections, True) | ||
] | ||
) | ||
|
||
self.shrinking_factor = 0 | ||
|
||
def build( | ||
self, | ||
output_2d: dai.Node.Output, | ||
output_depth: dai.Node.Output, | ||
calib_data: dai.CalibrationHandler, | ||
depth_alignment_socket: dai.CameraBoardSocket = dai.CameraBoardSocket.CAM_A, | ||
shrinking_factor: float = 0, | ||
) -> "DepthMerger": | ||
self.link_args(output_2d, output_depth) | ||
self.shrinking_factor = shrinking_factor | ||
self.host_spatials_calc = HostSpatialsCalc(calib_data, depth_alignment_socket) | ||
return self | ||
|
||
def process(self, message_2d: dai.Buffer, depth: dai.ImgFrame) -> None: | ||
spatial_dets = self._transform(message_2d, depth) | ||
self.output.send(spatial_dets) | ||
|
||
def _transform( | ||
self, message_2d: dai.Buffer, depth: dai.ImgFrame | ||
) -> Union[dai.SpatialImgDetections, dai.SpatialImgDetection]: | ||
"""Transforms 2D detections into spatial detections based on the depth frame.""" | ||
if isinstance(message_2d, dai.ImgDetection): | ||
return self._detection_to_spatial(message_2d, depth) | ||
elif isinstance(message_2d, dai.ImgDetections): | ||
return self._detections_to_spatial(message_2d, depth) | ||
elif isinstance(message_2d, ImgDetectionExtended): | ||
return self._detection_to_spatial(message_2d, depth) | ||
elif isinstance(message_2d, ImgDetectionsExtended): | ||
return self._detections_to_spatial(message_2d, depth) | ||
else: | ||
raise ValueError(f"Unknown message type: {type(message_2d)}") | ||
|
||
def _detection_to_spatial( | ||
self, | ||
detection: Union[dai.ImgDetection, ImgDetectionExtended], | ||
depth: dai.ImgFrame, | ||
) -> dai.SpatialImgDetection: | ||
"""Converts a single 2D detection into a spatial detection using the depth | ||
frame.""" | ||
depth_frame = depth.getCvFrame() | ||
x_len = depth_frame.shape[1] | ||
y_len = depth_frame.shape[0] | ||
xmin = ( | ||
detection.rotated_rect.getOuterRect()[0] | ||
if isinstance(detection, ImgDetectionExtended) | ||
else detection.xmin | ||
) | ||
ymin = ( | ||
detection.rotated_rect.getOuterRect()[1] | ||
if isinstance(detection, ImgDetectionExtended) | ||
else detection.ymin | ||
) | ||
xmax = ( | ||
detection.rotated_rect.getOuterRect()[2] | ||
if isinstance(detection, ImgDetectionExtended) | ||
else detection.xmax | ||
) | ||
ymax = ( | ||
detection.rotated_rect.getOuterRect()[3] | ||
if isinstance(detection, ImgDetectionExtended) | ||
else detection.ymax | ||
) | ||
xmin += (xmax - xmin) * self.shrinking_factor | ||
ymin += (ymax - ymin) * self.shrinking_factor | ||
xmax -= (xmax - xmin) * self.shrinking_factor | ||
ymax -= (ymax - ymin) * self.shrinking_factor | ||
roi = [ | ||
self._get_index(xmin, x_len), | ||
self._get_index(ymin, y_len), | ||
self._get_index(xmax, x_len), | ||
self._get_index(ymax, y_len), | ||
] | ||
spatials = self.host_spatials_calc.calc_spatials(depth, roi) | ||
|
||
spatial_img_detection = dai.SpatialImgDetection() | ||
spatial_img_detection.xmin = xmin | ||
spatial_img_detection.ymin = ymin | ||
spatial_img_detection.xmax = xmax | ||
spatial_img_detection.ymax = ymax | ||
spatial_img_detection.spatialCoordinates = dai.Point3f( | ||
spatials["x"], spatials["y"], spatials["z"] | ||
) | ||
|
||
spatial_img_detection.confidence = detection.confidence | ||
spatial_img_detection.label = 0 if detection.label == -1 else detection.label | ||
return spatial_img_detection | ||
|
||
def _detections_to_spatial( | ||
self, | ||
detections: Union[dai.ImgDetections, ImgDetectionsExtended], | ||
depth: dai.ImgFrame, | ||
) -> dai.SpatialImgDetections: | ||
"""Converts multiple 2D detections into spatial detections using the depth | ||
frame.""" | ||
new_dets = dai.SpatialImgDetections() | ||
new_dets.detections = [ | ||
self._detection_to_spatial(d, depth) for d in detections.detections | ||
] | ||
new_dets.setSequenceNum(detections.getSequenceNum()) | ||
new_dets.setTimestamp(detections.getTimestamp()) | ||
return new_dets | ||
|
||
def _get_index(self, relative_coord: float, dimension_len: int) -> int: | ||
"""Converts a relative coordinate to an absolute index within the given | ||
dimension length.""" | ||
bounded_coord = min(1, relative_coord) | ||
return max(0, int(bounded_coord * dimension_len) - 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
# HostSpatialsCalc implementation taken from here: | ||
# https://github.com/luxonis/depthai-experiments/blob/d10736715bef1663d984196f8528610a614e4b75/gen2-calc-spatials-on-host/calc.py | ||
|
||
from typing import Dict, List | ||
|
||
import depthai as dai | ||
import numpy as np | ||
|
||
|
||
class HostSpatialsCalc: | ||
"""HostSpatialsCalc is a helper class for calculating spatial coordinates from depth | ||
data. | ||
Attributes | ||
---------- | ||
calibData : dai.CalibrationHandler | ||
Calibration data handler for the device. | ||
depth_alignment_socket : dai.CameraBoardSocket | ||
The camera socket used for depth alignment. | ||
DELTA : int | ||
The delta value for ROI calculation. | ||
THRESH_LOW : int | ||
The lower threshold for depth values. | ||
THRESH_HIGH : int | ||
The upper threshold for depth values. | ||
setLowerThreshold(threshold_low): Sets the lower threshold for depth values. | ||
setUpperThreshold(threshold_high): Sets the upper threshold for depth values. | ||
setDeltaRoi(delta): Sets the delta value for ROI calculation. | ||
_check_input(roi, frame): Checks if the input is ROI or point and converts point to ROI if necessary. | ||
calc_spatials(depthData, roi, averaging_method): Calculates spatial coordinates from depth data within the specified ROI. | ||
""" | ||
|
||
# We need device object to get calibration data | ||
def __init__( | ||
self, | ||
calib_data: dai.CalibrationHandler, | ||
depth_alignment_socket: dai.CameraBoardSocket = dai.CameraBoardSocket.CAM_A, | ||
): | ||
self.calibData = calib_data | ||
self.depth_alignment_socket = depth_alignment_socket | ||
|
||
# Values | ||
self.DELTA = 5 # Take 10x10 depth pixels around point for depth averaging | ||
self.THRESH_LOW = 200 # 20cm | ||
self.THRESH_HIGH = 30000 # 30m | ||
|
||
def setLowerThreshold(self, threshold_low: int) -> None: | ||
"""Sets the lower threshold for depth values. | ||
@param threshold_low: The lower threshold for depth values. | ||
@type threshold_low: int | ||
""" | ||
if not isinstance(threshold_low, int): | ||
if isinstance(threshold_low, float): | ||
threshold_low = int(threshold_low) | ||
else: | ||
raise TypeError( | ||
"Threshold has to be an integer or float! Got {}".format( | ||
type(threshold_low) | ||
) | ||
) | ||
self.THRESH_LOW = threshold_low | ||
|
||
def setUpperThreshold(self, threshold_high: int) -> None: | ||
"""Sets the upper threshold for depth values. | ||
@param threshold_high: The upper threshold for depth values. | ||
@type threshold_high: int | ||
""" | ||
if not isinstance(threshold_high, int): | ||
if isinstance(threshold_high, float): | ||
threshold_high = int(threshold_high) | ||
else: | ||
raise TypeError( | ||
"Threshold has to be an integer or float! Got {}".format( | ||
type(threshold_high) | ||
) | ||
) | ||
self.THRESH_HIGH = threshold_high | ||
|
||
def setDeltaRoi(self, delta: int) -> None: | ||
"""Sets the delta value for ROI calculation. | ||
@param delta: The delta value for ROI calculation. | ||
@type delta: int | ||
""" | ||
if not isinstance(delta, int): | ||
if isinstance(delta, float): | ||
delta = int(delta) | ||
else: | ||
raise TypeError( | ||
"Delta has to be an integer or float! Got {}".format(type(delta)) | ||
) | ||
self.DELTA = delta | ||
|
||
def _check_input(self, roi: List[int], frame: np.ndarray) -> List[int]: | ||
"""Checks if the input is ROI or point and converts point to ROI if necessary. | ||
@param roi: The region of interest (ROI) or point. | ||
@type roi: List[int] | ||
@param frame: The depth frame. | ||
@type frame: np.ndarray | ||
@return: The region of interest (ROI). | ||
@rtype: List[int] | ||
""" | ||
if len(roi) == 4: | ||
return roi | ||
if len(roi) != 2: | ||
raise ValueError( | ||
"You have to pass either ROI (4 values) or point (2 values)!" | ||
) | ||
# Limit the point so ROI won't be outside the frame | ||
x = min(max(roi[0], self.DELTA), frame.shape[1] - self.DELTA) | ||
y = min(max(roi[1], self.DELTA), frame.shape[0] - self.DELTA) | ||
return (x - self.DELTA, y - self.DELTA, x + self.DELTA, y + self.DELTA) | ||
|
||
# roi has to be list of ints | ||
def calc_spatials( | ||
self, | ||
depthData: dai.ImgFrame, | ||
roi: List[int], | ||
averaging_method: callable = np.mean, | ||
) -> Dict[str, float]: | ||
"""Calculates spatial coordinates from depth data within the specified ROI. | ||
@param depthData: The depth data. | ||
@type depthData: dai.ImgFrame | ||
@param roi: The region of interest (ROI) or point. | ||
@type roi: List[int] | ||
@param averaging_method: The method for averaging the depth values. | ||
@type averaging_method: callable | ||
@return: The spatial coordinates. | ||
@rtype: Dict[str, float] | ||
""" | ||
depthFrame = depthData.getFrame() | ||
|
||
roi = self._check_input( | ||
roi, depthFrame | ||
) # If point was passed, convert it to ROI | ||
xmin, ymin, xmax, ymax = roi | ||
|
||
# Calculate the average depth in the ROI. | ||
depthROI = depthFrame[ymin:ymax, xmin:xmax] | ||
inRange = (self.THRESH_LOW <= depthROI) & (depthROI <= self.THRESH_HIGH) | ||
|
||
averageDepth = averaging_method(depthROI[inRange]) | ||
|
||
centroid = np.array( # Get centroid of the ROI | ||
[ | ||
int((xmax + xmin) / 2), | ||
int((ymax + ymin) / 2), | ||
] | ||
) | ||
|
||
K = self.calibData.getCameraIntrinsics( | ||
cameraId=self.depth_alignment_socket, | ||
resizeWidth=depthFrame.shape[1], | ||
resizeHeight=depthFrame.shape[0], | ||
) | ||
K = np.array(K) | ||
K_inv = np.linalg.inv(K) | ||
homogenous_coords = np.array([centroid[0], centroid[1], 1]) | ||
spatial_coords = averageDepth * K_inv.dot(homogenous_coords) | ||
|
||
spatials = { | ||
"x": spatial_coords[0], | ||
"y": spatial_coords[1], | ||
"z": spatial_coords[2], | ||
} | ||
return spatials |