Skip to content

Commit

Permalink
Fixes Action.*_async futures never complete
Browse files Browse the repository at this point in the history
Per rclpy:1123 If two seperate client server actions are running in seperate executors the future given to the ActionClient will never complete due to a race condition
This fixes  the calls to rcl handles potentially leading to deadlock scenarios by adding locks to there references
Co-authored-by: Aditya Agarwal <[email protected]>
Co-authored-by: Jonathan Blixt <[email protected]>
Signed-off-by: Jonathan Blixt <[email protected]>
  • Loading branch information
jmblixt3 committed Nov 27, 2024
1 parent 78f5e14 commit fc563df
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 62 deletions.
125 changes: 68 additions & 57 deletions rclpy/rclpy/action/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ def __init__(
self._node.add_waitable(self)
self._logger = self._node.get_logger().get_child('action_client')

self._lock = threading.Lock()

def _generate_random_uuid(self):
return UUID(uuid=list(uuid.uuid4().bytes))

Expand Down Expand Up @@ -251,39 +253,44 @@ def take_data(self) -> ClientGoalHandleDict:
"""Take stuff from lower level so the wait set doesn't immediately wake again."""
data: ClientGoalHandleDict = {}
if self._is_goal_response_ready:
taken_data = self._client_handle.take_goal_response(
self._action_type.Impl.SendGoalService.Response)
# If take fails, then we get (None, None)
if all(taken_data):
data['goal'] = taken_data
with self._lock:
taken_data = self._client_handle.take_goal_response(
self._action_type.Impl.SendGoalService.Response)
# If take fails, then we get (None, None)
if all(taken_data):
data['goal'] = taken_data

if self._is_cancel_response_ready:
taken_data = self._client_handle.take_cancel_response(
self._action_type.Impl.CancelGoalService.Response)
# If take fails, then we get (None, None)
if all(taken_data):
data['cancel'] = taken_data
with self._lock:
taken_data = self._client_handle.take_cancel_response(
self._action_type.Impl.CancelGoalService.Response)
# If take fails, then we get (None, None)
if all(taken_data):
data['cancel'] = taken_data

if self._is_result_response_ready:
taken_data = self._client_handle.take_result_response(
self._action_type.Impl.GetResultService.Response)
# If take fails, then we get (None, None)
if all(taken_data):
data['result'] = taken_data
with self._lock:
taken_data = self._client_handle.take_result_response(
self._action_type.Impl.GetResultService.Response)
# If take fails, then we get (None, None)
if all(taken_data):
data['result'] = taken_data

if self._is_feedback_ready:
taken_data = self._client_handle.take_feedback(
self._action_type.Impl.FeedbackMessage)
# If take fails, then we get None
if taken_data is not None:
data['feedback'] = taken_data
with self._lock:
taken_data = self._client_handle.take_feedback(
self._action_type.Impl.FeedbackMessage)
# If take fails, then we get None
if taken_data is not None:
data['feedback'] = taken_data

if self._is_status_ready:
taken_data = self._client_handle.take_status(
self._action_type.Impl.GoalStatusMessage)
# If take fails, then we get None
if taken_data is not None:
data['status'] = taken_data
with self._lock:
taken_data = self._client_handle.take_status(
self._action_type.Impl.GoalStatusMessage)
# If take fails, then we get None
if taken_data is not None:
data['status'] = taken_data

return data

Expand Down Expand Up @@ -364,12 +371,14 @@ async def execute(self, taken_data: ClientGoalHandleDict) -> None:

def get_num_entities(self):
"""Return number of each type of entity used in the wait set."""
num_entities = self._client_handle.get_num_entities()
with self._lock:
num_entities = self._client_handle.get_num_entities()
return NumberOfEntities(*num_entities)

def add_to_wait_set(self, wait_set):
"""Add entities to wait set."""
self._client_handle.add_to_waitset(wait_set)
with self._lock:
self._client_handle.add_to_waitset(wait_set)

def __enter__(self):
return self._client_handle.__enter__()
Expand Down Expand Up @@ -447,23 +456,23 @@ def send_goal_async(self, goal, feedback_callback=None, goal_uuid=None):
request = self._action_type.Impl.SendGoalService.Request()
request.goal_id = self._generate_random_uuid() if goal_uuid is None else goal_uuid
request.goal = goal
sequence_number = self._client_handle.send_goal_request(request)
if sequence_number in self._pending_goal_requests:
raise RuntimeError(
'Sequence ({}) conflicts with pending goal request'.format(sequence_number))
future = Future()
with self._lock:
sequence_number = self._client_handle.send_goal_request(request)
if sequence_number in self._pending_goal_requests:
raise RuntimeError(
'Sequence ({}) conflicts with pending goal request'.format(sequence_number))
self._pending_goal_requests[sequence_number] = future
self._goal_sequence_number_to_goal_id[sequence_number] = request.goal_id
future.add_done_callback(self._remove_pending_goal_request)
# Add future so executor is aware
self.add_future(future)

if feedback_callback is not None:
# TODO(jacobperron): Move conversion function to a general-use package
goal_uuid = bytes(request.goal_id.uuid)
self._feedback_callbacks[goal_uuid] = feedback_callback

future = Future()
self._pending_goal_requests[sequence_number] = future
self._goal_sequence_number_to_goal_id[sequence_number] = request.goal_id
future.add_done_callback(self._remove_pending_goal_request)
# Add future so executor is aware
self.add_future(future)

return future

def _cancel_goal(self, goal_handle):
Expand Down Expand Up @@ -505,16 +514,17 @@ def _cancel_goal_async(self, goal_handle):

cancel_request = CancelGoal.Request()
cancel_request.goal_info.goal_id = goal_handle.goal_id
sequence_number = self._client_handle.send_cancel_request(cancel_request)
if sequence_number in self._pending_cancel_requests:
raise RuntimeError(
'Sequence ({}) conflicts with pending cancel request'.format(sequence_number))

future = Future()
self._pending_cancel_requests[sequence_number] = future
future.add_done_callback(self._remove_pending_cancel_request)
# Add future so executor is aware
self.add_future(future)
with self._lock:
sequence_number = self._client_handle.send_cancel_request(cancel_request)
if sequence_number in self._pending_cancel_requests:
raise RuntimeError(
'Sequence ({}) conflicts with pending cancel request'.format(sequence_number))

self._pending_cancel_requests[sequence_number] = future
future.add_done_callback(self._remove_pending_cancel_request)
# Add future so executor is aware
self.add_future(future)

return future

Expand Down Expand Up @@ -557,17 +567,18 @@ def _get_result_async(self, goal_handle):

result_request = self._action_type.Impl.GetResultService.Request()
result_request.goal_id = goal_handle.goal_id
sequence_number = self._client_handle.send_result_request(result_request)
if sequence_number in self._pending_result_requests:
raise RuntimeError(
'Sequence ({}) conflicts with pending result request'.format(sequence_number))

future = Future()
self._pending_result_requests[sequence_number] = future
self._result_sequence_number_to_goal_id[sequence_number] = result_request.goal_id
future.add_done_callback(self._remove_pending_result_request)
# Add future so executor is aware
self.add_future(future)
with self._lock:
sequence_number = self._client_handle.send_result_request(result_request)
if sequence_number in self._pending_result_requests:
raise RuntimeError(
'Sequence ({}) conflicts with pending result request'.format(sequence_number))

self._pending_result_requests[sequence_number] = future
self._result_sequence_number_to_goal_id[sequence_number] = result_request.goal_id
future.add_done_callback(self._remove_pending_result_request)
# Add future so executor is aware
self.add_future(future)

return future

Expand Down
15 changes: 10 additions & 5 deletions rclpy/rclpy/action/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,8 @@ async def _execute_goal_request(self, request_header_and_message):
try:
# If the client goes away anytime before this, sending the goal response may fail.
# Catch the exception here and go on so we don't crash.
self._handle.send_goal_response(request_header, response_msg)
with self._lock:
self._handle.send_goal_response(request_header, response_msg)
except RCLError:
self._logger.warn('Failed to send goal response (the client may have gone away)')
return
Expand Down Expand Up @@ -399,7 +400,8 @@ async def _execute_cancel_request(self, request_header_and_message):
try:
# If the client goes away anytime before this, sending the goal response may fail.
# Catch the exception here and go on so we don't crash.
self._handle.send_cancel_response(request_header, cancel_response)
with self._lock:
self._handle.send_cancel_response(request_header, cancel_response)
except RCLError:
self._logger.warn('Failed to send cancel response (the client may have gone away)')

Expand All @@ -417,7 +419,8 @@ async def _execute_get_result_request(self, request_header_and_message):
'Sending result response for unknown or expired goal ID: {0}'.format(goal_uuid))
result_response = self._action_type.Impl.GetResultService.Response()
result_response.status = GoalStatus.STATUS_UNKNOWN
self._handle.send_result_response(request_header, result_response)
with self._lock:
self._handle.send_result_response(request_header, result_response)
return

# There is an accepted goal matching the goal ID, register a callback to send the
Expand All @@ -437,7 +440,8 @@ def _send_result_response(self, request_header, future):
try:
# If the client goes away anytime before this, sending the result response may fail.
# Catch the exception here and go on so we don't crash.
self._handle.send_result_response(request_header, future.result())
with self._lock:
self._handle.send_result_response(request_header, future.result())
except RCLError:
self._logger.warn('Failed to send result response (the client may have gone away)')

Expand Down Expand Up @@ -513,7 +517,8 @@ async def execute(self, taken_data: ServerGoalHandleDict) -> None:

def get_num_entities(self):
"""Return number of each type of entity used in the wait set."""
num_entities = self._handle.get_num_entities()
with self._lock:
num_entities = self._handle.get_num_entities()
return NumberOfEntities(
num_entities[0],
num_entities[1],
Expand Down

0 comments on commit fc563df

Please sign in to comment.