Skip to content

Commit

Permalink
#81: nodes store names instead of other nodes. Implemented subgraph g…
Browse files Browse the repository at this point in the history
…etters in GGraph
  • Loading branch information
marco-biasion committed Sep 9, 2024
1 parent ff1492b commit df0014c
Showing 1 changed file with 64 additions and 32 deletions.
96 changes: 64 additions & 32 deletions sxpat/newGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,27 @@

import networkx as nx
import functools as ft
import itertools as it
import re

from sxpat.utils.collections import InheritanceMapping


# > precursors

@dc.dataclass(frozen=True, repr=False)
@dc.dataclass(frozen=True)
class Node:
name: str
weight: int
in_subgraph: bool
SYMBOL: ClassVar[str] = 'MISSING'

def __repr__(self, depth: int = 1) -> str:
return f'{self.__class__.__name__}(name={self.name!r}, weight={self.weight}, in_subgraph={1})'
def __post_init__(self) -> None:
object.__setattr__(self, 'weight', int(self.weight))
object.__setattr__(self, 'in_subgraph', bool(self.in_subgraph))

def copy(self, **update):
return type(self)(**{**vars(self), **update})


@dc.dataclass(frozen=True, repr=False)
Expand All @@ -32,33 +37,29 @@ class IntNode(Node):
pass


@dc.dataclass(frozen=True, repr=False)
@dc.dataclass(frozen=True)
class OperationNode(Node):
_REQUIRED_CLASSES: ClassVar[Mapping[int, type]] = dict()
_items: Collection[Node] = tuple()
_items: Iterable[Node] = tuple()

@property
def items(self) -> Collection[Node]:
def items(self) -> Collection[str]:
return self._items

def __post_init__(self):
object.__setattr__(self, '_items', tuple(self._items))

covered_positions = set()
# assert that all REQUIRED_CLASSES are respected
n = len(self._items)
items = tuple(self._items)
covered_positions = set()
for pos, type in self._REQUIRED_CLASSES.items():
if pos is None:
assert all(isinstance(item, type) for i, item in enumerate(self._items) if i not in covered_positions), f'Wrong item type in node {self.name} of class {self.__class__.__name__}'
assert all(isinstance(item, type) for i, item in enumerate(items) if i not in covered_positions), f'Wrong item type in node {self.name} of class {self.__class__.__name__}'
else:
covered_positions.add((n + pos) % n)
assert isinstance(self._items[pos], type), f'Wrong item type(item {pos}) in node {self.name} of class {self.__class__.__name__}'
assert isinstance(items[pos], type), f'Wrong item type(item {pos}) in node {self.name} of class {self.__class__.__name__}'

def __repr__(self, depth: int = 1) -> str:
if depth == 0:
return super().__repr__()
else:
items = ', '.join(_i.__repr__(depth - 1) for _i in self._items)
return f'{self.__class__.__name__}(name={self.name!r}, weight={self.weight}, in_subgraph={self.in_subgraph}, _items=[{items}])'
# store items names
object.__setattr__(self, '_items', tuple(node.name for node in self._items))


@dc.dataclass(frozen=True, repr=False)
Expand Down Expand Up @@ -94,7 +95,7 @@ class IntInput(IntNode):
SYMBOL: ClassVar[str] = 'inI'


@dc.dataclass(frozen=True, repr=False)
@dc.dataclass(frozen=True)
class IntConstant(IntNode):
SYMBOL: ClassVar[str] = 'constI'
value: int
Expand Down Expand Up @@ -125,7 +126,7 @@ class BoolInput(BoolNode):
SYMBOL: ClassVar[str] = 'inB'


@dc.dataclass(frozen=True, repr=False)
@dc.dataclass(frozen=True)
class BoolConstant(BoolNode):
SYMBOL: ClassVar[str] = 'constB'
value: bool
Expand Down Expand Up @@ -285,39 +286,70 @@ class GGraph:

K = object()

def __init__(self, nodes: Iterable[Node], /, inputs: Iterable[str] = (), outputs: Iterable[str] = ()) -> None:
def __init__(self, nodes: Iterable[Node], /,
input_names: Iterable[str] = (), output_names: Iterable[str] = ()
) -> None:
# generate inner mutable structure
self._graph = nx.DiGraph()
self._graph.add_nodes_from(
(node.name, {self.K: node})
(node.name, {'type': type(node), **vars(node)})
for node in nodes
)
self._graph.add_edges_from(
(src.name, dst_name)
(src_name, dst_name)
for dst_name, data in self._graph.nodes(data=True)
if isinstance(data[self.K], OperationNode)
for src in data[self.K].items
for src_name in data[self.K].items
)

# freeze local instances
self._graph = nx.freeze(self._graph)
self._inputs = tuple(inputs)
self._outputs = tuple(outputs)
self._input_names = tuple(input_names)
self._output_names = tuple(output_names)

def __getitem__(self, key: str) -> Node:
return self._graph.nodes[key][self.K]

def predecessors(self, node_or_name: Union[Node, str]) -> Collection[Node]:
name = node_or_name if isinstance(node_or_name, str) else node_or_name.name
return tuple(self._graph.nodes[_name][self.K] for _name in self._graph.predecessors(name))

def successors(self, node_or_name: Union[Node, str]) -> Collection[Node]:
name = node_or_name if isinstance(node_or_name, str) else node_or_name.name
return tuple(self._graph.nodes[_name][self.K] for _name in self._graph.successors(name))

@ft.cached_property
def nodes(self) -> Iterable[Node]:
def nodes(self) -> Collection[Node]:
return tuple(self._graph.nodes[name][self.K] for name in self._graph.nodes)

@ft.cached_property
def inputs(self) -> Iterable[Node]:
return tuple(self._graph.nodes[name][self.K] for name in self._inputs)
def inputs(self) -> Collection[Node]:
return tuple(self._graph.nodes[name][self.K] for name in self._input_names)

@ft.cached_property
def outputs(self) -> Iterable[Node]:
return tuple(self._graph.nodes[name][self.K] for name in self._outputs)
def outputs(self) -> Collection[Node]:
return tuple(self._graph.nodes[name][self.K] for name in self._output_names)

def __str__(self) -> str:
pass
@ft.cached_property
def subgraph_nodes(self) -> Collection[Node]:
return tuple(node for node in self.nodes if node.in_subgraph)

@ft.cached_property
def subgraph_inputs(self) -> Collection[Node]:
# if a node is a subgraph input if it is not in the subgraph and any successor is in the subgraph
return tuple(it.chain.from_iterable(
(pred for pred in self.predecessors(node) if not pred.in_subgraph)
for node in self.subgraph_nodes
))

@ft.cached_property
def subgraph_outputs(self) -> Collection[Node]:
# if a subgraph node is a subgraph output if any successor is not in the subgraph
return tuple(
node
for node in self.subgraph_nodes
if any(not succ.in_subgraph for succ in self.successors(node))
)

def to_nx_digraph(self, /, mutable: bool = False) -> nx.DiGraph:
return nx.DiGraph(self._graph) if mutable else self._graph
Expand Down

0 comments on commit df0014c

Please sign in to comment.