Skip to content

Commit

Permalink
Refactor code for determining if two graphs are equal
Browse files Browse the repository at this point in the history
  • Loading branch information
Iñigo Gabirondo committed May 20, 2024
1 parent 2a93fe6 commit a066ffb
Showing 1 changed file with 17 additions and 43 deletions.
60 changes: 17 additions & 43 deletions tests/to_pyg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,42 +37,17 @@ def graph2() -> pg.ProgramGraph:
def graph3() -> pg.ProgramGraph:
return pg.from_cpp("int B(int x) { return x + 1; }")

def assert_equal_graphs(
def graphs_are_equal(
graph1: HeteroData,
graph2: HeteroData,
equality: bool = True
):
if equality:
assert graph1['nodes']['full_text'] == graph2['nodes']['full_text']

assert graph1['nodes', 'control', 'nodes'].edge_index.equal(graph2['nodes', 'control', 'nodes'].edge_index)
assert graph1['nodes', 'data', 'nodes'].edge_index.equal(graph2['nodes', 'data', 'nodes'].edge_index)
assert graph1['nodes', 'call', 'nodes'].edge_index.equal(graph2['nodes', 'call', 'nodes'].edge_index)
assert graph1['nodes', 'type', 'nodes'].edge_index.equal(graph2['nodes', 'type', 'nodes'].edge_index)

else:
text_different = graph1['nodes']['full_text'] != graph2['nodes']['full_text']

control_edges_different = not graph1['nodes', 'control', 'nodes'].edge_index.equal(
graph2['nodes', 'control', 'nodes'].edge_index
)
data_edges_different = not graph1['nodes', 'data', 'nodes'].edge_index.equal(
graph2['nodes', 'data', 'nodes'].edge_index
)
call_edges_different = not graph1['nodes', 'call', 'nodes'].edge_index.equal(
graph2['nodes', 'call', 'nodes'].edge_index
)
type_edges_different = not graph1['nodes', 'type', 'nodes'].edge_index.equal(
graph2['nodes', 'type', 'nodes'].edge_index
)

assert (
text_different
or control_edges_different
or data_edges_different
or call_edges_different
or type_edges_different
)
return (
(graph1['nodes']['full_text'] == graph2['nodes']['full_text'])
and (graph1['nodes', 'control', 'nodes'].edge_index.equal(graph2['nodes', 'control', 'nodes'].edge_index))
and (graph1['nodes', 'data', 'nodes'].edge_index.equal(graph2['nodes', 'data', 'nodes'].edge_index))
and (graph1['nodes', 'call', 'nodes'].edge_index.equal(graph2['nodes', 'call', 'nodes'].edge_index))
and (graph1['nodes', 'type', 'nodes'].edge_index.equal(graph2['nodes', 'type', 'nodes'].edge_index))
)

def test_to_pyg_simple_graph(graph: pg.ProgramGraph):
graphs = list(pg.to_pyg([graph]))
Expand All @@ -90,8 +65,8 @@ def test_to_pyg_different_two_different_inputs(
pyg_graph = pg.to_pyg(graph)
pyg_graph2 = pg.to_pyg(graph2)

# Ensure that the graphs are different
assert_equal_graphs(pyg_graph, pyg_graph2, equality=False)
# Ensure that the graphs are different
assert not graphs_are_equal(pyg_graph, pyg_graph2)

def test_to_pyg_different_inputs(
graph: pg.ProgramGraph,
Expand All @@ -102,21 +77,21 @@ def test_to_pyg_different_inputs(
pyg_graph2 = pg.to_pyg(graph2)
pyg_graph3 = pg.to_pyg(graph3)

# Ensure that the graphs are different
assert_equal_graphs(pyg_graph, pyg_graph2, equality=False)
assert_equal_graphs(pyg_graph, pyg_graph3, equality=False)
assert_equal_graphs(pyg_graph2, pyg_graph3, equality=False)
# Ensure that the graphs are different
assert not graphs_are_equal(pyg_graph, pyg_graph2)
assert not graphs_are_equal(pyg_graph, pyg_graph3)
assert not graphs_are_equal(pyg_graph2, pyg_graph3)

def test_to_pyg_two_inputs(graph: pg.ProgramGraph):
graphs = list(pg.to_pyg([graph, graph]))
assert len(graphs) == 2
assert_equal_graphs(graphs[0], graphs[1], equality=True)
assert graphs_are_equal(graphs[0], graphs[1])

def test_to_pyg_generator(graph: pg.ProgramGraph):
graphs = list(pg.to_pyg((graph for _ in range(10)), chunksize=3))
assert len(graphs) == 10
for x in graphs[1:]:
assert_equal_graphs(graphs[0], x, equality=True)
assert graphs_are_equal(graphs[0], x)

def test_to_pyg_generator_parallel_executor(graph: pg.ProgramGraph):
with ThreadPoolExecutor() as executor:
Expand All @@ -125,8 +100,7 @@ def test_to_pyg_generator_parallel_executor(graph: pg.ProgramGraph):
)
assert len(graphs) == 10
for x in graphs[1:]:
assert_equal_graphs(graphs[0], x, equality=True)

assert graphs_are_equal(graphs[0], x)

def test_to_pyg_smoke_test(llvm_program_graph: pg.ProgramGraph):
graphs = list(pg.to_pyg([llvm_program_graph]))
Expand Down

0 comments on commit a066ffb

Please sign in to comment.