From b163653fbfed1df3ea890f276e6e2103557da369 Mon Sep 17 00:00:00 2001 From: Paul Gesel Date: Wed, 4 Dec 2024 13:42:37 -0700 Subject: [PATCH] Add example SAM 2 behavior and objective (#23) * Add SAM2 behavior and ONNX models --------- Signed-off-by: Paul Gesel Co-authored-by: Griswald Brooks --- .gitattributes | 1 + src/example_behaviors/CMakeLists.txt | 11 +- .../example_behaviors/sam2_segmentation.hpp | 67 +++++++++ src/example_behaviors/models/decoder.onnx | 3 + .../models/sam2_hiera_large_encoder.onnx | 3 + .../src/register_behaviors.cpp | 12 +- .../src/sam2_segmentation.cpp | 142 ++++++++++++++++++ src/lab_sim/objectives/run_sam2_onnx.xml | 41 +++++ 8 files changed, 271 insertions(+), 9 deletions(-) create mode 100644 .gitattributes create mode 100644 src/example_behaviors/include/example_behaviors/sam2_segmentation.hpp create mode 100644 src/example_behaviors/models/decoder.onnx create mode 100644 src/example_behaviors/models/sam2_hiera_large_encoder.onnx create mode 100644 src/example_behaviors/src/sam2_segmentation.cpp create mode 100644 src/lab_sim/objectives/run_sam2_onnx.xml diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..0bb75f73 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +*.onnx filter=lfs diff=lfs merge=lfs -text diff --git a/src/example_behaviors/CMakeLists.txt b/src/example_behaviors/CMakeLists.txt index 21be9500..98129a10 100644 --- a/src/example_behaviors/CMakeLists.txt +++ b/src/example_behaviors/CMakeLists.txt @@ -15,6 +15,7 @@ example_interfaces) foreach(package IN ITEMS ${THIS_PACKAGE_INCLUDE_DEPENDS}) find_package(${package} REQUIRED) endforeach() +find_package(moveit_pro_ml REQUIRED) add_library( example_behaviors @@ -22,15 +23,16 @@ add_library( src/add_two_ints_service_client.cpp src/convert_mtc_solution_to_joint_trajectory.cpp src/delayed_message.cpp - src/get_string_from_topic.cpp src/fibonacci_action_client.cpp + src/get_string_from_topic.cpp src/hello_world.cpp + src/ndt_registration.cpp src/publish_color_rgba.cpp + src/ransac_registration.cpp + src/sam2_segmentation.cpp src/setup_mtc_pick_from_pose.cpp src/setup_mtc_place_from_pose.cpp src/setup_mtc_wave_hand.cpp - src/ndt_registration.cpp - src/ransac_registration.cpp src/register_behaviors.cpp) target_include_directories( example_behaviors @@ -39,6 +41,7 @@ target_include_directories( PRIVATE ${PCL_INCLUDE_DIRS}) ament_target_dependencies(example_behaviors ${THIS_PACKAGE_INCLUDE_DEPENDS}) +target_link_libraries(example_behaviors onnx_sam2) # Install Libraries install( @@ -50,7 +53,7 @@ install( INCLUDES DESTINATION include) -install(DIRECTORY config DESTINATION share/${PROJECT_NAME}) +install(DIRECTORY config models DESTINATION share/${PROJECT_NAME}) if(BUILD_TESTING) moveit_pro_behavior_test(example_behaviors) diff --git a/src/example_behaviors/include/example_behaviors/sam2_segmentation.hpp b/src/example_behaviors/include/example_behaviors/sam2_segmentation.hpp new file mode 100644 index 00000000..06fed8be --- /dev/null +++ b/src/example_behaviors/include/example_behaviors/sam2_segmentation.hpp @@ -0,0 +1,67 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + + +namespace example_behaviors +{ +/** + * @brief Segment an image using the SAM 2 model + */ +class SAM2Segmentation : public moveit_studio::behaviors::AsyncBehaviorBase +{ +public: +/** + * @brief Constructor for the SAM2Segmentation behavior. + * @param name The name of a particular instance of this Behavior. This will be set by the behavior tree factory when this Behavior is created within a new behavior tree. + * @param config This contains runtime configuration info for this Behavior, such as the mapping between the Behavior's data ports on the behavior tree's blackboard. This will be set by the behavior tree factory when this Behavior is created within a new behavior tree. + * @details An important limitation is that the members of the base Behavior class are not instantiated until after the initialize() function is called, so these classes should not be used within the constructor. + */ + SAM2Segmentation(const std::string& name, const BT::NodeConfiguration& config, + const std::shared_ptr& shared_resources); + + /** + * @brief Implementation of the required providedPorts() function for the Behavior. + * @details The BehaviorTree.CPP library requires that Behaviors must implement a static function named providedPorts() which defines their input and output ports. If the Behavior does not use any ports, this function must return an empty BT::PortsList. + * This function returns a list of ports with their names and port info, which is used internally by the behavior tree. + * @return List of ports for the behavior. + */ + static BT::PortsList providedPorts(); + + /** + * @brief Implementation of the metadata() function for displaying metadata, such as Behavior description and + * subcategory, in the MoveIt Studio Developer Tool. + * @return A BT::KeyValueVector containing the Behavior metadata. + */ + static BT::KeyValueVector metadata(); + +protected: + tl::expected doWork() override; + + +private: + std::unique_ptr sam2_; + moveit_pro_ml::ONNXImage onnx_image_; + sensor_msgs::msg::Image mask_image_msg_; + moveit_studio_vision_msgs::msg::Mask2D mask_msg_; + + /** @brief Classes derived from AsyncBehaviorBase must implement getFuture() so that it returns a shared_future class member */ + std::shared_future>& getFuture() override + { + return future_; + } + + /** @brief Classes derived from AsyncBehaviorBase must have this shared_future as a class member */ + std::shared_future> future_; + +}; +} // namespace sam2_segmentation diff --git a/src/example_behaviors/models/decoder.onnx b/src/example_behaviors/models/decoder.onnx new file mode 100644 index 00000000..fc4bc327 --- /dev/null +++ b/src/example_behaviors/models/decoder.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1f448cdb479e6ec14e61c4756138eb4081ce7f8a11ca43a0a24856d5e8b61b6f +size 20665365 diff --git a/src/example_behaviors/models/sam2_hiera_large_encoder.onnx b/src/example_behaviors/models/sam2_hiera_large_encoder.onnx new file mode 100644 index 00000000..cace9759 --- /dev/null +++ b/src/example_behaviors/models/sam2_hiera_large_encoder.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c99ab89a38385753aff7ea9155f0808ad5535bc55ea2a49320254e39e4011630 +size 889364590 diff --git a/src/example_behaviors/src/register_behaviors.cpp b/src/example_behaviors/src/register_behaviors.cpp index c3d26677..9755dc04 100644 --- a/src/example_behaviors/src/register_behaviors.cpp +++ b/src/example_behaviors/src/register_behaviors.cpp @@ -2,18 +2,19 @@ #include #include -#include +#include #include #include -#include -#include #include #include +#include +#include #include +#include +#include #include #include -#include -#include +#include #include @@ -35,6 +36,7 @@ class ExampleBehaviorsLoader : public moveit_studio::behaviors::SharedResourcesN moveit_studio::behaviors::registerBehavior(factory, "FibonacciActionClient", shared_resources); moveit_studio::behaviors::registerBehavior(factory, "PublishColorRGBA", shared_resources); + moveit_studio::behaviors::registerBehavior(factory, "SAM2Segmentation", shared_resources); moveit_studio::behaviors::registerBehavior(factory, "SetupMtcPickFromPose", shared_resources); moveit_studio::behaviors::registerBehavior(factory, "SetupMtcPlaceFromPose", shared_resources); diff --git a/src/example_behaviors/src/sam2_segmentation.cpp b/src/example_behaviors/src/sam2_segmentation.cpp new file mode 100644 index 00000000..a7e27f6d --- /dev/null +++ b/src/example_behaviors/src/sam2_segmentation.cpp @@ -0,0 +1,142 @@ +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace +{ + constexpr auto kPortImage = "image"; + constexpr auto kPortImageDefault = "{image}"; + constexpr auto kPortPoint = "pixel_coords"; + constexpr auto kPortPointDefault = "{pixel_coords}"; + constexpr auto kPortMasks = "masks2d"; + constexpr auto kPortMasksDefault = "{masks2d}"; + + constexpr auto kImageInferenceWidth = 1024; + constexpr auto kImageInferenceHeight = 1024; +} // namespace + +namespace example_behaviors +{ + // Convert a ROS image message to the ONNX image format used by the SAM 2 model + void set_onnx_image_from_ros_image(const sensor_msgs::msg::Image& image_msg, + moveit_pro_ml::ONNXImage& onnx_image) + { + onnx_image.shape = {1, image_msg.height, image_msg.width, 3}; + onnx_image.data.resize(image_msg.height * image_msg.width * 3); + const int stride = image_msg.encoding != "rgb8" ? 3: 4; + for (size_t i = 0; i < onnx_image.data.size(); i+=stride) + { + onnx_image.data[i] = static_cast(image_msg.data[i]) / 255.0f; + onnx_image.data[i+1] = static_cast(image_msg.data[i+1]) / 255.0f; + onnx_image.data[i+2] = static_cast(image_msg.data[i+2]) / 255.0f; + } + } + + // Converts a single channel ONNX image mask to a ROS mask message. + void set_ros_mask_from_onnx_mask(const moveit_pro_ml::ONNXImage& onnx_image, sensor_msgs::msg::Image& mask_image_msg, moveit_studio_vision_msgs::msg::Mask2D& mask_msg) + { + mask_image_msg.height = static_cast(onnx_image.shape[0]); + mask_image_msg.width = static_cast(onnx_image.shape[1]); + mask_image_msg.encoding = "mono8"; + mask_image_msg.data.resize(mask_image_msg.height * mask_image_msg.width); + mask_image_msg.step = mask_image_msg.width; + for (size_t i = 0; i < onnx_image.data.size(); ++i) + { + mask_image_msg.data[i] = onnx_image.data[i] > 0.5 ? 255: 0; + } + mask_msg.pixels = mask_image_msg; + mask_msg.x = 0; + mask_msg.y = 0; + } + + SAM2Segmentation::SAM2Segmentation(const std::string& name, const BT::NodeConfiguration& config, + const std::shared_ptr& shared_resources) + : moveit_studio::behaviors::AsyncBehaviorBase(name, config, shared_resources) + { + + const std::filesystem::path package_path = ament_index_cpp::get_package_share_directory("example_behaviors"); + const std::filesystem::path encoder_onnx_file = package_path / "models" / "sam2_hiera_large_encoder.onnx"; + const std::filesystem::path decoder_onnx_file = package_path / "models" / "decoder.onnx"; + sam2_ = std::make_unique(encoder_onnx_file, decoder_onnx_file); + } + + BT::PortsList SAM2Segmentation::providedPorts() + { + return { + BT::InputPort(kPortImage, kPortImageDefault, + "The Image to run segmentation on."), + BT::InputPort>(kPortPoint, kPortPointDefault, + "The input points, as a vector of geometry_msgs/PointStamped messages to be used for segmentation."), + + BT::OutputPort>(kPortMasks, kPortMasksDefault, + "The masks contained in a vector of moveit_studio_vision_msgs::msg::Mask2D messages.") + }; + } + + tl::expected SAM2Segmentation::doWork() + { + const auto ports = moveit_studio::behaviors::getRequiredInputs(getInput(kPortImage), + getInput>(kPortPoint)); + + // Check that all required input data ports were set. + if (!ports.has_value()) + { + auto error_message = fmt::format("Failed to get required values from input data ports:\n{}", ports.error()); + return tl::make_unexpected(error_message); + } + const auto& [image_msg, points_2d] = ports.value(); + + if (image_msg.encoding != "rgb8" && image_msg.encoding != "rgba8") + { + auto error_message = fmt::format("Invalid image message format. Expected `(rgb8, rgba8)` got :\n{}", image_msg.encoding); + return tl::make_unexpected(error_message); + } + + // Create ONNX formatted image tensor from ROS image + set_onnx_image_from_ros_image(image_msg, onnx_image_); + + std::vector point_prompts; + for (auto const& point : points_2d) + { + // Assume all points are the same label + point_prompts.push_back({{kImageInferenceWidth*static_cast(point.point.x), kImageInferenceHeight*static_cast(point.point.y)}, {1.0f}}); + } + + try + { + const auto masks = sam2_->predict(onnx_image_, point_prompts); + + mask_image_msg_.header = image_msg.header; + set_ros_mask_from_onnx_mask(masks, mask_image_msg_, mask_msg_); + + setOutput>(kPortMasks, {mask_msg_}); + } + catch (const std::invalid_argument& e) + { + return tl::make_unexpected(fmt::format("Invalid argument: {}", e.what())); + } + + return true; + } + + BT::KeyValueVector SAM2Segmentation::metadata() + { + return { + { + "description", + "Segments a ROS image message using the provided points represented as a vector of geometry_msgs/PointStamped messages." + } + }; + } +} // namespace sam2_segmentation diff --git a/src/lab_sim/objectives/run_sam2_onnx.xml b/src/lab_sim/objectives/run_sam2_onnx.xml new file mode 100644 index 00000000..2661486b --- /dev/null +++ b/src/lab_sim/objectives/run_sam2_onnx.xml @@ -0,0 +1,41 @@ + + + + + + + + + + + + + + + + + + + + + + + +