Skip to content

Commit

Permalink
introduce feature: final callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
aleneum committed Aug 29, 2022
1 parent 2e2fda9 commit 85137c2
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 14 deletions.
106 changes: 103 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,48 @@ machine = Machine(lump, states=['A', 'B', 'C'])

Now, any time `lump` transitions to state `A`, the `on_enter_A()` method defined in the `Matter` class will fire.

_Experimental in 0.9.0:_
When you use the [Tag feature](#state-features) (described in more detail below) you can make use of `on_final` callbacks.

```python
from transitions import Machine
from transitions.extensions.states import Tags as State

states = [State(name='idling'),
State(name='rescuing_kitten'),
State(name='offender_escaped', tags='final'),
State(name='offender_caught', tags='final')]

transitions = [["called", "idling", "rescuing_kitten"], # we will come when called
{"trigger": "intervene",
"source": "rescuing_kitten",
"dest": "offender_caught", # we will catch the offender
"unless": "offender_is_faster"}, # unless they are faster
["intervene", "rescuing_kitten", "offender_gone"]]


class FinalSuperhero(object):

def __init__(self, speed):
self.machine = Machine(self, states=states, transitions=transitions, initial="idling", on_final="claim_success")
self.speed = speed

def offender_is_faster(self):
self.speed < 15

def claim_success(self):
print("The kitten is safe.")


hero = FinalSuperhero(speed=10) # we are not in shape today
hero.called()
assert hero.is_rescuing_kitten()
hero.intervene()
# >>> 'The kitten is safe'
assert hero.machine.get_state(hero.state).is_final # it's over
assert hero.is_offender_gone() # maybe next time
```

#### <a name="checking-state"></a>Checking state

You can always check the current state of the model by either:
Expand Down Expand Up @@ -1417,7 +1459,6 @@ assert machine.is_C.s2() is False
assert machine.is_C.s2(allow_substates=True) # FunctionWrapper support allow_substate as well
```

_new in 0.8.0_
You can use enumerations in HSMs as well but keep in mind that `Enum` are compared by value.
If you have a value more than once in a state tree those states cannot be distinguished.

Expand All @@ -1429,7 +1470,6 @@ machine.to_B()
machine.is_GREEN() # returns True even though the actual state is B_GREEN
```

_new in 0.8.0_
`HierarchicalMachine` has been rewritten from scratch to support parallel states and better isolation of nested states.
This involves some tweaks based on community feedback.
To get an idea of processing order and configuration have a look at the following example:
Expand Down Expand Up @@ -1483,6 +1523,66 @@ m.to_B_1()
assert m.is_B(allow_substates=True)
```

_Experimental in 0.9.0:_
When you extend your states with a `is_final` property (for instance by using the [Tag feature](#state-features) described below) you can make use of `on_final` callbacks either in states or on the HSM itself. Callbacks will be triggered if a) the state itself is tagged with `final` and has just been entered or b) all substates are considered final and at least one substate just entered a final state. In case of b) all parents will be considered final as well if condition b) holds true for them. This might be useful in cases where processing happens in parallel and your HSM or any parent state should be notified when all substates have reached a final state:


```python
from transitions.extensions import HierarchicalMachine
from transitions.extensions.states import add_state_features, Tags


@add_state_features(Tags)
class FinalHSM(HierarchicalMachine):

def final_event_raised(self, event_data):
# one way to get the currently finalized state is via the scoped attribute of the machine passed
# with 'event_data'. However, this is done here to keep the example short. In most cases dedicated
# final callbacks will probably result in cleaner and more comprehensible code.
print("{} is final!".format(event_data.machine.scoped.name or "Machine"))


# We initialize this parallel HSM in state A:
# / X
# / / yI
# A -> B - Y - yII [final]
# \ Z - zI
# \ zII [final]

states = ['A', {'name': 'B', 'parallel': [{'name': 'X', 'tags': ['final'], 'on_final': 'final_event_raised'},
{'name': 'Y', 'transitions': [['final_Y', 'yI', 'yII']],
'initial': 'yI',
'on_final': 'final_event_raised',
'states':
['yI', {'name': 'yII', 'tags': ['final']}]
},
{'name': 'Z', 'transitions': [['final_Z', 'zI', 'zII']],
'initial': 'zI',
'on_final': 'final_event_raised',
'states':
['zI', {'name': 'zII', 'tags': ['final']}]
},
],
"on_final": 'final_event_raised'}]

machine = FinalHSM(states=states, on_final='final_event_raised', initial='A', send_event=True)
# X will emit a final event right away
machine.to_B()
# >>> X is final!
print(machine.state)
# >>> ['B_X', 'B_Y_yI', 'B_Z_zI']
# Y's substate is final now and will trigger 'on_final' on Y
machine.final_Y()
# >>> Y is final!
print(machine.state)
# >>> ['B_X', 'B_Y_yII', 'B_Z_zI']
# Z's substate becomes final which also makes all children of B final and thus machine itself
machine.final_Z()
# >>> Z is final!
# >>> B is final!
# >>> Machine is final!
```

#### Reuse of previously created HSMs

Besides semantic order, nested states are very handy if you want to specify state machines for specific tasks and plan to reuse them.
Expand Down Expand Up @@ -1874,7 +1974,7 @@ asyncio.run(asyncio.wait([m.to_B(), asyncio.sleep(0.3)]))
assert m.is_C() # now timeout should have been processed
```

You should consider passing `queued=True` to the `TimeoutMachine` constructor. This will make sure that events are processed sequentially and avoid asynchronous [racing conditions](https://github.com/pytransitions/transitions/issues/459) that may appear when timeout and event happen in close proximity.
You should consider passing `queued=True` to the `TimeoutMachine` constructor. This will make sure that events are processed sequentially and avoid asynchronous [racing conditions](https://github.com/pytransitions/transitions/issues/459) that may appear when timeout and event happen in proximity.

#### <a name="django-support"></a> Using transitions together with Django

Expand Down
83 changes: 75 additions & 8 deletions tests/test_states.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from transitions import Machine
from transitions import Machine, MachineError
from transitions.extensions.states import *
from transitions.extensions import MachineFactory
from time import sleep
Expand All @@ -14,10 +14,13 @@

class TestTransitions(TestCase):

def setUp(self):
self.machine_cls = Machine

def test_tags(self):

@add_state_features(Tags)
class CustomMachine(Machine):
class CustomMachine(self.machine_cls):
pass

states = [{"name": "A", "tags": ["initial", "success", "error_state"]}]
Expand All @@ -31,7 +34,7 @@ class CustomMachine(Machine):
def test_error(self):

@add_state_features(Error)
class CustomMachine(Machine):
class CustomMachine(self.machine_cls):
pass

states = ['A', 'B', 'F',
Expand All @@ -53,7 +56,7 @@ class CustomMachine(Machine):

def test_error_callback(self):
@add_state_features(Error)
class CustomMachine(Machine):
class CustomMachine(self.machine_cls):
pass

mock_callback = MagicMock()
Expand All @@ -72,7 +75,7 @@ def test_timeout(self):
mock = MagicMock()

@add_state_features(Timeout)
class CustomMachine(Machine):
class CustomMachine(self.machine_cls):

def timeout(self):
mock()
Expand Down Expand Up @@ -102,7 +105,7 @@ def test_timeout_callbacks(self):
counter = MagicMock()

@add_state_features(Timeout)
class CustomMachine(Machine):
class CustomMachine(self.machine_cls):
pass

class Model(object):
Expand Down Expand Up @@ -143,7 +146,7 @@ def test_timeout_transitioning(self):
timeout_mock = MagicMock()

@add_state_features(Timeout)
class CustomMachine(Machine):
class CustomMachine(self.machine_cls):
pass

states = ['A', {'name': 'B', 'timeout': 0.05, 'on_timeout': ['to_A', timeout_mock]}]
Expand All @@ -164,7 +167,7 @@ def increase(self):
self.value += 1

@add_state_features(Volatile)
class CustomMachine(Machine):
class CustomMachine(self.machine_cls):
pass

states = ['A', {'name': 'B', 'volatile': TemporalState}]
Expand All @@ -188,6 +191,70 @@ class CustomMachine(Machine):
# value should be reset
self.assertEqual(m.scope.value, 5)

def test_final_state(self):
final_mock = MagicMock()

@add_state_features(Tags)
class CustomMachine(self.machine_cls):
pass

machine = CustomMachine(states=['A', {'name': 'B', 'tags': ['final']}], on_final=final_mock, initial='A')
self.assertFalse(final_mock.called)
machine.to_B()
self.assertTrue(final_mock.called)
machine.to_A()
self.assertEqual(1, final_mock.call_count)
machine.to_B()
self.assertEqual(2, final_mock.call_count)


class TestStatesNested(TestTransitions):

def setUp(self):
self.machine_cls = MachineFactory.get_predefined(locked=True, nested=True, graph=True)

def test_final_state_nested(self):
final_mock_B = MagicMock()
final_mock_Y = MagicMock()
final_mock_Z = MagicMock()
final_mock_machine = MagicMock()
mocks = [final_mock_B, final_mock_Y, final_mock_Z, final_mock_machine]

@add_state_features(Tags)
class CustomMachine(self.machine_cls):
pass

states = ['A', {'name': 'B', 'parallel': [{'name': 'X', 'tags': ['final']},
{'name': 'Y', 'transitions': [['final_Y', 'yI', 'yII']],
'initial': 'yI',
'on_final': final_mock_Y,
'states':
['yI', {'name': 'yII', 'tags': ['final']}]
},
{'name': 'Z', 'transitions': [['final_Z', 'zI', 'zII']],
'initial': 'zI',
'on_final': final_mock_Z,
'states':
['zI', {'name': 'zII', 'tags': ['final']}]
},
],
"on_final": final_mock_B}]

machine = CustomMachine(states=states, on_final=final_mock_machine,
initial='A')
self.assertFalse(any(mock.called for mock in mocks))
machine.to_B()
self.assertFalse(any(mock.called for mock in mocks))
machine.final_Y()
self.assertTrue(final_mock_Y.called)
self.assertFalse(final_mock_Z.called or final_mock_B.called or final_mock_machine.called)
machine.final_Z()
self.assertTrue(all(mock.called for mock in mocks))
self.assertEqual(1, final_mock_Y.call_count)
self.assertEqual(1, final_mock_Z.call_count)
self.assertEqual(1, final_mock_B.call_count)
self.assertEqual(1, final_mock_machine.call_count)


class TestStatesDiagramsLockedNested(TestDiagramsLockedNested):

Expand Down
19 changes: 17 additions & 2 deletions transitions/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,10 @@ def _change_state(self, event_data):
event_data.machine.get_state(self.source).exit(event_data)
event_data.machine.set_state(self.dest, event_data.model)
event_data.update(getattr(event_data.model, event_data.machine.model_attribute))
event_data.machine.get_state(self.dest).enter(event_data)
dest_state = event_data.machine.get_state(self.dest)
dest_state.enter(event_data)
if getattr(dest_state, 'is_final', False):
event_data.machine.callbacks(event_data.machine.on_final, event_data)

def add_callback(self, trigger, func):
""" Add a new before, after, or prepare callback.
Expand Down Expand Up @@ -508,7 +511,7 @@ def __init__(self, model=self_literal, states=None, initial='initial', transitio
ordered_transitions=False, ignore_invalid_triggers=None,
before_state_change=None, after_state_change=None, name=None,
queued=False, prepare_event=None, finalize_event=None, model_attribute='state', on_exception=None,
**kwargs):
on_final=None, **kwargs):
"""
Args:
model (object or list): The object(s) whose states we want to manage. If set to `Machine.self_literal`
Expand Down Expand Up @@ -573,6 +576,7 @@ def __init__(self, model=self_literal, states=None, initial='initial', transitio
self._prepare_event = []
self._finalize_event = []
self._on_exception = []
self._on_final = []
self._initial = None

self.states = OrderedDict()
Expand All @@ -585,6 +589,7 @@ def __init__(self, model=self_literal, states=None, initial='initial', transitio
self.after_state_change = after_state_change
self.finalize_event = finalize_event
self.on_exception = on_exception
self.on_final = on_final
self.name = name + ": " if name is not None else ""
self.model_attribute = model_attribute

Expand Down Expand Up @@ -740,6 +745,16 @@ def on_exception(self):
def on_exception(self, value):
self._on_exception = listify(value)

@property
def on_final(self):
"""Callbacks will be executed when the reached state is tagged with 'final'"""
return self._on_final

# this should make sure that finalize_event is always a list
@on_final.setter
def on_final(self, value):
self._on_final = listify(value)

def get_state(self, state):
""" Return the State instance with the passed name. """
if isinstance(state, Enum):
Expand Down
Loading

0 comments on commit 85137c2

Please sign in to comment.