Skip to content

Commit

Permalink
#81: added JSONConverter for newGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
marco-biasion committed Sep 12, 2024
1 parent 087824c commit 110c794
Showing 1 changed file with 53 additions and 5 deletions.
58 changes: 53 additions & 5 deletions sxpat/newGraph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import ClassVar, Collection, Dict, FrozenSet, Iterable, List, Mapping, NoReturn, Tuple, Union
from typing import ClassVar, Dict, FrozenSet, Iterable, List, Mapping, NoReturn, Tuple, Union

from collections import defaultdict
import dataclasses as dc
Expand All @@ -9,7 +9,7 @@
import itertools as it
import re

# from utils.collections import InheritanceMapping
from sxpat.utils.inheritance import get_all_leaves_subclasses, get_all_subclasses
from sxpat.utils.collections import InheritanceMapping


Expand All @@ -26,9 +26,6 @@ def copy(self, **update):
return type(self)(**{**vars(self), **update})


# NodeOrName = Union[Node, str]


@dc.dataclass(frozen=True, repr=False)
class BoolNode(Node):
pass
Expand Down Expand Up @@ -625,6 +622,57 @@ def to_string(cls, graph: Graph) -> str:
))


class JSONConverter:
import json

_CLASS_F = 'class'
_NODES_F = 'nodes'

_G_CLSS = {c.__name__: c for c in get_all_subclasses(Graph)}
_N_CLSS = {c.__name__: c for c in get_all_leaves_subclasses(Node)}

@classmethod
def dict_factory(cls, obj: object) -> dict:
return {cls._CLASS_F: obj.__class__.__name__, **vars(obj)}

@classmethod
def node_factory(cls, dct: dict) -> Node:
return cls._N_CLSS[dct.pop(cls._CLASS_F)](**dct)

@classmethod
def load_file(cls, filename: str) -> Graph:
with open(filename, 'r') as f:
string = f.read()
return cls.from_string(string)

@classmethod
def save_file(cls, graph: Graph, filename: str) -> None:
string = cls.to_string(graph)
with open(filename, 'w') as f:
f.write(string)

@classmethod
def from_string(cls, string: str) -> Graph:
_g: dict = cls.json.loads(string)
nodes = [cls.node_factory(n) for n in _g.pop(cls._NODES_F)]
return cls._G_CLSS[_g.pop(cls._CLASS_F)](nodes=nodes, **_g)

@classmethod
def to_string(cls, graph: Graph) -> str:
_g = {
cls._CLASS_F: graph.__class__.__name__,
cls._NODES_F: [cls.dict_factory(node) for node in graph.nodes],
}

if isinstance(graph, GGraph):
_g['inputs_names'] = graph.inputs_names
_g['outputs_names'] = graph.outputs_names
if isinstance(graph, TGraph):
_g['parameters_names'] = graph.parameters_names

return cls.json.dumps(_g, indent=4)


if __name__ == '__main__':

g = DotConverter.load_file_clean('output/gv/adder_i8_o5_Sop1_encz3bvec_si6m0_i7m1.gv')
Expand Down

0 comments on commit 110c794

Please sign in to comment.