Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve multimodel support and parallel fixes #40

Merged
merged 8 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

Release 0.9.0 has been developed to be compatible with `transitions` 0.9.1.

* Update all npm packages to their recent version
* Align minor version to transitions
* Add stub files and `py.typed`
* Extend documentation with styling example
* Updated all npm packages to their recent version
* Aligned minor version to transitions
* Added stub files and `py.typed`
* Extended documentation with styling example
* Added a legend and model names to graphs with multiple models
* Bugfix: Add default machine name to graph when none was passed by markup
* Bugfix #24: Edges with same source not correctly disambiguated for highlighting (thanks @Bilby42)

Expand Down
4 changes: 3 additions & 1 deletion examples/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
}]
}, 'preparing']

transitions = [['go', 'preparing', 'style']]
transitions = [['inc', 'preparing', 'style'],
['inc', 'style', 'preparing']]

machine = NestedWebMachine(states=states, transitions=transitions, initial='preparing',
name="Label Machine",
Expand All @@ -37,5 +38,6 @@
try:
while True:
time.sleep(5)
machine.inc()
except KeyboardInterrupt: # Ctrl + C will shutdown the machine
machine.stop_server()
36 changes: 18 additions & 18 deletions frontend/src/modules/graph.js
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ let defaultStyle = [
'label': 'data(label)',
'text-valign': 'center',
'text-halign': 'center',
'text-wrap': 'wrap',
'border-width': '2',
'border-color': 'black',
'background-color': '#fff',
Expand All @@ -58,7 +59,9 @@ let defaultStyle = [
selector: 'node[parallel]',
style: {
'background-opacity': 0,
'background-image': renderParallel
'background-image': function(ele) {
return renderParallel(ele).svg
}
}
},
{
Expand Down Expand Up @@ -185,22 +188,22 @@ export function initLegend(nodes, style) {
return cyLeg
}

const emptySvg = 'data:image/svg+xml,' + encodeURIComponent(`
<?xml version="1.0" encoding="UTF-8"?><!DOCTYPE svg>
<svg xmlns="http://www.w3.org/2000/svg">
</svg>`)

function renderParallel (ele) {
if (!ele) {
return ele
if (!ele || !ele._private.autoWidth) {
return {svg: emptySvg}
}

const width = ele._private.autoWidth
if (width === undefined) {
return ele
}

const height = ele._private.autoHeight + 20
const pos = ele._private.position
const left = pos.x - width / 2
const top = pos.y - height / 2
let lines = ""
if (ele._private.children.length > 1) {
let lines = ''
if (height > width) {
const cPosX = ele._private.children.map(c => c._private.position.x).sort(function (a, b) { return a - b })
for (let i = 1; i < cPosX.length; ++i) {
Expand All @@ -214,14 +217,11 @@ function renderParallel (ele) {
lines += `<line x1="0" y1="${y}" x2="${width}" y2="${y}" stroke="black" stroke-dasharray="4, 4" />\n`
}
}
var svg = `<?xml version="1.0" encoding="UTF-8"?><!DOCTYPE svg>
<svg xmlns="http://www.w3.org/2000/svg" width="${width}" height="${height}" stroke-width="2">
${lines}
</svg>`
// console.log(svg)
var payload = 'data:image/svg+xml,' + encodeURIComponent(svg)
// console.log(payload)
return payload
}
return ele
const svg = `<?xml version="1.0" encoding="UTF-8"?><!DOCTYPE svg>
<svg xmlns="http://www.w3.org/2000/svg" width="${width}" height="${height}">
${lines}
</svg>`
// console.log(svg)
return {svg: 'data:image/svg+xml,' + encodeURIComponent(svg), width: width, height: height}
}
32 changes: 24 additions & 8 deletions frontend/src/modules/machine.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@ export default class WebMachine {
}
this.cy = initGraph(machine.nodes, machine.edges, layout, this.style)
let legendEntries = []
transitionsMarkup.models.forEach(model => {
legendEntries.push({name: model.name, class: model['class-name'].replace(/\W/g, ''), state: model.state})
});
this.updateLegend(legendEntries)
transitionsMarkup.models.forEach(model => {
this.modelClasses[model.name] = model['class-name'].replace(/\W/g, '')
this.modelStates[model.name] = []
this.modelTransitions[model.name] = []
this.selectState(model.name, model.state)
legendEntries.push({name: model.name, class: model['class-name'].replace(/\W/g, ''), state: model.state})
});
this.updateLegend(legendEntries)
}

updateLegend(entries) {
Expand All @@ -41,6 +43,7 @@ export default class WebMachine {
this.cyLegend = initLegend(nodes, this.style)
} else {
document.getElementById('legend').style.display = 'none'
this.cyLegend = undefined
}
}

Expand Down Expand Up @@ -83,20 +86,30 @@ export default class WebMachine {

selectState (modelName, state) {
let escapedName = modelName.replace(/\W/g, '')
// console.log(this.modelStates)
const hasLegend = this.cyLegend !== undefined
this.modelStates[modelName].forEach(node => {
node.removeClass('currentState')
if (!hasLegend) {
node.removeClass('currentState')
}
node.removeClass(escapedName)
node.removeClass(this.modelClasses[modelName])
if (hasLegend) {
node.data('label', node.data('label').replace(`\n${modelName}`, ''))
}
// console.log(node)
})
const states = (Array.isArray(state)) ? state : [state]
this.modelStates[modelName] = states.map(stateName => { return this.cy.getElementById(stateName) })
// console.log(this.modelStates[modelName])
this.modelStates[modelName].forEach(node => {
node.addClass('currentState')
if (!hasLegend) {
node.addClass('currentState')
}
node.addClass(escapedName)
node.addClass(this.modelClasses[modelName])
if (hasLegend) {
node.data('label', node.data('label') + `\n${modelName}`)
}
})
if (this.cyLegend) {
this.cyLegend.nodes(`#${modelName}`).css({content: `${modelName} <${this.modelClasses[modelName]}>\nState: ${states.join(', ')}`})
Expand All @@ -107,9 +120,12 @@ export default class WebMachine {
// console.log(this.modelTransitions)
// console.log(modelName)
// console.log(this.modelTransitions[modelName])
this.modelTransitions[modelName].forEach(edge => {
edge.removeClass('currentTransition')
})
if (this.modelTransitions[modelName]) {
this.modelTransitions[modelName].forEach(edge => {
edge.removeClass('currentTransition')
})
}

// console.log(transition)
const source = this.cy.nodes(`[id="${transition.source}"]`)
let edge = source.connectedEdges(`[trigger="${transition.trigger}"]`)
Expand Down
1 change: 1 addition & 0 deletions requirements_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pytest
pytest-cov
pytest-runner
pycodestyle
websocket-client
38 changes: 34 additions & 4 deletions tests/test_web.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
import json
from unittest import TestCase
from transitions_gui import WebMachine
import time
import threading
from websocket import create_connection

_SIMPLE_ARGS = dict(states=['A', 'B', 'C'], initial='A', name='Simple Machine',
ordered_transitions=True, ignore_invalid_triggers=True, auto_transitions=False)

_INIT_DELAY = 0.1


class TestWebMachine(TestCase):

def tearDown(self):
assert self.machine._thread is not None
self.machine.stop_server()
time.sleep(0.1)
time.sleep(_INIT_DELAY)

def test_server(self):
self.machine = WebMachine(**_SIMPLE_ARGS) # type: ignore
time.sleep(0.5)
time.sleep(_INIT_DELAY)

def test_threaded_server(self):

Expand All @@ -27,12 +31,38 @@ def run(self):

mf = MachineFactory()
mf.start()
time.sleep(0.5)
time.sleep(_INIT_DELAY)
self.machine = mf.machine

def test_trigger_event(self):
self.machine = WebMachine(**_SIMPLE_ARGS) # type: ignore
time.sleep(0.5)
time.sleep(_INIT_DELAY)
self.assertEqual('A', self.machine.state)
self.machine.process_message({"method": "trigger", "arg": "next_state"})
self.assertEqual('B', self.machine.state)

def test_client_connection(self):
self.machine = WebMachine(**_SIMPLE_ARGS) # type: ignore
time.sleep(_INIT_DELAY)
ws = create_connection(f"ws://localhost:{self.machine.port}/ws")
answer = json.loads(ws.recv())
assert answer["method"] == "update_machine"
config = answer["arg"]
assert config["initial"] == _SIMPLE_ARGS["initial"]
for state in config["states"]:
assert state["name"] in _SIMPLE_ARGS["states"] # type: ignore
ws.close()

def test_client_transition(self):
self.machine = WebMachine(**_SIMPLE_ARGS) # type: ignore
time.sleep(_INIT_DELAY)
ws = create_connection(f"ws://localhost:{self.machine.port}/ws")
_ = ws.recv()
self.machine.next_state()
answer = json.loads(ws.recv())
assert answer["method"] == "state_changed"
config = answer["arg"]
assert config["transition"]["source"] == _SIMPLE_ARGS["initial"]
assert config["transition"]["trigger"] == "next_state"
assert config["state"] != _SIMPLE_ARGS["initial"]
ws.close()
13 changes: 7 additions & 6 deletions transitions_gui/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import tornado.websocket
import json
from collections import defaultdict
import logging

_LOGGER = logging.getLogger(__name__)
Expand All @@ -16,7 +17,7 @@ def get(self):


class WebSocketHandler(tornado.websocket.WebSocketHandler):
sockets = set()
sockets = defaultdict(set)

def __init__(
self,
Expand All @@ -26,16 +27,16 @@ def __init__(
super(WebSocketHandler, self).__init__(*args, **kwargs)

@classmethod
def send_message(cls, message):
for s in cls.sockets:
s.write_message(message, binary=False)
def send_message(cls, message, port):
for socket in cls.sockets[port]:
socket.write_message(message, binary=False)

def initialize(self, machine):
self.machine = machine

def open(self):
_LOGGER.info("WebSocket opened")
self.sockets.add(self)
self.sockets[self.machine.port].add(self)
self.write_message(
{
"method": "update_machine",
Expand All @@ -50,5 +51,5 @@ def on_message(self, message):
self.machine.process_message(message)

def on_close(self):
self.sockets.remove(self)
self.sockets[self.machine.port].remove(self)
_LOGGER.info("WebSocket closed")
5 changes: 2 additions & 3 deletions transitions_gui/handlers/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@ class MainHandler(tornado.web.RequestHandler):
def get(self) -> None: ...

class WebSocketHandler(tornado.websocket.WebSocketHandler):
sockets: ClassVar[Set["WebSocketHandler"]]
sockets: ClassVar[Dict[int, Set["WebSocketHandler"]]]
machine: Optional[WebMachine]

@classmethod
def send_message(cls, message: Dict[str, Any]) -> None: ...
def send_message(self, message: Dict[str, Any], port: int) -> None: ...
def initialize(self, machine: WebMachine) -> None: ...
def open(self, *args: str, **kwargs: str) -> Optional[Awaitable[None]]: ...
def on_message(self, message: Union[str, bytes]) -> None: ...
Expand Down
2 changes: 1 addition & 1 deletion transitions_gui/static/js/main.js

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion transitions_gui/static/js/main.js.map

Large diffs are not rendered by default.

30 changes: 25 additions & 5 deletions transitions_gui/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def _change_state(self, event_data):
{
"method": "state_changed",
"arg": {"model": model_name, "transition": transition, "state": current_state},
}
},
port=event_data.machine.port
)


Expand Down Expand Up @@ -66,7 +67,7 @@ def process_message(self, message):

def start_server(self):
self._iloop = tornado.ioloop.IOLoop.current()
self._http_server = self._application.listen(self._port)
self._http_server = self._application.listen(self.port)
if self._iloop is not None:
self._iloop.start()
_LOGGER.info("Loop stopped")
Expand Down Expand Up @@ -104,7 +105,7 @@ def _init_default_handler(machine, port=8080, daemon=False):

_LOGGER.info("Initializing tornado web application")
machine._application = tornado.web.Application(handlers, **settings)
machine._port = port
machine.port = port
server_thread = threading.Thread(target=machine.start_server)
server_thread.daemon = daemon
machine._thread = server_thread
Expand All @@ -113,8 +114,27 @@ def _init_default_handler(machine, port=8080, daemon=False):
return WebSocketHandler


class NestedWebTransition(WebTransition, NestedTransition):
pass
class NestedWebTransition(NestedTransition):
def _change_state(self, event_data):
super(NestedWebTransition, self)._change_state(event_data)
model_name = (
event_data.model.name
if hasattr(event_data.model, "name")
else str(id(event_data.model))
)
src = event_data.machine.get_global_name(self.source)
dest = event_data.machine.get_global_name(self.dest)
transition = {"source": src, "dest": dest, "trigger": event_data.event.name}
current_state = (
self.dest if hasattr(event_data.model.state, "name") else event_data.model.state
)
event_data.machine.websocket_handler.send_message(
{
"method": "state_changed",
"arg": {"model": model_name, "transition": transition, "state": current_state},
},
port=event_data.machine.port
)


class NestedWebMachine(WebMachine, HierarchicalMachine):
Expand Down