Skip to content

Commit

Permalink
fix: improve timeout processing for better (raised) exception handling
Browse files Browse the repository at this point in the history
- see #682 for details
- renamed `AsyncMachine.switch_model_context` to `cancel_running_transitions`
  • Loading branch information
aleneum committed Aug 20, 2024
1 parent 19f1aea commit aa7b8dc
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 9 deletions.
74 changes: 70 additions & 4 deletions tests/test_async.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from asyncio import CancelledError

from transitions.extensions.factory import AsyncGraphMachine, HierarchicalAsyncGraphMachine
from transitions.extensions.states import add_state_features

try:
import asyncio
from transitions.extensions.asyncio import AsyncMachine, HierarchicalAsyncMachine, AsyncEventData, \
AsyncTransition
AsyncTransition, AsyncTimeout

except (ImportError, SyntaxError):
asyncio = None # type: ignore
Expand Down Expand Up @@ -342,9 +345,6 @@ async def run():
asyncio.run(run())

def test_async_timeout(self):
from transitions.extensions.states import add_state_features
from transitions.extensions.asyncio import AsyncTimeout

timeout_called = MagicMock()

@add_state_features(AsyncTimeout)
Expand Down Expand Up @@ -376,6 +376,72 @@ async def run():

asyncio.run(run())

def test_timeout_cancel(self):
error_mock = MagicMock()
timout_mock = MagicMock()
long_op_mock = MagicMock()

@add_state_features(AsyncTimeout)
class TimeoutMachine(self.machine_cls): # type: ignore
async def on_enter_B(self):
await asyncio.sleep(0.2)
long_op_mock() # should never be called

async def handle_timeout(self):
timout_mock()
await self.to_A()

machine = TimeoutMachine(states=["A", {"name": "B", "timeout": 0.1, "on_timeout": "handle_timeout"}],
initial="A", on_exception=error_mock)

async def run():
await machine.to_B()
assert timout_mock.called
assert error_mock.call_count == 1 # should only be one CancelledError
assert not long_op_mock.called
assert machine.is_A()
asyncio.run(run())

def test_queued_timeout_cancel(self):
error_mock = MagicMock()
timout_mock = MagicMock()
long_op_mock = MagicMock()

@add_state_features(AsyncTimeout)
class TimeoutMachine(self.machine_cls): # type: ignore
async def long_op(self, event_data):
await self.to_C()
await self.to_D()
await self.to_E()
await asyncio.sleep(1)
long_op_mock()

async def handle_timeout(self, event_data):
timout_mock()
raise TimeoutError()

async def handle_error(self, event_data):
if isinstance(event_data.error, CancelledError):
if error_mock.called:
raise RuntimeError()
error_mock()
raise event_data.error

machine = TimeoutMachine(states=["A", "C", "D", "E",
{"name": "B", "timeout": 0.1, "on_timeout": "handle_timeout",
"on_enter": "long_op"}],
initial="A", queued=True, send_event=True, on_exception="handle_error")

async def run():
await machine.to_B()
assert timout_mock.called
assert error_mock.called
assert not long_op_mock.called
assert machine.is_B()
with self.assertRaises(RuntimeError):
await machine.to_B()
asyncio.run(run())

def test_callback_order(self):
finished = []

Expand Down
32 changes: 28 additions & 4 deletions transitions/extensions/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import asyncio
import contextvars
import inspect
import warnings
from collections import deque
from functools import partial, reduce
import copy
Expand Down Expand Up @@ -116,7 +117,7 @@ async def execute(self, event_data):

machine = event_data.machine
# cancel running tasks since the transition will happen
await machine.switch_model_context(event_data.model)
await machine.cancel_running_transitions(event_data.model)

await event_data.machine.callbacks(event_data.machine.before_state_change, event_data)
await event_data.machine.callbacks(self.before, event_data)
Expand Down Expand Up @@ -189,7 +190,8 @@ async def _trigger(self, event_data):
if self._is_valid_source(event_data.state):
await self._process(event_data)
except BaseException as err: # pylint: disable=broad-except; Exception will be handled elsewhere
_LOGGER.error("%sException was raised while processing the trigger: %s", self.machine.name, err)
_LOGGER.error("%sException was raised while processing the trigger '%s': %s",
self.machine.name, event_data.event.name, repr(err))
event_data.error = err
if self.machine.on_exception:
await self.machine.callbacks(self.machine.on_exception, event_data)
Expand Down Expand Up @@ -374,6 +376,11 @@ async def await_all(callables):
return await asyncio.gather(*[func() for func in callables])

async def switch_model_context(self, model):
warnings.warn("Please replace 'AsyncMachine.switch_model_context' with "
"'AsyncMachine.cancel_running_transitions'.", category=DeprecationWarning)
await self.cancel_running_transitions(model)

async def cancel_running_transitions(self, model):
"""
This method is called by an `AsyncTransition` when all conditional tests have passed
and the transition will happen. This requires already running tasks to be cancelled.
Expand All @@ -399,7 +406,7 @@ async def process_context(self, func, model):
bool: returns the success state of the triggered event
"""
if self.current_context.get() is None:
self.current_context.set(asyncio.current_task())
token = self.current_context.set(asyncio.current_task())
if id(model) in self.async_tasks:
self.async_tasks[id(model)].append(asyncio.current_task())
else:
Expand All @@ -410,6 +417,7 @@ async def process_context(self, func, model):
res = False
finally:
self.async_tasks[id(model)].remove(asyncio.current_task())
self.current_context.reset(token)
if len(self.async_tasks[id(model)]) == 0:
del self.async_tasks[id(model)]
else:
Expand Down Expand Up @@ -682,7 +690,23 @@ async def _timeout():

async def _process_timeout(self, event_data):
_LOGGER.debug("%sTimeout state %s. Processing callbacks...", event_data.machine.name, self.name)
await event_data.machine.callbacks(self.on_timeout, event_data)
event_data = AsyncEventData(event_data.state, AsyncEvent("_timeout", event_data.machine),
event_data.machine, event_data.model, args=tuple(), kwargs={})
try:
await event_data.machine.callbacks(self.on_timeout, event_data)
except BaseException as err:
_LOGGER.warning("%sException raised while processing timeout!",
event_data.machine.name)
event_data.error = err
try:
if event_data.machine.on_exception:
await event_data.machine.callbacks(event_data.machine.on_exception, event_data)
else:
raise
except BaseException as err2:
_LOGGER.error("%sHandling timeout exception '%s' caused another exception: %s. "
"Cancel running transitions...", event_data.machine.name, repr(err), repr(err2))
await event_data.machine.cancel_running_transitions(event_data.model)
_LOGGER.info("%sTimeout state %s processed.", event_data.machine.name, self.name)

@property
Expand Down
3 changes: 2 additions & 1 deletion transitions/extensions/asyncio.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,12 @@ class AsyncMachine(Machine):
async def callback(self, func: AsyncCallback, event_data: AsyncEventData) -> None: ... # type: ignore[override]
@staticmethod
async def await_all(callables: List[AsyncCallbackFunc]) -> List[Optional[bool]]: ...
async def cancel_running_transitions(self, model: object) -> None: ...
async def switch_model_context(self, model: object) -> None: ...
def get_state(self, state: Union[str, Enum]) -> AsyncState: ...
async def process_context(self, func: Callable[[], Awaitable[None]], model: object) -> bool: ...
def remove_model(self, model: object) -> None: ...
async def _process_async(self, trigger: Callable[[], Awaitable[None]], model: object) -> bool: ...
async def _process_async(self, trigger: Callable[[], Awaitable[None]], model: object, queued: bool) -> bool: ...


class HierarchicalAsyncMachine(HierarchicalMachine, AsyncMachine): # type: ignore
Expand Down

0 comments on commit aa7b8dc

Please sign in to comment.