Skip to content

Commit

Permalink
Add DepthMerger host node.
Browse files Browse the repository at this point in the history
  • Loading branch information
kkeroo committed Jan 24, 2025
1 parent b013402 commit b04f1bd
Show file tree
Hide file tree
Showing 3 changed files with 315 additions and 0 deletions.
Empty file.
144 changes: 144 additions & 0 deletions depthai_nodes/ml/helpers/depth/depth_merger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from typing import Union

import depthai as dai

from depthai_nodes.ml.messages import ImgDetectionExtended, ImgDetectionsExtended

from .host_spatials_calc import HostSpatialsCalc


class DepthMerger(dai.node.HostNode):
"""DepthMerger is a custom host node for merging 2D detections with depth
information to produce spatial detections.
Attributes
----------
output : dai.Node.Output
The output of the DepthMerger node containing dai.SpatialImgDetections.
Usage
-----
depth_merger = pipeline.create(DepthMerger).build(
output_2d=nn.out,
output_depth=stereo.depth
)
"""

def __init__(self) -> None:
super().__init__()

self.output = self.createOutput(
possibleDatatypes=[
dai.Node.DatatypeHierarchy(dai.DatatypeEnum.SpatialImgDetections, True)
]
)

self.shrinking_factor = 0

def build(
self,
output_2d: dai.Node.Output,
output_depth: dai.Node.Output,
calib_data: dai.CalibrationHandler,
depth_alignment_socket: dai.CameraBoardSocket = dai.CameraBoardSocket.CAM_A,
shrinking_factor: float = 0,
) -> "DepthMerger":
self.link_args(output_2d, output_depth)
self.shrinking_factor = shrinking_factor
self.host_spatials_calc = HostSpatialsCalc(calib_data, depth_alignment_socket)
return self

def process(self, message_2d: dai.Buffer, depth: dai.ImgFrame) -> None:
spatial_dets = self._transform(message_2d, depth)
self.output.send(spatial_dets)

def _transform(
self, message_2d: dai.Buffer, depth: dai.ImgFrame
) -> Union[dai.SpatialImgDetections, dai.SpatialImgDetection]:
"""Transforms 2D detections into spatial detections based on the depth frame."""
if isinstance(message_2d, dai.ImgDetection):
return self._detection_to_spatial(message_2d, depth)
elif isinstance(message_2d, dai.ImgDetections):
return self._detections_to_spatial(message_2d, depth)
elif isinstance(message_2d, ImgDetectionExtended):
return self._detection_to_spatial(message_2d, depth)
elif isinstance(message_2d, ImgDetectionsExtended):
return self._detections_to_spatial(message_2d, depth)
else:
raise ValueError(f"Unknown message type: {type(message_2d)}")

def _detection_to_spatial(
self,
detection: Union[dai.ImgDetection, ImgDetectionExtended],
depth: dai.ImgFrame,
) -> dai.SpatialImgDetection:
"""Converts a single 2D detection into a spatial detection using the depth
frame."""
depth_frame = depth.getCvFrame()
x_len = depth_frame.shape[1]
y_len = depth_frame.shape[0]
xmin = (
detection.rotated_rect.getOuterRect()[0]
if isinstance(detection, ImgDetectionExtended)
else detection.xmin
)
ymin = (
detection.rotated_rect.getOuterRect()[1]
if isinstance(detection, ImgDetectionExtended)
else detection.ymin
)
xmax = (
detection.rotated_rect.getOuterRect()[2]
if isinstance(detection, ImgDetectionExtended)
else detection.xmax
)
ymax = (
detection.rotated_rect.getOuterRect()[3]
if isinstance(detection, ImgDetectionExtended)
else detection.ymax
)
xmin += (xmax - xmin) * self.shrinking_factor
ymin += (ymax - ymin) * self.shrinking_factor
xmax -= (xmax - xmin) * self.shrinking_factor
ymax -= (ymax - ymin) * self.shrinking_factor
roi = [
self._get_index(xmin, x_len),
self._get_index(ymin, y_len),
self._get_index(xmax, x_len),
self._get_index(ymax, y_len),
]
spatials = self.host_spatials_calc.calc_spatials(depth, roi)

spatial_img_detection = dai.SpatialImgDetection()
spatial_img_detection.xmin = xmin
spatial_img_detection.ymin = ymin
spatial_img_detection.xmax = xmax
spatial_img_detection.ymax = ymax
spatial_img_detection.spatialCoordinates = dai.Point3f(
spatials["x"], spatials["y"], spatials["z"]
)

spatial_img_detection.confidence = detection.confidence
spatial_img_detection.label = 0 if detection.label == -1 else detection.label
return spatial_img_detection

def _detections_to_spatial(
self,
detections: Union[dai.ImgDetections, ImgDetectionsExtended],
depth: dai.ImgFrame,
) -> dai.SpatialImgDetections:
"""Converts multiple 2D detections into spatial detections using the depth
frame."""
new_dets = dai.SpatialImgDetections()
new_dets.detections = [
self._detection_to_spatial(d, depth) for d in detections.detections
]
new_dets.setSequenceNum(detections.getSequenceNum())
new_dets.setTimestamp(detections.getTimestamp())
return new_dets

def _get_index(self, relative_coord: float, dimension_len: int) -> int:
"""Converts a relative coordinate to an absolute index within the given
dimension length."""
bounded_coord = min(1, relative_coord)
return max(0, int(bounded_coord * dimension_len) - 1)
171 changes: 171 additions & 0 deletions depthai_nodes/ml/helpers/depth/host_spatials_calc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# HostSpatialsCalc implementation taken from here:
# https://github.com/luxonis/depthai-experiments/blob/d10736715bef1663d984196f8528610a614e4b75/gen2-calc-spatials-on-host/calc.py

from typing import Dict, List

import depthai as dai
import numpy as np


class HostSpatialsCalc:
"""HostSpatialsCalc is a helper class for calculating spatial coordinates from depth
data.
Attributes
----------
calibData : dai.CalibrationHandler
Calibration data handler for the device.
depth_alignment_socket : dai.CameraBoardSocket
The camera socket used for depth alignment.
DELTA : int
The delta value for ROI calculation.
THRESH_LOW : int
The lower threshold for depth values.
THRESH_HIGH : int
The upper threshold for depth values.
setLowerThreshold(threshold_low): Sets the lower threshold for depth values.
setUpperThreshold(threshold_high): Sets the upper threshold for depth values.
setDeltaRoi(delta): Sets the delta value for ROI calculation.
_check_input(roi, frame): Checks if the input is ROI or point and converts point to ROI if necessary.
calc_spatials(depthData, roi, averaging_method): Calculates spatial coordinates from depth data within the specified ROI.
"""

# We need device object to get calibration data
def __init__(
self,
calib_data: dai.CalibrationHandler,
depth_alignment_socket: dai.CameraBoardSocket = dai.CameraBoardSocket.CAM_A,
):
self.calibData = calib_data
self.depth_alignment_socket = depth_alignment_socket

# Values
self.DELTA = 5 # Take 10x10 depth pixels around point for depth averaging
self.THRESH_LOW = 200 # 20cm
self.THRESH_HIGH = 30000 # 30m

def setLowerThreshold(self, threshold_low: int) -> None:
"""Sets the lower threshold for depth values.
@param threshold_low: The lower threshold for depth values.
@type threshold_low: int
"""
if not isinstance(threshold_low, int):
if isinstance(threshold_low, float):
threshold_low = int(threshold_low)
else:
raise TypeError(
"Threshold has to be an integer or float! Got {}".format(
type(threshold_low)
)
)
self.THRESH_LOW = threshold_low

def setUpperThreshold(self, threshold_high: int) -> None:
"""Sets the upper threshold for depth values.
@param threshold_high: The upper threshold for depth values.
@type threshold_high: int
"""
if not isinstance(threshold_high, int):
if isinstance(threshold_high, float):
threshold_high = int(threshold_high)
else:
raise TypeError(
"Threshold has to be an integer or float! Got {}".format(
type(threshold_high)
)
)
self.THRESH_HIGH = threshold_high

def setDeltaRoi(self, delta: int) -> None:
"""Sets the delta value for ROI calculation.
@param delta: The delta value for ROI calculation.
@type delta: int
"""
if not isinstance(delta, int):
if isinstance(delta, float):
delta = int(delta)
else:
raise TypeError(
"Delta has to be an integer or float! Got {}".format(type(delta))
)
self.DELTA = delta

def _check_input(self, roi: List[int], frame: np.ndarray) -> List[int]:
"""Checks if the input is ROI or point and converts point to ROI if necessary.
@param roi: The region of interest (ROI) or point.
@type roi: List[int]
@param frame: The depth frame.
@type frame: np.ndarray
@return: The region of interest (ROI).
@rtype: List[int]
"""
if len(roi) == 4:
return roi
if len(roi) != 2:
raise ValueError(
"You have to pass either ROI (4 values) or point (2 values)!"
)
# Limit the point so ROI won't be outside the frame
x = min(max(roi[0], self.DELTA), frame.shape[1] - self.DELTA)
y = min(max(roi[1], self.DELTA), frame.shape[0] - self.DELTA)
return (x - self.DELTA, y - self.DELTA, x + self.DELTA, y + self.DELTA)

# roi has to be list of ints
def calc_spatials(
self,
depthData: dai.ImgFrame,
roi: List[int],
averaging_method: callable = np.mean,
) -> Dict[str, float]:
"""Calculates spatial coordinates from depth data within the specified ROI.
@param depthData: The depth data.
@type depthData: dai.ImgFrame
@param roi: The region of interest (ROI) or point.
@type roi: List[int]
@param averaging_method: The method for averaging the depth values.
@type averaging_method: callable
@return: The spatial coordinates.
@rtype: Dict[str, float]
"""
depthFrame = depthData.getFrame()

roi = self._check_input(
roi, depthFrame
) # If point was passed, convert it to ROI
xmin, ymin, xmax, ymax = roi

# Calculate the average depth in the ROI.
depthROI = depthFrame[ymin:ymax, xmin:xmax]
inRange = (self.THRESH_LOW <= depthROI) & (depthROI <= self.THRESH_HIGH)

averageDepth = averaging_method(depthROI[inRange])

centroid = np.array( # Get centroid of the ROI
[
int((xmax + xmin) / 2),
int((ymax + ymin) / 2),
]
)

K = self.calibData.getCameraIntrinsics(
cameraId=self.depth_alignment_socket,
resizeWidth=depthFrame.shape[1],
resizeHeight=depthFrame.shape[0],
)
K = np.array(K)
K_inv = np.linalg.inv(K)
homogenous_coords = np.array([centroid[0], centroid[1], 1])
spatial_coords = averageDepth * K_inv.dot(homogenous_coords)

spatials = {
"x": spatial_coords[0],
"y": spatial_coords[1],
"z": spatial_coords[2],
}
return spatials

0 comments on commit b04f1bd

Please sign in to comment.