From 15e1a277dede418f7b23f91a84d8d2f7efda4da8 Mon Sep 17 00:00:00 2001 From: Halley Fritze <97766437+hfr1tz3@users.noreply.github.com> Date: Tue, 21 Jan 2025 11:29:43 -0800 Subject: [PATCH] new test for n2 matching --- tests/test_methods.py | 54 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/tests/test_methods.py b/tests/test_methods.py index 61ae9b8..0084b68 100644 --- a/tests/test_methods.py +++ b/tests/test_methods.py @@ -521,3 +521,57 @@ def test_extra_match(self): assert np.isclose(dis.tpr, n2_match_span / true_spans[1]) assert np.isclose(dis.dissimilarity, true_spans[0] - n1_match_span) assert np.isclose(dis.inverse_dissimilarity, true_spans[1] - n2_match_span) + + def get_n2_match_ex(self, samples=None, extra_nodes=False): + node_times = { + 0: 0.0, + 1: 0.0, + 2: 0.0, + 3: 100.0, + 4: 200.0, + } + if extra_nodes: + node_times[5] = 300.0 + # (p, c, l, r) ordered by p[time] + edges = [ + (3, 0, 0, 3), + (3, 1, 0, 3), + (4, 3, 0, 3), + (5, 2, 0, 3), + (5, 4, 0, 3) + ] + else: + edges = [ + (3, 0, 0, 3), + (3, 1, 0, 3), + (4, 2, 0, 3), + (4, 3, 0, 3) + ] + tables = tskit.TableCollection(sequence_length=3) + if samples is None: + samples = [0, 1, 2] + for ( + n, + t, + ) in node_times.items(): + flags = tskit.NODE_IS_SAMPLE if n in samples else 0 + tables.nodes.add_row(time=t, flags=flags) + for p, c, l, r in edges: + tables.edges.add_row(parent=p, child=c, left=l, right=r) + ts = tables.tree_sequence() + if extra_nodes is True: + assert ts.num_edges == 5 + if extra_nodes is False: + assert ts.num_edges == 4 + return ts + + def test_n2_matching(self): + ts = self.get_n2_match_ex() + other = self.get_n2_match_ex(extra_nodes=True) + dis = tscompare.compare(ts, other, transform=None) + true_spans = (15, 18) + match_spans = (15, 15) + assert np.isclose(dis.arf, 1 - match_spans[0]/true_spans[0]) + assert np.isclose(dis.tpr, match_spans[1]/true_spans[1]) + assert np.isclose(dis.dissimilarity, true_spans[0] - match_spans[0]) + assert np.isclose(dis.inverse_dissimilarity, true_spans[1] - match_spans[1]) \ No newline at end of file