From a066ffb2f971388fe6867613a01a55ac62c3bdb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?I=C3=B1igo=20Gabirondo?= Date: Mon, 20 May 2024 21:07:21 +0200 Subject: [PATCH] Refactor code for determining if two graphs are equal --- tests/to_pyg_test.py | 60 +++++++++++++------------------------------- 1 file changed, 17 insertions(+), 43 deletions(-) diff --git a/tests/to_pyg_test.py b/tests/to_pyg_test.py index c42f2380..56cdfe57 100644 --- a/tests/to_pyg_test.py +++ b/tests/to_pyg_test.py @@ -37,42 +37,17 @@ def graph2() -> pg.ProgramGraph: def graph3() -> pg.ProgramGraph: return pg.from_cpp("int B(int x) { return x + 1; }") -def assert_equal_graphs( +def graphs_are_equal( graph1: HeteroData, graph2: HeteroData, - equality: bool = True ): - if equality: - assert graph1['nodes']['full_text'] == graph2['nodes']['full_text'] - - assert graph1['nodes', 'control', 'nodes'].edge_index.equal(graph2['nodes', 'control', 'nodes'].edge_index) - assert graph1['nodes', 'data', 'nodes'].edge_index.equal(graph2['nodes', 'data', 'nodes'].edge_index) - assert graph1['nodes', 'call', 'nodes'].edge_index.equal(graph2['nodes', 'call', 'nodes'].edge_index) - assert graph1['nodes', 'type', 'nodes'].edge_index.equal(graph2['nodes', 'type', 'nodes'].edge_index) - - else: - text_different = graph1['nodes']['full_text'] != graph2['nodes']['full_text'] - - control_edges_different = not graph1['nodes', 'control', 'nodes'].edge_index.equal( - graph2['nodes', 'control', 'nodes'].edge_index - ) - data_edges_different = not graph1['nodes', 'data', 'nodes'].edge_index.equal( - graph2['nodes', 'data', 'nodes'].edge_index - ) - call_edges_different = not graph1['nodes', 'call', 'nodes'].edge_index.equal( - graph2['nodes', 'call', 'nodes'].edge_index - ) - type_edges_different = not graph1['nodes', 'type', 'nodes'].edge_index.equal( - graph2['nodes', 'type', 'nodes'].edge_index - ) - - assert ( - text_different - or control_edges_different - or data_edges_different - or call_edges_different - or type_edges_different - ) + return ( + (graph1['nodes']['full_text'] == graph2['nodes']['full_text']) + and (graph1['nodes', 'control', 'nodes'].edge_index.equal(graph2['nodes', 'control', 'nodes'].edge_index)) + and (graph1['nodes', 'data', 'nodes'].edge_index.equal(graph2['nodes', 'data', 'nodes'].edge_index)) + and (graph1['nodes', 'call', 'nodes'].edge_index.equal(graph2['nodes', 'call', 'nodes'].edge_index)) + and (graph1['nodes', 'type', 'nodes'].edge_index.equal(graph2['nodes', 'type', 'nodes'].edge_index)) + ) def test_to_pyg_simple_graph(graph: pg.ProgramGraph): graphs = list(pg.to_pyg([graph])) @@ -90,8 +65,8 @@ def test_to_pyg_different_two_different_inputs( pyg_graph = pg.to_pyg(graph) pyg_graph2 = pg.to_pyg(graph2) - # Ensure that the graphs are different - assert_equal_graphs(pyg_graph, pyg_graph2, equality=False) + # Ensure that the graphs are different + assert not graphs_are_equal(pyg_graph, pyg_graph2) def test_to_pyg_different_inputs( graph: pg.ProgramGraph, @@ -102,21 +77,21 @@ def test_to_pyg_different_inputs( pyg_graph2 = pg.to_pyg(graph2) pyg_graph3 = pg.to_pyg(graph3) - # Ensure that the graphs are different - assert_equal_graphs(pyg_graph, pyg_graph2, equality=False) - assert_equal_graphs(pyg_graph, pyg_graph3, equality=False) - assert_equal_graphs(pyg_graph2, pyg_graph3, equality=False) + # Ensure that the graphs are different + assert not graphs_are_equal(pyg_graph, pyg_graph2) + assert not graphs_are_equal(pyg_graph, pyg_graph3) + assert not graphs_are_equal(pyg_graph2, pyg_graph3) def test_to_pyg_two_inputs(graph: pg.ProgramGraph): graphs = list(pg.to_pyg([graph, graph])) assert len(graphs) == 2 - assert_equal_graphs(graphs[0], graphs[1], equality=True) + assert graphs_are_equal(graphs[0], graphs[1]) def test_to_pyg_generator(graph: pg.ProgramGraph): graphs = list(pg.to_pyg((graph for _ in range(10)), chunksize=3)) assert len(graphs) == 10 for x in graphs[1:]: - assert_equal_graphs(graphs[0], x, equality=True) + assert graphs_are_equal(graphs[0], x) def test_to_pyg_generator_parallel_executor(graph: pg.ProgramGraph): with ThreadPoolExecutor() as executor: @@ -125,8 +100,7 @@ def test_to_pyg_generator_parallel_executor(graph: pg.ProgramGraph): ) assert len(graphs) == 10 for x in graphs[1:]: - assert_equal_graphs(graphs[0], x, equality=True) - + assert graphs_are_equal(graphs[0], x) def test_to_pyg_smoke_test(llvm_program_graph: pg.ProgramGraph): graphs = list(pg.to_pyg([llvm_program_graph]))