Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Humble enjoy mppi critic pub and critics goal distance fix #41

Open
wants to merge 13 commits into
base: humble
Choose a base branch
from
14 changes: 14 additions & 0 deletions nav2_mppi_controller/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ set(XTENSOR_USE_XSIMD 1)
find_package(ament_cmake REQUIRED)
find_package(xtensor REQUIRED)
find_package(xsimd REQUIRED)
find_package(rosidl_default_generators REQUIRED)

include_directories(
include
Expand All @@ -33,6 +34,7 @@ set(dependencies_pkgs
tf2_geometry_msgs
tf2_eigen
tf2_ros
std_msgs
)

foreach(pkg IN LISTS dependencies_pkgs)
Expand All @@ -41,6 +43,12 @@ endforeach()

nav2_package()

rosidl_generate_interfaces(${PROJECT_NAME}
"msg/CriticScore.msg"
"msg/CriticScores.msg"
DEPENDENCIES std_msgs
)

include(CheckCXXCompilerFlag)

check_cxx_compiler_flag("-mno-avx512f" COMPILER_SUPPORTS_AVX512)
Expand Down Expand Up @@ -120,8 +128,14 @@ if(BUILD_TESTING)
# add_subdirectory(benchmark)
endif()

rosidl_get_typesupport_target(cpp_typesupport_target
${PROJECT_NAME} rosidl_typesupport_cpp)

target_link_libraries(mppi_controller "${cpp_typesupport_target}")

ament_export_libraries(${libraries})
ament_export_dependencies(${dependencies_pkgs})
ament_export_dependencies(rosidl_default_runtime)
ament_export_include_directories(include)
pluginlib_export_plugin_description_file(nav2_core mppic.xml)
pluginlib_export_plugin_description_file(nav2_mppi_controller critics.xml)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include "nav2_mppi_controller/tools/trajectory_visualizer.hpp"
#include "nav2_mppi_controller/models/constraints.hpp"
#include "nav2_mppi_controller/tools/utils.hpp"
#include "nav2_mppi_controller/msg/critic_score.hpp"
#include "nav2_mppi_controller/msg/critic_scores.hpp"

#include "nav2_core/controller.hpp"
#include "nav2_core/goal_checker.hpp"
Expand Down Expand Up @@ -121,10 +123,14 @@ class MPPIController : public nav2_core::Controller
TrajectoryVisualizer trajectory_visualizer_;

bool visualize_;
bool publish_critics_;

double reset_period_;
// Last time computeVelocityCommands was called
rclcpp::Time last_time_called_;

std::shared_ptr<rclcpp_lifecycle::LifecyclePublisher<nav2_mppi_controller::msg::CriticScores>>
critics_publisher_;
};

} // namespace nav2_mppi_controller
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,15 @@ namespace mppi

/**
* @struct mppi::CriticData
* @brief Data to pass to critics for scoring, including state, trajectories, path, costs, and
* important parameters to share
* @brief Data to pass to critics for scoring, including state, trajectories,
* pruned path, global goal, costs, and important parameters to share
*/
struct CriticData
{
const models::State & state;
const models::Trajectories & trajectories;
const models::Path & path;
const geometry_msgs::msg::Pose & goal;

xt::xtensor<float, 1> & costs;
float & model_dt;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ class CriticManager
* @brief Constructor for mppi::CriticManager
*/
CriticManager() = default;

/**
* @brief Virtual Destructor for mppi::CriticManager
*/
virtual ~CriticManager() = default;

/**
* @brief Configure critic manager on bringup and load plugins
* @param parent WeakPtr to node
Expand All @@ -69,6 +69,10 @@ class CriticManager
*/
void evalTrajectoriesScores(CriticData & data) const;

xt::xtensor<float, 1> evalTrajectory(CriticData & data) const;

std::vector<std::string> getCriticNames() const;

protected:
/**
* @brief Get parameters (critics to load)
Expand Down
20 changes: 16 additions & 4 deletions nav2_mppi_controller/include/nav2_mppi_controller/optimizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class Optimizer
geometry_msgs::msg::TwistStamped evalControl(
const geometry_msgs::msg::PoseStamped & robot_pose,
const geometry_msgs::msg::Twist & robot_speed, const nav_msgs::msg::Path & plan,
nav2_core::GoalChecker * goal_checker);
const geometry_msgs::msg::Pose & goal, nav2_core::GoalChecker * goal_checker);

/**
* @brief Get the trajectories generated in a cycle for visualization
Expand All @@ -104,6 +104,15 @@ class Optimizer
*/
xt::xtensor<float, 2> getOptimizedTrajectory();


/**
* @brief Get the critic costs for given trajectory
* @return Names and costs of the critics
*/
xt::xtensor<float, 1> getOptimizationResults();

std::vector<std::string> getCriticNames() const;

/**
* @brief Set the maximum speed based on the speed limits callback
* @param speed_limit Limit of the speed for use
Expand Down Expand Up @@ -132,7 +141,8 @@ class Optimizer
void prepare(
const geometry_msgs::msg::PoseStamped & robot_pose,
const geometry_msgs::msg::Twist & robot_speed,
const nav_msgs::msg::Path & plan, nav2_core::GoalChecker * goal_checker);
const nav_msgs::msg::Path & plan,
const geometry_msgs::msg::Pose & goal, nav2_core::GoalChecker * goal_checker);

/**
* @brief Obtain the main controller's parameters
Expand Down Expand Up @@ -250,10 +260,12 @@ class Optimizer
std::array<mppi::models::Control, 4> control_history_;
models::Trajectories generated_trajectories_;
models::Path path_;
geometry_msgs::msg::Pose goal_;
xt::xtensor<float, 1> costs_;

CriticData critics_data_ =
{state_, generated_trajectories_, path_, costs_, settings_.model_dt, false, nullptr, nullptr,
CriticData critics_data_ = {
state_, generated_trajectories_, path_, goal_,
costs_, settings_.model_dt, false, nullptr, nullptr,
std::nullopt, std::nullopt}; /// Caution, keep references

rclcpp::Logger logger_{rclcpp::get_logger("MPPIController")};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ class PathHandler
*/
nav_msgs::msg::Path transformPath(const geometry_msgs::msg::PoseStamped & robot_pose);

/**
* @brief Get the global goal pose transformed to the local frame
* @return Transformed goal pose
*/
geometry_msgs::msg::PoseStamped getTransformedGoal();

protected:
/**
* @brief Transform a pose to another frame
Expand Down
29 changes: 10 additions & 19 deletions nav2_mppi_controller/include/nav2_mppi_controller/tools/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,27 +195,23 @@ inline models::Path toTensor(const nav_msgs::msg::Path & path)
* @brief Check if the robot pose is within the Goal Checker's tolerances to goal
* @param global_checker Pointer to the goal checker
* @param robot Pose of robot
* @param path Path to retreive goal pose from
* @param goal Goal pose
* @return bool If robot is within goal checker tolerances to the goal
*/
inline bool withinPositionGoalTolerance(
nav2_core::GoalChecker * goal_checker,
const geometry_msgs::msg::Pose & robot,
const models::Path & path)
const geometry_msgs::msg::Pose & goal)
{
const auto goal_idx = path.x.shape(0) - 1;
const auto goal_x = path.x(goal_idx);
const auto goal_y = path.y(goal_idx);

if (goal_checker) {
geometry_msgs::msg::Pose pose_tolerance;
geometry_msgs::msg::Twist velocity_tolerance;
goal_checker->getTolerances(pose_tolerance, velocity_tolerance);

const auto pose_tolerance_sq = pose_tolerance.position.x * pose_tolerance.position.x;

auto dx = robot.position.x - goal_x;
auto dy = robot.position.y - goal_y;
auto dx = robot.position.x - goal.position.x;
auto dy = robot.position.y - goal.position.y;

auto dist_sq = dx * dx + dy * dy;

Expand All @@ -231,24 +227,19 @@ inline bool withinPositionGoalTolerance(
* @brief Check if the robot pose is within tolerance to the goal
* @param pose_tolerance Pose tolerance to use
* @param robot Pose of robot
* @param path Path to retreive goal pose from
* @param goal Goal pose
* @return bool If robot is within tolerance to the goal
*/
inline bool withinPositionGoalTolerance(
float pose_tolerance,
const geometry_msgs::msg::Pose & robot,
const models::Path & path)
const geometry_msgs::msg::Pose & goal)
{
const auto goal_idx = path.x.shape(0) - 1;
const auto goal_x = path.x(goal_idx);
const auto goal_y = path.y(goal_idx);

const auto pose_tolerance_sq = pose_tolerance * pose_tolerance;

auto dx = robot.position.x - goal_x;
auto dy = robot.position.y - goal_y;
const double & dist_sq =
std::pow(goal.position.x - robot.position.x, 2) +
std::pow(goal.position.y - robot.position.y, 2);

auto dist_sq = dx * dx + dy * dy;
const float pose_tolerance_sq = pose_tolerance * pose_tolerance;

if (dist_sq < pose_tolerance_sq) {
return true;
Expand Down
2 changes: 2 additions & 0 deletions nav2_mppi_controller/msg/CriticScore.msg
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
std_msgs/String name
std_msgs/Float32 score
2 changes: 2 additions & 0 deletions nav2_mppi_controller/msg/CriticScores.msg
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
std_msgs/Header header # when msg was sent
CriticScore[] critic_scores
5 changes: 5 additions & 0 deletions nav2_mppi_controller/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

<buildtool_depend>ament_cmake</buildtool_depend>
<buildtool_depend>ament_cmake_ros</buildtool_depend>
<buildtool_depend>rosidl_default_generators</buildtool_depend>

<exec_depend>rosidl_default_runtime</exec_depend>

<depend>rclcpp</depend>
<depend>nav2_common</depend>
Expand All @@ -33,6 +36,8 @@
<test_depend>ament_lint_auto</test_depend>
<test_depend>ament_lint_common</test_depend>
<test_depend>ament_cmake_gtest</test_depend>

<member_of_group>rosidl_interface_packages</member_of_group>
<export>
<build_type>ament_cmake</build_type>
<nav2_core plugin="${prefix}/mppic.xml" />
Expand Down
46 changes: 45 additions & 1 deletion nav2_mppi_controller/src/controller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ void MPPIController::configure(
// Get high-level controller parameters
auto getParam = parameters_handler_->getParamGetter(name_);
getParam(visualize_, "visualize", false);
getParam(publish_critics_, "publish_critics", false);
getParam(reset_period_, "reset_period", 1.0);

// Configure composed objects
Expand All @@ -48,6 +49,11 @@ void MPPIController::configure(
parent_, name_,
costmap_ros_->getGlobalFrameID(), parameters_handler_.get());

if (publish_critics_) {
critics_publisher_ = node->create_publisher<nav2_mppi_controller::msg::CriticScores>(
"/mppi_critic_scores", 1);
}

RCLCPP_INFO(logger_, "Configured MPPI Controller: %s", name_.c_str());
}

Expand All @@ -61,13 +67,19 @@ void MPPIController::cleanup()

void MPPIController::activate()
{
if (publish_critics_) {
critics_publisher_->on_activate();
}
trajectory_visualizer_.on_activate();
parameters_handler_->start();
RCLCPP_INFO(logger_, "Activated MPPI Controller: %s", name_.c_str());
}

void MPPIController::deactivate()
{
if (publish_critics_) {
critics_publisher_->on_deactivate();
}
trajectory_visualizer_.on_deactivate();
RCLCPP_INFO(logger_, "Deactivated MPPI Controller: %s", name_.c_str());
}
Expand All @@ -92,13 +104,15 @@ geometry_msgs::msg::TwistStamped MPPIController::computeVelocityCommands(
last_time_called_ = clock_->now();

std::lock_guard<std::mutex> param_lock(*parameters_handler_->getLock());
geometry_msgs::msg::Pose goal = path_handler_.getTransformedGoal().pose;

nav_msgs::msg::Path transformed_plan = path_handler_.transformPath(robot_pose);

nav2_costmap_2d::Costmap2D * costmap = costmap_ros_->getCostmap();
std::unique_lock<nav2_costmap_2d::Costmap2D::mutex_t> costmap_lock(*(costmap->getMutex()));

geometry_msgs::msg::TwistStamped cmd =
optimizer_.evalControl(robot_pose, robot_speed, transformed_plan, goal_checker);
optimizer_.evalControl(robot_pose, robot_speed, transformed_plan, goal, goal_checker);

#ifdef BENCHMARK_TESTING
auto end = std::chrono::system_clock::now();
Expand All @@ -110,6 +124,36 @@ geometry_msgs::msg::TwistStamped MPPIController::computeVelocityCommands(
visualize(std::move(transformed_plan));
}

if (publish_critics_) {
std::vector<std::string> critic_names = optimizer_.getCriticNames();
xt::xtensor<float, 1> critic_costs = optimizer_.getOptimizationResults();

// log critic names and costs
for (size_t i = 0; i < critic_names.size(); i++) {
RCLCPP_DEBUG(logger_, "Critic: %s, Cost: %f", critic_names[i].c_str(), critic_costs[i]);
}

// make msg
auto critic_scores_ = std::make_unique<nav2_mppi_controller::msg::CriticScores>();
if (critic_names.size() != critic_costs.size()) {
RCLCPP_ERROR(
logger_,
"Critic names %ld and costs %ld size mismatch!",
critic_names.size(), critic_costs.size());
return cmd;
}

for (size_t i = 0; i < critic_names.size(); i++) {
nav2_mppi_controller::msg::CriticScore critic_score;
critic_score.name.data = critic_names[i];
critic_score.score.data = critic_costs[i];
critic_scores_->critic_scores.push_back(critic_score);
}

critic_scores_->header.stamp = clock_->now();
critics_publisher_->publish(std::move(critic_scores_));
}

return cmd;
}

Expand Down
Loading