From 281f910a95ca3258a028737059e95162188a80e7 Mon Sep 17 00:00:00 2001 From: Marco Biasion Date: Thu, 12 Sep 2024 20:52:43 +0200 Subject: [PATCH] #81: removed Node.SYMBOL and delegated it to DotConverter --- sxpat/newGraph.py | 182 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 149 insertions(+), 33 deletions(-) diff --git a/sxpat/newGraph.py b/sxpat/newGraph.py index bcc69a75b..32d30ed27 100644 --- a/sxpat/newGraph.py +++ b/sxpat/newGraph.py @@ -1,7 +1,8 @@ from __future__ import annotations -from typing import ClassVar, Dict, FrozenSet, Iterable, List, Mapping, NoReturn, Tuple, Union +from typing import Callable, Dict, FrozenSet, Iterable, List, NoReturn, Tuple, Type, Union from collections import defaultdict +from bidict import bidict import dataclasses as dc import networkx as nx @@ -34,7 +35,6 @@ class Node: name: str weight: int = None in_subgraph: bool = None - SYMBOL: ClassVar[str] = 'MISSING' def __post_init__(self) -> None: assert re.match(r'^\w+$', self.name), f'The name `{self.name}` is invalid, it must match regex `\w+`.' @@ -96,28 +96,27 @@ def right(self) -> str: @dc.dataclass(frozen=True, repr=False) class IntInput(IntNode): - SYMBOL: ClassVar[str] = 'inI' + pass @dc.dataclass(frozen=True) class IntConstant(IntNode): - SYMBOL: ClassVar[str] = 'constI' value: int = 0 @dc.dataclass(frozen=True, repr=False) class ToInt(IntNode, OperationNode): - SYMBOL: ClassVar[str] = 'toInt' + pass @dc.dataclass(frozen=True, repr=False) class Sum(IntNode, OperationNode): - SYMBOL: ClassVar[str] = 'sum' + pass @dc.dataclass(frozen=True, repr=False) class AbsDiff(IntNode, Op2Node): - SYMBOL: ClassVar[str] = 'absdiff' + pass # > bool @@ -125,47 +124,42 @@ class AbsDiff(IntNode, Op2Node): @dc.dataclass(frozen=True, repr=False) class BoolInput(BoolNode): - SYMBOL: ClassVar[str] = 'inB' + pass @dc.dataclass(frozen=True) class BoolConstant(BoolNode): - SYMBOL: ClassVar[str] = 'constB' value: bool = False @dc.dataclass(frozen=True, repr=False) class Not(BoolNode, Op1Node): - SYMBOL: ClassVar[str] = 'not' + pass @dc.dataclass(frozen=True, repr=False) class And(BoolNode, OperationNode): - SYMBOL: ClassVar[str] = 'and' + pass @dc.dataclass(frozen=True, repr=False) class Or(BoolNode, OperationNode): - SYMBOL: ClassVar[str] = 'or' + pass @dc.dataclass(frozen=True, repr=False) class Implies(BoolNode, Op2Node): - SYMBOL: ClassVar[str] = 'impl' + pass @dc.dataclass(frozen=True, repr=False) class Equals(BoolNode, Op2Node): - SYMBOL: ClassVar[str] = '==' - def __post_init__(self): super().__post_init__() @dc.dataclass(frozen=True, repr=False) class AtLeast(BoolNode, OperationNode): - SYMBOL: ClassVar[str] = 'atleast' - @property def items(self) -> Tuple[str, ...]: return self._items[:-1] @@ -177,8 +171,6 @@ def value(self) -> str: @dc.dataclass(frozen=True, repr=False) class AtMost(BoolNode, OperationNode): - SYMBOL: ClassVar[str] = 'atmost' - @property def items(self) -> Tuple[str, ...]: return self._items[:-1] @@ -190,28 +182,26 @@ def value(self) -> str: @dc.dataclass(frozen=True, repr=False) class LessThan(BoolNode, Op2Node): - SYMBOL: ClassVar[str] = '<' + pass @dc.dataclass(frozen=True, repr=False) class LessEqualThan(BoolNode, Op2Node): - SYMBOL: ClassVar[str] = '<=' + pass @dc.dataclass(frozen=True, repr=False) class GreaterThan(BoolNode, Op2Node): - SYMBOL: ClassVar[str] = '>' + pass @dc.dataclass(frozen=True, repr=False) class GreaterEqualThan(BoolNode, Op2Node): - SYMBOL: ClassVar[str] = '>=' + pass @dc.dataclass(frozen=True, repr=False) class Multiplexer(BoolNode, OperationNode): - SYMBOL: ClassVar[str] = 'mux' - def __post_init__(self): super().__post_init__() assert len(self._items) == 3, f'Wrong items count in node {self.name} of class {self.__class__.__name__}' @@ -231,8 +221,7 @@ def parameter_2(self) -> str: @dc.dataclass(frozen=True, repr=False) class Switch(BoolNode, Op2Node): - SYMBOL: ClassVar[str] = 'switch' - off_value: bool = None + value: bool = None @property def origin(self) -> str: @@ -248,8 +237,6 @@ def parameter(self) -> str: @dc.dataclass(frozen=True, repr=False) class If(OperationNode): - SYMBOL: ClassVar[str] = 'if' - def __post_init__(self): super().__post_init__() assert len(self._items) == 3, f'Wrong items count in node {self.name} of class {self.__class__.__name__}' @@ -269,12 +256,12 @@ def if_false(self) -> str: @dc.dataclass(frozen=True, repr=False) class Copy(Op1Node, BoolNode, IntNode): - SYMBOL: ClassVar[str] = 'copy' + pass @dc.dataclass(frozen=True, repr=False) class PlaceHolder(BoolNode, IntNode): - SYMBOL: ClassVar[str] = 'holder' + pass class Graph: @@ -439,6 +426,126 @@ def placeholders(self) -> FrozenSet[PlaceHolder]: class DotConverter: # TODO: update / simplify + NODE_SYMBOL = bidict({ + # inputs + BoolInput: 'inB', + IntInput: 'inI', + # constants + BoolConstant: 'constB', + IntConstant: 'constI', + # output + Copy: 'copy', + # placeholder + PlaceHolder: 'holder', + # int operations + ToInt: 'toInt', + Sum: 'sum', + AbsDiff: 'absdiff', + # bool operations + Not: 'not', + And: 'and', + Or: 'or', + Implies: 'impl', + # comparison operations + Equals: '==', + AtLeast: 'atleast', + AtMost: 'atmost', + LessThan: '<', + LessEqualThan: '<=', + GreaterThan: '>', + GreaterEqualThan: '>=', + # branching operations + Multiplexer: 'mux', + Switch: 'switch', + If: 'if', + }) + NODE_SHAPE = { + # inputs + BoolInput: 'circle', + IntInput: 'circle', + # constants + BoolConstant: 'square', + IntConstant: 'square', + # output + Copy: 'doublecircle', + # placeholder + PlaceHolder: 'octagon', + # int operations + ToInt: 'invtrapezium', + Sum: 'invtrapezium', + AbsDiff: 'invtrapezium', + # bool operations + Not: 'invhouse', + And: 'invhouse', + Or: 'invhouse', + Implies: 'invhouse', + # comparison operations + Equals: 'invtriangle', + AtLeast: 'invtriangle', + AtMost: 'invtriangle', + LessThan: 'invtriangle', + LessEqualThan: 'invtriangle', + GreaterThan: 'invtriangle', + GreaterEqualThan: 'invtriangle', + # branching operations + Multiplexer: 'diamond', + Switch: 'diamond', + If: 'diamond', + } + NODE_LABEL_EXTRA: Dict[Type, Callable[[Node], str]] = { + # constants + BoolConstant: lambda n: rf'\nv={n.value}', + IntConstant: lambda n: rf'\nv={n.value}', + # branching operations + Switch: lambda n: rf'\nv={n.off_value}', + } + + @classmethod + def _get_label(cls, node: Node) -> str: + # base informations + string = rf'{cls.NODE_SYMBOL[type(node)]}\n{node.name}' + + # extra informations + if node.weight is not None: + string += rf'\nw={node.weight}' + if isinstance(Node, OperationNode): + string += rf'\ni={node._items}' + if isinstance(node, (BoolConstant, IntConstant, Switch)): + string += rf'\nv={node.value}' + + return string + + @classmethod + def _parse_label(cls, string: str) -> Node: + m = re.match( + ''.join(( + rf'({"|".join(cls.NODE_SYMBOL.inverse)})', # class + rf'\n(\w+)', # name + rf'(?:\nw=([-+]?\d+))', # weight + rf'(?:\nv=([-+]?\d+|True|False))?', # value + rf'(?:\ni=(\w+(?:,\w+)*))?', # items + )), + string + ) + + type = cls.NODE_SYMBOL.inverse[m[1]] + name = m[2] + weight = m[3] and int(m[3]) + value = m[4] and { + BoolConstant: lambda: {'True': True, 'False': False}[m[4]], + Switch: lambda: {'True': True, 'False': False}[m[4]], + IntConstant: lambda: int(m[4]) + }[type]() + items = m[5] and m[5].split(',') + + arguments = { + k: v + for k, v in [('name', name), ('weight', weight), ('value', value), ('_items', items)] + if v is not None + } + + return type(**arguments) + def __new__(cls) -> NoReturn: raise TypeError(f'Cannot create instances of class {cls.__name__}') @@ -566,7 +673,7 @@ def to_string(cls, graph: Graph) -> str: fillcolor_f = ', fillcolor=red' if n.in_subgraph else ', fillcolor=white' items_f = '' - symbol_l = n.SYMBOL + symbol_l = cls.NODE_SYMBOL[type(n)] if isinstance(n, OperationNode): items_s = ','.join(n._items) items_f = f', items="{items_s}"' @@ -578,13 +685,22 @@ def to_string(cls, graph: Graph) -> str: edge_lines.extend(f' {src_name} -> {n.name};' for src_name in n._items) return '\n'.join(( - 'strict digraph GGraph {', + f'strict digraph {type(graph).__name__} {{', ' node [style=filled, fillcolor=white];', *node_lines, *edge_lines, '}', )) + @classmethod + def to_string(cls, graph: Graph) -> str: + node_lines = [] + for node in graph.nodes: + node_lines.append( + f'' + ) + pass + class JSONConverter: import json