Skip to content

Commit

Permalink
Mppi goal to critic (backport #4822) (#4853)
Browse files Browse the repository at this point in the history
* Mppi goal to critic (#4822)

* Add goal pose to CriticData

Signed-off-by: redvinaa <[email protected]>

* Pass goal pose directly to withinPositionGoalTolerance

Signed-off-by: redvinaa <[email protected]>

* Fix condition not

Signed-off-by: redvinaa <[email protected]>

* Add goal positions to tests

Signed-off-by: redvinaa <[email protected]>

* Use plan stamp

Signed-off-by: redvinaa <[email protected]>

* Use float instead of auto

Signed-off-by: redvinaa <[email protected]>

* Throw nav2_core exceptions

Signed-off-by: redvinaa <[email protected]>

* Set pose frame id in test

Signed-off-by: redvinaa <[email protected]>

* Fix frame id in test vol 2

Signed-off-by: redvinaa <[email protected]>

---------

Signed-off-by: redvinaa <[email protected]>
(cherry picked from commit d11de56)

# Conflicts:
#	nav2_mppi_controller/include/nav2_mppi_controller/tools/utils.hpp

* Update utils.hpp

Signed-off-by: Steve Macenski <[email protected]>

---------

Signed-off-by: Steve Macenski <[email protected]>
Co-authored-by: Vince Reda <[email protected]>
Co-authored-by: Steve Macenski <[email protected]>
  • Loading branch information
3 people authored Jan 14, 2025
1 parent b705c13 commit a6a4c26
Show file tree
Hide file tree
Showing 22 changed files with 160 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ namespace mppi

/**
* @struct mppi::CriticData
* @brief Data to pass to critics for scoring, including state, trajectories, path, costs, and
* important parameters to share
* @brief Data to pass to critics for scoring, including state, trajectories,
* pruned path, global goal, costs, and important parameters to share
*/
struct CriticData
{
const models::State & state;
const models::Trajectories & trajectories;
const models::Path & path;
const geometry_msgs::msg::Pose & goal;

xt::xtensor<float, 1> & costs;
float & model_dt;
Expand Down
11 changes: 7 additions & 4 deletions nav2_mppi_controller/include/nav2_mppi_controller/optimizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class Optimizer
geometry_msgs::msg::TwistStamped evalControl(
const geometry_msgs::msg::PoseStamped & robot_pose,
const geometry_msgs::msg::Twist & robot_speed, const nav_msgs::msg::Path & plan,
nav2_core::GoalChecker * goal_checker);
const geometry_msgs::msg::Pose & goal, nav2_core::GoalChecker * goal_checker);

/**
* @brief Get the trajectories generated in a cycle for visualization
Expand Down Expand Up @@ -138,7 +138,8 @@ class Optimizer
void prepare(
const geometry_msgs::msg::PoseStamped & robot_pose,
const geometry_msgs::msg::Twist & robot_speed,
const nav_msgs::msg::Path & plan, nav2_core::GoalChecker * goal_checker);
const nav_msgs::msg::Path & plan,
const geometry_msgs::msg::Pose & goal, nav2_core::GoalChecker * goal_checker);

/**
* @brief Obtain the main controller's parameters
Expand Down Expand Up @@ -256,10 +257,12 @@ class Optimizer
std::array<mppi::models::Control, 4> control_history_;
models::Trajectories generated_trajectories_;
models::Path path_;
geometry_msgs::msg::Pose goal_;
xt::xtensor<float, 1> costs_;

CriticData critics_data_ =
{state_, generated_trajectories_, path_, costs_, settings_.model_dt, false, nullptr, nullptr,
CriticData critics_data_ = {
state_, generated_trajectories_, path_, goal_,
costs_, settings_.model_dt, false, nullptr, nullptr,
std::nullopt, std::nullopt}; /// Caution, keep references

rclcpp::Logger logger_{rclcpp::get_logger("MPPIController")};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ class PathHandler
*/
nav_msgs::msg::Path transformPath(const geometry_msgs::msg::PoseStamped & robot_pose);

/**
* @brief Get the global goal pose transformed to the local frame
* @return Transformed goal pose
*/
geometry_msgs::msg::PoseStamped getTransformedGoal();

protected:
/**
* @brief Transform a pose to another frame
Expand Down
27 changes: 9 additions & 18 deletions nav2_mppi_controller/include/nav2_mppi_controller/tools/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,27 +204,23 @@ inline models::Path toTensor(const nav_msgs::msg::Path & path)
* @brief Check if the robot pose is within the Goal Checker's tolerances to goal
* @param global_checker Pointer to the goal checker
* @param robot Pose of robot
* @param path Path to retreive goal pose from
* @param goal Goal pose
* @return bool If robot is within goal checker tolerances to the goal
*/
inline bool withinPositionGoalTolerance(
nav2_core::GoalChecker * goal_checker,
const geometry_msgs::msg::Pose & robot,
const models::Path & path)
const geometry_msgs::msg::Pose & goal)
{
const auto goal_idx = path.x.shape(0) - 1;
const auto goal_x = path.x(goal_idx);
const auto goal_y = path.y(goal_idx);

if (goal_checker) {
geometry_msgs::msg::Pose pose_tolerance;
geometry_msgs::msg::Twist velocity_tolerance;
goal_checker->getTolerances(pose_tolerance, velocity_tolerance);

const auto pose_tolerance_sq = pose_tolerance.position.x * pose_tolerance.position.x;

auto dx = robot.position.x - goal_x;
auto dy = robot.position.y - goal_y;
auto dx = robot.position.x - goal.position.x;
auto dy = robot.position.y - goal.position.y;

auto dist_sq = dx * dx + dy * dy;

Expand All @@ -240,25 +236,20 @@ inline bool withinPositionGoalTolerance(
* @brief Check if the robot pose is within tolerance to the goal
* @param pose_tolerance Pose tolerance to use
* @param robot Pose of robot
* @param path Path to retreive goal pose from
* @param goal Goal pose
* @return bool If robot is within tolerance to the goal
*/
inline bool withinPositionGoalTolerance(
float pose_tolerance,
const geometry_msgs::msg::Pose & robot,
const models::Path & path)
const geometry_msgs::msg::Pose & goal)
{
const auto goal_idx = path.x.shape(0) - 1;
const float goal_x = path.x(goal_idx);
const float goal_y = path.y(goal_idx);
const double & dist_sq =
std::pow(goal.position.x - robot.position.x, 2) +
std::pow(goal.position.y - robot.position.y, 2);

const float pose_tolerance_sq = pose_tolerance * pose_tolerance;

const float dx = static_cast<float>(robot.position.x) - goal_x;
const float dy = static_cast<float>(robot.position.y) - goal_y;

float dist_sq = dx * dx + dy * dy;

if (dist_sq < pose_tolerance_sq) {
return true;
}
Expand Down
4 changes: 3 additions & 1 deletion nav2_mppi_controller/src/controller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,15 @@ geometry_msgs::msg::TwistStamped MPPIController::computeVelocityCommands(
#endif

std::lock_guard<std::mutex> param_lock(*parameters_handler_->getLock());
geometry_msgs::msg::Pose goal = path_handler_.getTransformedGoal().pose;

nav_msgs::msg::Path transformed_plan = path_handler_.transformPath(robot_pose);

nav2_costmap_2d::Costmap2D * costmap = costmap_ros_->getCostmap();
std::unique_lock<nav2_costmap_2d::Costmap2D::mutex_t> costmap_lock(*(costmap->getMutex()));

geometry_msgs::msg::TwistStamped cmd =
optimizer_.evalControl(robot_pose, robot_speed, transformed_plan, goal_checker);
optimizer_.evalControl(robot_pose, robot_speed, transformed_plan, goal, goal_checker);

#ifdef BENCHMARK_TESTING
auto end = std::chrono::system_clock::now();
Expand Down
2 changes: 1 addition & 1 deletion nav2_mppi_controller/src/critics/cost_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ void CostCritic::score(CriticData & data)

// If near the goal, don't apply the preferential term since the goal is near obstacles
bool near_goal = false;
if (utils::withinPositionGoalTolerance(near_goal_distance_, data.state.pose.pose, data.path)) {
if (utils::withinPositionGoalTolerance(near_goal_distance_, data.state.pose.pose, data.goal)) {
near_goal = true;
}

Expand Down
2 changes: 1 addition & 1 deletion nav2_mppi_controller/src/critics/goal_angle_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void GoalAngleCritic::initialize()
void GoalAngleCritic::score(CriticData & data)
{
if (!enabled_ || !utils::withinPositionGoalTolerance(
threshold_to_consider_, data.state.pose.pose, data.path))
threshold_to_consider_, data.state.pose.pose, data.goal))
{
return;
}
Expand Down
8 changes: 3 additions & 5 deletions nav2_mppi_controller/src/critics/goal_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,13 @@ void GoalCritic::initialize()
void GoalCritic::score(CriticData & data)
{
if (!enabled_ || !utils::withinPositionGoalTolerance(
threshold_to_consider_, data.state.pose.pose, data.path))
threshold_to_consider_, data.state.pose.pose, data.goal))
{
return;
}

const auto goal_idx = data.path.x.shape(0) - 1;

const auto goal_x = data.path.x(goal_idx);
const auto goal_y = data.path.y(goal_idx);
const auto & goal_x = data.goal.position.x;
const auto & goal_y = data.goal.position.y;

const auto traj_x = xt::view(data.trajectories.x, xt::all(), xt::all());
const auto traj_y = xt::view(data.trajectories.y, xt::all(), xt::all());
Expand Down
2 changes: 1 addition & 1 deletion nav2_mppi_controller/src/critics/obstacles_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ void ObstaclesCritic::score(CriticData & data)

// If near the goal, don't apply the preferential term since the goal is near obstacles
bool near_goal = false;
if (utils::withinPositionGoalTolerance(near_goal_distance_, data.state.pose.pose, data.path)) {
if (utils::withinPositionGoalTolerance(near_goal_distance_, data.state.pose.pose, data.goal)) {
near_goal = true;
}

Expand Down
4 changes: 2 additions & 2 deletions nav2_mppi_controller/src/critics/path_align_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ void PathAlignCritic::initialize()
void PathAlignCritic::score(CriticData & data)
{
// Don't apply close to goal, let the goal critics take over
if (!enabled_ ||
utils::withinPositionGoalTolerance(threshold_to_consider_, data.state.pose.pose, data.path))
if (!enabled_ || utils::withinPositionGoalTolerance(
threshold_to_consider_, data.state.pose.pose, data.goal))
{
return;
}
Expand Down
2 changes: 1 addition & 1 deletion nav2_mppi_controller/src/critics/path_angle_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void PathAngleCritic::initialize()
void PathAngleCritic::score(CriticData & data)
{
if (!enabled_ ||
utils::withinPositionGoalTolerance(threshold_to_consider_, data.state.pose.pose, data.path))
utils::withinPositionGoalTolerance(threshold_to_consider_, data.state.pose.pose, data.goal))
{
return;
}
Expand Down
2 changes: 1 addition & 1 deletion nav2_mppi_controller/src/critics/path_follow_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void PathFollowCritic::initialize()
void PathFollowCritic::score(CriticData & data)
{
if (!enabled_ || data.path.x.shape(0) < 2 ||
utils::withinPositionGoalTolerance(threshold_to_consider_, data.state.pose.pose, data.path))
utils::withinPositionGoalTolerance(threshold_to_consider_, data.state.pose.pose, data.goal))
{
return;
}
Expand Down
4 changes: 2 additions & 2 deletions nav2_mppi_controller/src/critics/prefer_forward_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ void PreferForwardCritic::initialize()
void PreferForwardCritic::score(CriticData & data)
{
using xt::evaluation_strategy::immediate;
if (!enabled_ ||
utils::withinPositionGoalTolerance(threshold_to_consider_, data.state.pose.pose, data.path))
if (!enabled_ || utils::withinPositionGoalTolerance(
threshold_to_consider_, data.state.pose.pose, data.goal))
{
return;
}
Expand Down
2 changes: 1 addition & 1 deletion nav2_mppi_controller/src/critics/twirling_critic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ void TwirlingCritic::score(CriticData & data)
{
using xt::evaluation_strategy::immediate;
if (!enabled_ ||
utils::withinPositionGoalTolerance(data.goal_checker, data.state.pose.pose, data.path))
utils::withinPositionGoalTolerance(data.goal_checker, data.state.pose.pose, data.goal))
{
return;
}
Expand Down
11 changes: 8 additions & 3 deletions nav2_mppi_controller/src/optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,11 @@ bool Optimizer::isHolonomic() const
geometry_msgs::msg::TwistStamped Optimizer::evalControl(
const geometry_msgs::msg::PoseStamped & robot_pose,
const geometry_msgs::msg::Twist & robot_speed,
const nav_msgs::msg::Path & plan, nav2_core::GoalChecker * goal_checker)
const nav_msgs::msg::Path & plan,
const geometry_msgs::msg::Pose & goal,
nav2_core::GoalChecker * goal_checker)
{
prepare(robot_pose, robot_speed, plan, goal_checker);
prepare(robot_pose, robot_speed, plan, goal, goal_checker);

do {
optimize();
Expand Down Expand Up @@ -201,12 +203,15 @@ bool Optimizer::fallback(bool fail)
void Optimizer::prepare(
const geometry_msgs::msg::PoseStamped & robot_pose,
const geometry_msgs::msg::Twist & robot_speed,
const nav_msgs::msg::Path & plan, nav2_core::GoalChecker * goal_checker)
const nav_msgs::msg::Path & plan,
const geometry_msgs::msg::Pose & goal,
nav2_core::GoalChecker * goal_checker)
{
state_.pose = robot_pose;
state_.speed = robot_speed;
path_ = utils::toTensor(plan);
costs_.fill(0.0f);
goal_ = goal;

critics_data_.fail_flag = false;
critics_data_.goal_checker = goal_checker;
Expand Down
14 changes: 14 additions & 0 deletions nav2_mppi_controller/src/path_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,20 @@ void PathHandler::prunePlan(nav_msgs::msg::Path & plan, const PathIterator end)
plan.poses.erase(plan.poses.begin(), end);
}

geometry_msgs::msg::PoseStamped PathHandler::getTransformedGoal()
{
auto goal = global_plan_.poses.back();
goal.header = global_plan_.header;
if (goal.header.frame_id.empty()) {
throw nav2_core::ControllerTFError("Goal pose has an empty frame_id");
}
geometry_msgs::msg::PoseStamped transformed_goal;
if (!transformPose(costmap_->getGlobalFrameID(), goal, transformed_goal)) {
throw nav2_core::ControllerTFError("Unable to transform goal pose into costmap frame");
}
return transformed_goal;
}

bool PathHandler::isWithinInversionTolerances(const geometry_msgs::msg::PoseStamped & robot_pose)
{
// Keep full path if we are within tolerance of the inversion pose
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ TEST(ControllerStateTransitionTest, ControllerNotFail)
auto pose = getDummyPointStamped(node, start_pose);
auto velocity = getDummyTwist();
auto path = getIncrementalDummyPath(node, path_settings);
path.header.frame_id = costmap_ros->getGlobalFrameID();
pose.header.frame_id = costmap_ros->getGlobalFrameID();

controller->setPlan(path);

Expand Down
3 changes: 2 additions & 1 deletion nav2_mppi_controller/test/critic_manager_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,11 @@ TEST(CriticManagerTests, BasicCriticOperations)
models::ControlSequence control_sequence;
models::Trajectories generated_trajectories;
models::Path path;
geometry_msgs::msg::Pose goal;
xt::xtensor<float, 1> costs;
float model_dt = 0.1;
CriticData data =
{state, generated_trajectories, path, costs, model_dt, false, nullptr, nullptr,
{state, generated_trajectories, path, goal, costs, model_dt, false, nullptr, nullptr,
std::nullopt, std::nullopt};

data.fail_flag = true;
Expand Down
Loading

0 comments on commit a6a4c26

Please sign in to comment.