diff --git a/rclpy/rclpy/executors.py b/rclpy/rclpy/executors.py index 78544c91c..336a2b0f5 100644 --- a/rclpy/rclpy/executors.py +++ b/rclpy/rclpy/executors.py @@ -317,14 +317,24 @@ def spin_until_future_complete( future.add_done_callback(lambda x: self.wake()) if timeout_sec is None or timeout_sec < 0: - while self._context.ok() and not future.done() and not self._is_shutdown: + while ( + self._context.ok() + and not future.done() + and not future.cancelled() + and not self._is_shutdown + ): self.spin_once_until_future_complete(future, timeout_sec) else: start = time.monotonic() end = start + timeout_sec timeout_left = TimeoutObject(timeout_sec) - while self._context.ok() and not future.done() and not self._is_shutdown: + while ( + self._context.ok() + and not future.done() + and not future.cancelled() + and not self._is_shutdown + ): self.spin_once_until_future_complete(future, timeout_left) now = time.monotonic() @@ -577,6 +587,8 @@ def _wait_for_ready_callbacks( with self._tasks_lock: # Get rid of any tasks that are done self._tasks = list(filter(lambda t_e_n: not t_e_n[0].done(), self._tasks)) + # Get rid of any tasks that are cancelled + self._tasks = list(filter(lambda t_e_n: not t_e_n[0].cancelled(), self._tasks)) # Gather entities that can be waited on subscriptions: List[Subscription] = [] diff --git a/rclpy/rclpy/task.py b/rclpy/rclpy/task.py index 8e701acea..b602e6366 100644 --- a/rclpy/rclpy/task.py +++ b/rclpy/rclpy/task.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum import inspect import sys import threading @@ -24,14 +25,19 @@ def _fake_weakref(): return None +class FutureState(Enum): + """States defining the lifecycle of a future.""" + + PENDING = 'PENDING' + CANCELLED = 'CANCELLED' + FINISHED = 'FINISHED' + + class Future: """Represent the outcome of a task in the future.""" def __init__(self, *, executor=None): - # true if the task is done or cancelled - self._done = False - # true if the task is cancelled - self._cancelled = False + self._state = FutureState.PENDING # the final return value of the handler self._result = None # An exception raised by the handler when called @@ -53,15 +59,20 @@ def __del__(self): def __await__(self): # Yield if the task is not finished - while not self._done: + while self._pending(): yield return self.result() + def _pending(self) -> bool: + return self._state == FutureState.PENDING + def cancel(self): """Request cancellation of the running task if it is not done already.""" with self._lock: - if not self._done: - self._cancelled = True + if not self._pending(): + return + + self._state = FutureState.CANCELLED self._schedule_or_invoke_done_callbacks() def cancelled(self): @@ -71,7 +82,7 @@ def cancelled(self): :return: True if the task was cancelled :rtype: bool """ - return self._cancelled + return self._state == FutureState.CANCELLED def done(self): """ @@ -80,7 +91,7 @@ def done(self): :return: True if the task is finished or raised while it was executing :rtype: bool """ - return self._done + return self._state == FutureState.FINISHED def result(self): """ @@ -111,8 +122,8 @@ def set_result(self, result): """ with self._lock: self._result = result - self._done = True - self._cancelled = False + self._state = FutureState.FINISHED + self._schedule_or_invoke_done_callbacks() def set_exception(self, exception): @@ -124,8 +135,8 @@ def set_exception(self, exception): with self._lock: self._exception = exception self._exception_fetched = False - self._done = True - self._cancelled = False + self._state = FutureState.FINISHED + self._schedule_or_invoke_done_callbacks() def _schedule_or_invoke_done_callbacks(self): @@ -173,7 +184,7 @@ def add_done_callback(self, callback): """ invoke = False with self._lock: - if self._done: + if not self._pending(): executor = self._executor() if executor is not None: executor.create_task(callback, self) @@ -226,10 +237,14 @@ def __call__(self): The return value of the handler is stored as the task result. """ - if self._done or self._executing or not self._task_lock.acquire(blocking=False): + if ( + not self._pending() or + self._executing or + not self._task_lock.acquire(blocking=False) + ): return try: - if self._done: + if not self._pending(): return self._executing = True @@ -239,7 +254,6 @@ def __call__(self): self._handler.send(None) except StopIteration as e: # The coroutine finished; store the result - self._handler.close() self.set_result(e.value) self._complete_task() except Exception as e: @@ -271,3 +285,9 @@ def executing(self): :rtype: bool """ return self._executing + + def cancel(self) -> None: + if self._pending() and inspect.iscoroutine(self._handler): + self._handler.close() + + super().cancel() diff --git a/rclpy/test/test_executor.py b/rclpy/test/test_executor.py index 1d3d8d975..0dad384ec 100644 --- a/rclpy/test/test_executor.py +++ b/rclpy/test/test_executor.py @@ -255,6 +255,26 @@ async def coroutine(): self.assertTrue(future.done()) self.assertEqual('Sentinel Result', future.result()) + def test_create_task_coroutine_cancel(self) -> None: + self.assertIsNotNone(self.node.handle) + executor = SingleThreadedExecutor(context=self.context) + executor.add_node(self.node) + + async def coroutine(): + return 'Sentinel Result' + + future = executor.create_task(coroutine) + self.assertFalse(future.done()) + self.assertFalse(future.cancelled()) + + future.cancel() + self.assertTrue(future.cancelled()) + + executor.spin_until_future_complete(future) + self.assertFalse(future.done()) + self.assertTrue(future.cancelled()) + self.assertEqual(None, future.result()) + def test_create_task_normal_function(self): self.assertIsNotNone(self.node.handle) executor = SingleThreadedExecutor(context=self.context) diff --git a/rclpy/test/test_task.py b/rclpy/test/test_task.py index f0e92ccf9..d94a74764 100644 --- a/rclpy/test/test_task.py +++ b/rclpy/test/test_task.py @@ -322,6 +322,39 @@ def cb(fut): f.add_done_callback(cb) assert called + def test_set_result_on_done_future_without_exception(self) -> None: + f = Future() + f.set_result(None) + self.assertTrue(f.done()) + self.assertFalse(f.cancelled()) + f.set_result(None) + self.assertTrue(f.done()) + self.assertFalse(f.cancelled()) + + def test_set_result_on_cancelled_future_without_exception(self) -> None: + f = Future() + f.cancel() + self.assertTrue(f.cancelled()) + self.assertFalse(f.done()) + f.set_result(None) + self.assertTrue(f.done()) + + def test_set_exception_on_done_future_without_exception(self) -> None: + f = Future() + f.set_result(None) + self.assertIsNone(f.exception()) + f.set_exception(Exception()) + f.set_result(None) + self.assertIsNotNone(f.exception()) + + def test_set_exception_on_cancelled_future_without_exception(self) -> None: + f = Future() + f.cancel() + self.assertTrue(f.cancelled()) + self.assertIsNone(f.exception()) + f.set_exception(Exception()) + self.assertIsNotNone(f.exception()) + if __name__ == '__main__': unittest.main()