Skip to content

Commit

Permalink
#81: removed Node.SYMBOL and delegated it to DotConverter
Browse files Browse the repository at this point in the history
  • Loading branch information
marco-biasion committed Sep 12, 2024
1 parent 454f79d commit 281f910
Showing 1 changed file with 149 additions and 33 deletions.
182 changes: 149 additions & 33 deletions sxpat/newGraph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations
from typing import ClassVar, Dict, FrozenSet, Iterable, List, Mapping, NoReturn, Tuple, Union
from typing import Callable, Dict, FrozenSet, Iterable, List, NoReturn, Tuple, Type, Union

from collections import defaultdict
from bidict import bidict
import dataclasses as dc

import networkx as nx
Expand Down Expand Up @@ -34,7 +35,6 @@ class Node:
name: str
weight: int = None
in_subgraph: bool = None
SYMBOL: ClassVar[str] = 'MISSING'

def __post_init__(self) -> None:
assert re.match(r'^\w+$', self.name), f'The name `{self.name}` is invalid, it must match regex `\w+`.'
Expand Down Expand Up @@ -96,76 +96,70 @@ def right(self) -> str:

@dc.dataclass(frozen=True, repr=False)
class IntInput(IntNode):
SYMBOL: ClassVar[str] = 'inI'
pass


@dc.dataclass(frozen=True)
class IntConstant(IntNode):
SYMBOL: ClassVar[str] = 'constI'
value: int = 0


@dc.dataclass(frozen=True, repr=False)
class ToInt(IntNode, OperationNode):
SYMBOL: ClassVar[str] = 'toInt'
pass


@dc.dataclass(frozen=True, repr=False)
class Sum(IntNode, OperationNode):
SYMBOL: ClassVar[str] = 'sum'
pass


@dc.dataclass(frozen=True, repr=False)
class AbsDiff(IntNode, Op2Node):
SYMBOL: ClassVar[str] = 'absdiff'
pass


# > bool


@dc.dataclass(frozen=True, repr=False)
class BoolInput(BoolNode):
SYMBOL: ClassVar[str] = 'inB'
pass


@dc.dataclass(frozen=True)
class BoolConstant(BoolNode):
SYMBOL: ClassVar[str] = 'constB'
value: bool = False


@dc.dataclass(frozen=True, repr=False)
class Not(BoolNode, Op1Node):
SYMBOL: ClassVar[str] = 'not'
pass


@dc.dataclass(frozen=True, repr=False)
class And(BoolNode, OperationNode):
SYMBOL: ClassVar[str] = 'and'
pass


@dc.dataclass(frozen=True, repr=False)
class Or(BoolNode, OperationNode):
SYMBOL: ClassVar[str] = 'or'
pass


@dc.dataclass(frozen=True, repr=False)
class Implies(BoolNode, Op2Node):
SYMBOL: ClassVar[str] = 'impl'
pass


@dc.dataclass(frozen=True, repr=False)
class Equals(BoolNode, Op2Node):
SYMBOL: ClassVar[str] = '=='

def __post_init__(self):
super().__post_init__()


@dc.dataclass(frozen=True, repr=False)
class AtLeast(BoolNode, OperationNode):
SYMBOL: ClassVar[str] = 'atleast'

@property
def items(self) -> Tuple[str, ...]:
return self._items[:-1]
Expand All @@ -177,8 +171,6 @@ def value(self) -> str:

@dc.dataclass(frozen=True, repr=False)
class AtMost(BoolNode, OperationNode):
SYMBOL: ClassVar[str] = 'atmost'

@property
def items(self) -> Tuple[str, ...]:
return self._items[:-1]
Expand All @@ -190,28 +182,26 @@ def value(self) -> str:

@dc.dataclass(frozen=True, repr=False)
class LessThan(BoolNode, Op2Node):
SYMBOL: ClassVar[str] = '<'
pass


@dc.dataclass(frozen=True, repr=False)
class LessEqualThan(BoolNode, Op2Node):
SYMBOL: ClassVar[str] = '<='
pass


@dc.dataclass(frozen=True, repr=False)
class GreaterThan(BoolNode, Op2Node):
SYMBOL: ClassVar[str] = '>'
pass


@dc.dataclass(frozen=True, repr=False)
class GreaterEqualThan(BoolNode, Op2Node):
SYMBOL: ClassVar[str] = '>='
pass


@dc.dataclass(frozen=True, repr=False)
class Multiplexer(BoolNode, OperationNode):
SYMBOL: ClassVar[str] = 'mux'

def __post_init__(self):
super().__post_init__()
assert len(self._items) == 3, f'Wrong items count in node {self.name} of class {self.__class__.__name__}'
Expand All @@ -231,8 +221,7 @@ def parameter_2(self) -> str:

@dc.dataclass(frozen=True, repr=False)
class Switch(BoolNode, Op2Node):
SYMBOL: ClassVar[str] = 'switch'
off_value: bool = None
value: bool = None

@property
def origin(self) -> str:
Expand All @@ -248,8 +237,6 @@ def parameter(self) -> str:

@dc.dataclass(frozen=True, repr=False)
class If(OperationNode):
SYMBOL: ClassVar[str] = 'if'

def __post_init__(self):
super().__post_init__()
assert len(self._items) == 3, f'Wrong items count in node {self.name} of class {self.__class__.__name__}'
Expand All @@ -269,12 +256,12 @@ def if_false(self) -> str:

@dc.dataclass(frozen=True, repr=False)
class Copy(Op1Node, BoolNode, IntNode):
SYMBOL: ClassVar[str] = 'copy'
pass


@dc.dataclass(frozen=True, repr=False)
class PlaceHolder(BoolNode, IntNode):
SYMBOL: ClassVar[str] = 'holder'
pass


class Graph:
Expand Down Expand Up @@ -439,6 +426,126 @@ def placeholders(self) -> FrozenSet[PlaceHolder]:
class DotConverter:
# TODO: update / simplify

NODE_SYMBOL = bidict({
# inputs
BoolInput: 'inB',
IntInput: 'inI',
# constants
BoolConstant: 'constB',
IntConstant: 'constI',
# output
Copy: 'copy',
# placeholder
PlaceHolder: 'holder',
# int operations
ToInt: 'toInt',
Sum: 'sum',
AbsDiff: 'absdiff',
# bool operations
Not: 'not',
And: 'and',
Or: 'or',
Implies: 'impl',
# comparison operations
Equals: '==',
AtLeast: 'atleast',
AtMost: 'atmost',
LessThan: '<',
LessEqualThan: '<=',
GreaterThan: '>',
GreaterEqualThan: '>=',
# branching operations
Multiplexer: 'mux',
Switch: 'switch',
If: 'if',
})
NODE_SHAPE = {
# inputs
BoolInput: 'circle',
IntInput: 'circle',
# constants
BoolConstant: 'square',
IntConstant: 'square',
# output
Copy: 'doublecircle',
# placeholder
PlaceHolder: 'octagon',
# int operations
ToInt: 'invtrapezium',
Sum: 'invtrapezium',
AbsDiff: 'invtrapezium',
# bool operations
Not: 'invhouse',
And: 'invhouse',
Or: 'invhouse',
Implies: 'invhouse',
# comparison operations
Equals: 'invtriangle',
AtLeast: 'invtriangle',
AtMost: 'invtriangle',
LessThan: 'invtriangle',
LessEqualThan: 'invtriangle',
GreaterThan: 'invtriangle',
GreaterEqualThan: 'invtriangle',
# branching operations
Multiplexer: 'diamond',
Switch: 'diamond',
If: 'diamond',
}
NODE_LABEL_EXTRA: Dict[Type, Callable[[Node], str]] = {
# constants
BoolConstant: lambda n: rf'\nv={n.value}',
IntConstant: lambda n: rf'\nv={n.value}',
# branching operations
Switch: lambda n: rf'\nv={n.off_value}',
}

@classmethod
def _get_label(cls, node: Node) -> str:
# base informations
string = rf'{cls.NODE_SYMBOL[type(node)]}\n{node.name}'

# extra informations
if node.weight is not None:
string += rf'\nw={node.weight}'
if isinstance(Node, OperationNode):
string += rf'\ni={node._items}'
if isinstance(node, (BoolConstant, IntConstant, Switch)):
string += rf'\nv={node.value}'

return string

@classmethod
def _parse_label(cls, string: str) -> Node:
m = re.match(
''.join((
rf'({"|".join(cls.NODE_SYMBOL.inverse)})', # class
rf'\n(\w+)', # name
rf'(?:\nw=([-+]?\d+))', # weight
rf'(?:\nv=([-+]?\d+|True|False))?', # value
rf'(?:\ni=(\w+(?:,\w+)*))?', # items
)),
string
)

type = cls.NODE_SYMBOL.inverse[m[1]]
name = m[2]
weight = m[3] and int(m[3])
value = m[4] and {
BoolConstant: lambda: {'True': True, 'False': False}[m[4]],
Switch: lambda: {'True': True, 'False': False}[m[4]],
IntConstant: lambda: int(m[4])
}[type]()
items = m[5] and m[5].split(',')

arguments = {
k: v
for k, v in [('name', name), ('weight', weight), ('value', value), ('_items', items)]
if v is not None
}

return type(**arguments)

def __new__(cls) -> NoReturn:
raise TypeError(f'Cannot create instances of class {cls.__name__}')

Expand Down Expand Up @@ -566,7 +673,7 @@ def to_string(cls, graph: Graph) -> str:
fillcolor_f = ', fillcolor=red' if n.in_subgraph else ', fillcolor=white'

items_f = ''
symbol_l = n.SYMBOL
symbol_l = cls.NODE_SYMBOL[type(n)]
if isinstance(n, OperationNode):
items_s = ','.join(n._items)
items_f = f', items="{items_s}"'
Expand All @@ -578,13 +685,22 @@ def to_string(cls, graph: Graph) -> str:
edge_lines.extend(f' {src_name} -> {n.name};' for src_name in n._items)

return '\n'.join((
'strict digraph GGraph {',
f'strict digraph {type(graph).__name__} {{',
' node [style=filled, fillcolor=white];',
*node_lines,
*edge_lines,
'}',
))

@classmethod
def to_string(cls, graph: Graph) -> str:
node_lines = []
for node in graph.nodes:
node_lines.append(
f''
)
pass


class JSONConverter:
import json
Expand Down

0 comments on commit 281f910

Please sign in to comment.