From e7320f6a42448dbb03681950b8615ff5b60b3f51 Mon Sep 17 00:00:00 2001 From: peter Date: Fri, 9 Aug 2024 15:35:46 -0700 Subject: [PATCH 1/2] stuff from discrepancy --- tests/test_discrepancy.py | 391 ++++++++++++++++++++++++++++++++++++++ tsdate/discrepancy.py | 375 ++++++++++++++++++++++++++++++++++++ 2 files changed, 766 insertions(+) create mode 100644 tests/test_discrepancy.py create mode 100644 tsdate/discrepancy.py diff --git a/tests/test_discrepancy.py b/tests/test_discrepancy.py new file mode 100644 index 00000000..3070f810 --- /dev/null +++ b/tests/test_discrepancy.py @@ -0,0 +1,391 @@ +# MIT License +# +# Copyright (c) 2021-23 Tskit Developers +# Copyright (c) 2020-21 University of Oxford +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Test tools for mapping between node sets of different tree sequences +""" + +from collections import defaultdict +from itertools import combinations + +import msprime +import numpy as np +import pytest +import scipy.sparse +import tsinfer +import tskit + +from tsdate import evaluation + +# --- simulate test case --- +demo = msprime.Demography.isolated_model([1e4]) +for t in np.linspace(500, 10000, 20): + demo.add_census(time=t) +true_unary = msprime.sim_ancestry( + samples=10, + sequence_length=1e6, + demography=demo, + recombination_rate=1e-8, + random_seed=1024, +) +true_unary = msprime.sim_mutations(true_unary, rate=2e-8, random_seed=1024) +assert true_unary.num_trees > 1 +true_simpl = true_unary.simplify(filter_sites=False) +sample_dat = tsinfer.SampleData.from_tree_sequence(true_simpl) +infr_unary = tsinfer.infer(sample_dat) +infr_simpl = infr_unary.simplify(filter_sites=False) +true_ext = true_simpl.extend_edges() + + +def naive_shared_node_spans(ts, other): + """ + Inefficient but transparent function to get span where nodes from two tree + sequences subtend the same sample set + """ + + def _clade_dict(tree): + clade_to_node = defaultdict(set) + for node in tree.nodes(): + clade = frozenset(tree.samples(node)) + clade_to_node[clade].add(node) + return clade_to_node + + assert ts.sequence_length == other.sequence_length + assert ts.num_samples == other.num_samples + out = np.zeros((ts.num_nodes, other.num_nodes)) + for interval, query_tree, target_tree in ts.coiterate(other): + query = _clade_dict(query_tree) + target = _clade_dict(target_tree) + span = interval.right - interval.left + for clade, nodes in query.items(): + if clade in target: + for i in nodes: + for j in target[clade]: + out[i, j] += span + return scipy.sparse.csr_matrix(out) + + +def naive_node_span(ts): + """ + Ineffiecient but transparent function to get total span + of each node in a tree sequence, including roots. + """ + node_spans = np.zeros(ts.num_nodes) + for t in ts.trees(): + for n in t.nodes(): + if t.parent(n) != tskit.NULL or t.num_children(n) > 0: + span = t.span + node_spans[n] += span + return node_spans + + +def naive_discrepancy(ts, other): + """ + Ineffiecient but transparent function to compute discrepancy + and root-mean-square-error between two tree sequences. + """ + shared_spans = naive_shared_node_spans(ts, other).toarray() + max_span = np.max(shared_spans, axis=1) + assert len(max_span) == ts.num_nodes + time_array = np.zeros((ts.num_nodes, other.num_nodes)) + discrepancy_matrix = np.zeros((ts.num_nodes, other.num_nodes)) + for i in range(ts.num_nodes): + # Skip nodes with no match in shared_spans + if max_span[i] == 0: + continue + else: + for j in range(other.num_nodes): + if shared_spans[i, j] == max_span[i]: + time_array[i, j] = np.abs(ts.nodes_time[i] - other.nodes_time[j]) + discrepancy_matrix[i, j] = 1 / (1 + time_array[i, j]) + best_match = np.argmax(discrepancy_matrix, axis=1) + best_match_spans = np.zeros((ts.num_nodes,)) + time_discrepancies = np.zeros((ts.num_nodes,)) + for i, j in enumerate(best_match): + best_match_spans[i] = shared_spans[i, j] + time_discrepancies[i] = time_array[i, j] + node_span = naive_node_span(ts) + total_node_spans = np.sum(node_span) + discrepancy = 1 - np.sum(best_match_spans) / total_node_spans + rmse = np.sqrt(np.sum(node_span * time_discrepancies**2) / total_node_spans) + return discrepancy, rmse + + +@pytest.mark.parametrize("ts", [true_unary, infr_unary, true_simpl, infr_simpl]) +class TestCladeMap: + def test_map(self, ts): + """ + test that clade map has correct nodes, clades + """ + clade_map = evaluation.CladeMap(ts) + for tree in ts.trees(): + for node in tree.nodes(): + clade = frozenset(tree.samples(node)) + assert node in clade_map._nodes[clade] + assert clade_map._clades[node] == clade + clade_map.next() + + def test_diff(self, ts): + """ + test difference in clades between adjacent trees. + """ + clade_map = evaluation.CladeMap(ts) + tree_1 = ts.first() + tree_2 = ts.first() + while True: + tree_2.next() + diff = clade_map.next() + diff_test = {} + for n in set(tree_1.nodes()) | set(tree_2.nodes()): + prev = frozenset(tree_1.samples(n)) + curr = frozenset(tree_2.samples(n)) + if prev != curr: + diff_test[n] = (prev, curr) + for node in diff_test.keys() | diff.keys(): + assert diff_test[node][0] == diff[node][0] + assert diff_test[node][1] == diff[node][1] + if tree_2.index == ts.num_trees - 1: + break + tree_1.next() + + +class TestNodeMatching: + @pytest.mark.parametrize( + "ts", + [infr_simpl, true_simpl, infr_unary, true_unary], + ) + def test_node_spans(self, ts): + eval_ns = evaluation.node_spans(ts) + naive_ns = naive_node_span(ts) + assert np.allclose(eval_ns, naive_ns) + + @pytest.mark.parametrize( + "pair", combinations([infr_simpl, true_simpl, infr_unary, true_unary], 2) + ) + def test_shared_spans(self, pair): + """ + Check that efficient implementation returns same answer as naive + implementation + """ + check = naive_shared_node_spans(pair[0], pair[1]) + test = evaluation.shared_node_spans(pair[0], pair[1]) + assert check.shape == test.shape + assert check.nnz == test.nnz + assert np.allclose(check.data, test.data) + + @pytest.mark.parametrize("ts", [infr_simpl, true_simpl]) + def test_match_self(self, ts): + """ + Check that matching against self returns node ids + + TODO: this'll only work reliably when there's not unary nodes. + """ + time, _, hit = evaluation.match_node_ages(ts, ts) + assert np.allclose(time, ts.nodes_time) + assert np.array_equal(hit, np.arange(ts.num_nodes)) + + @pytest.mark.parametrize( + "pair", + [(true_ext, true_ext), (true_simpl, true_ext), (true_simpl, true_unary)], + ) + def test_basic_discrepancy(self, pair): + """ + Check that efficient implementation reutrns the same answer as naive + implementation. + """ + check_dis, check_err = naive_discrepancy(pair[0], pair[1]) + test_dis, test_err = evaluation.tree_discrepancy(pair[0], pair[1]) + assert np.isclose(check_dis, test_dis) + assert np.isclose(check_err, test_err) + + @pytest.mark.parametrize( + "pair", + [(true_ext, true_ext), (true_simpl, true_ext), (true_simpl, true_unary)], + ) + def test_zero_discrepancy(self, pair): + dis, err = evaluation.tree_discrepancy(pair[0], pair[1]) + assert np.isclose(dis, 0) + assert np.isclose(err, 0) + + def get_simple_ts(self, samples=None, time=False, span=False, no_match=False): + # A simple tree sequence we can use to properly test various + # discrepancy and MSRE values. + # + # 6 6 6 + # +-+-+ +-+-+ +-+-+ + # | | 7 | | 8 + # | | ++-+ | | +-++ + # 4 5 4 | 5 4 | 5 + # +++ +++ +++ | | | | +++ + # 0 1 2 3 0 1 2 3 0 1 2 3 + # + # if time = False: + # with node times 0.0, 500.0, 750.0, 1000.0 for each tier, + # else: + # with node times 0.0, 200.0, 600.0, 1000.0 for each tier, + # + # if span = False: + # each tree spans (0,2), (2,4), and (4,6) respectively. + # else: + # each tree spans (0,1), (1,5), and (5,6) repectively. + if time is False: + node_times = { + 0: 0, + 1: 0, + 2: 0, + 3: 0, + 4: 500.0, + 5: 500.0, + 6: 1000.0, + 7: 750.0, + 8: 750.0, + } + else: + node_times = { + 0: 0, + 1: 0, + 2: 0, + 3: 0, + 4: 200.0, + 5: 200.0, + 6: 1000.0, + 7: 600.0, + 8: 600.0, + } + # (p, c, l, r) + if span is False: + edges = [ + (4, 0, 0, 6), + (4, 1, 0, 4), + (5, 2, 0, 2), + (5, 2, 4, 6), + (5, 3, 0, 6), + (7, 2, 2, 4), + (7, 4, 2, 4), + (8, 1, 4, 6), + (8, 5, 4, 6), + (6, 4, 0, 2), + (6, 4, 4, 6), + (6, 5, 0, 4), + (6, 7, 2, 4), + (6, 8, 4, 6), + ] + else: + edges = [ + (4, 0, 0, 6), + (4, 1, 0, 5), + (5, 2, 0, 1), + (5, 2, 5, 6), + (5, 3, 0, 6), + (7, 2, 1, 5), + (7, 4, 1, 5), + (8, 1, 5, 6), + (8, 5, 5, 6), + (6, 4, 0, 1), + (6, 4, 5, 6), + (6, 5, 0, 5), + (6, 7, 1, 5), + (6, 8, 5, 6), + ] + if no_match is True: + node_times[9] = 100.0 + if span is False: + edges = [ + (9, 0, 4, 6), + (4, 0, 0, 4), + (4, 1, 0, 6), + (4, 9, 4, 6), + (5, 2, 0, 2), + (5, 2, 4, 6), + (5, 3, 0, 6), + (7, 2, 2, 4), + (7, 4, 2, 4), + (6, 4, 0, 2), + (6, 4, 4, 6), + (6, 5, 0, 6), + (6, 7, 2, 4), + ] + else: + edges = [ + (9, 0, 5, 6), + (4, 0, 0, 5), + (4, 1, 0, 6), + (4, 9, 5, 6), + (5, 2, 0, 2), + (5, 2, 5, 6), + (5, 3, 0, 6), + (7, 2, 2, 5), + (7, 4, 2, 5), + (6, 4, 0, 2), + (6, 4, 5, 6), + (6, 5, 0, 6), + (6, 7, 2, 5), + ] + tables = tskit.TableCollection(sequence_length=6) + if samples is None: + samples = [0, 1, 2, 3] + for ( + n, + t, + ) in node_times.items(): + flags = tskit.NODE_IS_SAMPLE if n in samples else 0 + tables.nodes.add_row(time=t, flags=flags) + for p, c, l, r in edges: + tables.edges.add_row(parent=p, child=c, left=l, right=r) + ts = tables.tree_sequence() + if no_match is True: + assert ts.num_edges == 13 + if no_match is False: + assert ts.num_edges == 14 + return ts + + def test_discrepancy_value(self): + ts = self.get_simple_ts() + other = self.get_simple_ts(span=True) + dis, err = evaluation.tree_discrepancy(ts, other) + assert np.isclose(dis, 4 / 46) + assert np.isclose(err, 0.0) + + def test_discrepancy_error(self): + ts = self.get_simple_ts() + other = self.get_simple_ts(time=True) + dis, err = evaluation.tree_discrepancy(ts, other) + true_error = np.sqrt((2 * 6 * 300**2 + 2 * 2 * 150**2) / 46) + assert np.isclose(dis, 0.0) + assert np.isclose(err, true_error) + + def test_discrepancy_value_and_error(self): + ts = self.get_simple_ts() + other = self.get_simple_ts(span=True, time=True) + dis, err = evaluation.tree_discrepancy(ts, other) + true_error = np.sqrt((2 * 6 * 300**2 + 2 * 2 * 150**2) / 46) + assert np.isclose(dis, 4 / 46) + assert np.isclose(err, true_error) + + def test_discrepancy_and_naive_discrepancy_with_no_match(self): + ts = self.get_simple_ts() + other = self.get_simple_ts(span=True, time=True, no_match=True) + check_dis, check_err = naive_discrepancy(ts, other) + test_dis, test_err = evaluation.tree_discrepancy(ts, other) + assert np.isclose(check_dis, test_dis) + assert np.isclose(check_err, test_err) diff --git a/tsdate/discrepancy.py b/tsdate/discrepancy.py new file mode 100644 index 00000000..77478aa9 --- /dev/null +++ b/tsdate/discrepancy.py @@ -0,0 +1,375 @@ +# MIT License +# +# Copyright (c) 2021-23 Tskit Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Tools for comparing node times between tree sequences with different node sets +""" + +import copy +from collections import defaultdict +from itertools import product + +import numpy as np +import scipy.sparse +import tskit + + +class CladeMap: + """ + An iterator across trees that maintains a mapping from a clade (a `frozenset` of + sample IDs) to a `set` of nodes. When there are unary nodes, there may be multiple + nodes associated with each clade. + """ + + def __init__(self, ts): + self._nil = frozenset() + self._nodes = defaultdict(set) # nodes[clade] = {node ids} + self._clades = defaultdict(frozenset) # clades[node] = {sample ids} + self.tree_sequence = ts + self.tree = ts.first(sample_lists=True) + for node in self.tree.nodes(): + clade = frozenset(self.tree.samples(node)) + self._nodes[clade].add(node) + self._clades[node] = clade + self._prev = copy.deepcopy(self._clades) + self._diff = ts.edge_diffs() + next(self._diff) + + def _propagate(self, edge, downdate=False): + """ + Traverse path from `edge.parent` to root, either adding or removing the + state (clade) associated with `edge.child` from the state of each + visited node. Return a set with the node ids encountered during + traversal. + """ + nodes = set() + node = edge.parent + clade = self._clades[edge.child] + while node != tskit.NULL: + last = self._clades[node] + self._clades[node] = last - clade if downdate else last | clade + if len(last): + self._nodes[last].remove(node) + if len(self._nodes[last]) == 0: + del self._nodes[last] + self._nodes[self._clades[node]].add(node) + nodes.add(node) + node = self.tree.parent(node) + return nodes + + def next(self): + """ + Advance to the next tree, returning the difference between trees as a + dictionary of the form `node : (last_clade, next_clade)` + """ + nodes = set() # nodes with potentially altered clades + diff = {} # diff[node] = (prev_clade, curr_clade) + + if self.tree.index + 1 == self.tree_sequence.num_trees: + return None + + # Subtract clades subtended by outgoing edges + edge_diff = next(self._diff) + for eo in edge_diff.edges_out: + nodes |= self._propagate(eo, downdate=True) + + # Prune nodes that are no longer in tree + for node in self._nodes[self._nil]: + diff[node] = (self._prev[node], self._nil) + del self._clades[node] + nodes -= self._nodes[self._nil] + self._nodes[self._nil].clear() + + # Add clades subtended by incoming edges + self.tree.next() + for ei in edge_diff.edges_in: + nodes |= self._propagate(ei, downdate=False) + + # Find difference in clades between adjacent trees + for node in nodes: + diff[node] = (self._prev[node], self._clades[node]) + if self._prev[node] == self._clades[node]: + del diff[node] + + # Sync previous and current states + for node, (_, curr) in diff.items(): + if curr == self._nil: + del self._prev[node] + else: + self._prev[node] = curr + + return diff + + @property + def interval(self): + """ + Return interval spanned by tree + """ + return self.tree.interval + + def clades(self): + """ + Return set of clades in tree + """ + return self._nodes.keys() - self._nil + + def __getitem__(self, clade): + """ + Return set of nodes associated with a given clade. + """ + return frozenset(self._nodes[clade]) if frozenset(clade) in self else self._nil + + def __contains__(self, clade): + """ + Check if a clade is present in the tree + """ + return clade in self._nodes + + +def shared_node_spans(ts, other): + """ + Calculate the spans over which pairs of nodes in two tree sequences are + ancestral to indentical sets of samples. + + Returns a sparse matrix where rows correspond to nodes in `ts` and columns + correspond to nodes in `other`. + """ + + if ts.sequence_length != other.sequence_length: + raise ValueError("Tree sequences must be of equal sequence length.") + + if ts.num_samples != other.num_samples: + raise ValueError("Tree sequences must have the same numbers of samples.") + + nil = frozenset() + + # Initialize clade iterators + query = CladeMap(ts) + target = CladeMap(other) + + # Initialize buffer[clade] = (query_nodes, target_nodes, left_coord) + modified = query.clades() | target.clades() + buffer = {} + + # Build sparse matrix of matches in triplet format + query_node = [] + target_node = [] + shared_span = [] + right = 0 + while True: + left = right + right = min(query.interval[1], target.interval[1]) + + # Flush pairs of nodes that no longer have matching clades + for clade in modified: # flush: + if clade in buffer: + n_i, n_j, start = buffer.pop(clade) + span = left - start + for i, j in product(n_i, n_j): + query_node.append(i) + target_node.append(j) + shared_span.append(span) + + # Add new pairs of nodes with matching clades + for clade in modified: + assert clade not in buffer + if clade in query and clade in target: + n_i, n_j = query[clade], target[clade] + buffer[clade] = (n_i, n_j, left) + + if right == ts.sequence_length: + break + + # Find difference in clades with advance to next tree + modified.clear() + for clade_map in (query, target): + if clade_map.interval[1] == right: + clade_diff = clade_map.next() + for prev, curr in clade_diff.values(): + if prev != nil: + modified.add(prev) + if curr != nil: + modified.add(curr) + + # Flush final tree + for clade in buffer: + n_i, n_j, start = buffer[clade] + span = right - start + for i, j in product(n_i, n_j): + query_node.append(i) + target_node.append(j) + shared_span.append(span) + + numer = scipy.sparse.coo_matrix( + (shared_span, (query_node, target_node)), + shape=(ts.num_nodes, other.num_nodes), + ).tocsr() + + return numer + + +def match_node_ages(ts, other): + """ + For each node in `ts`, return the age of a matched node from `other`. Node + matching is accomplished by calculating the intervals over which pairs of + nodes (one from `ts`, one from `other`) subtend the same set of samples. + + Returns three vectors of length `ts.num_nodes`: the age of the best + matching node in `other` (e.g. with the longest shared span); the + proportion of the node span in `ts` that is covered by the best match; and + the id of the best match in `other`. + + If either tree sequence contains unary nodes, then there may be multiple + matches with the same span for a single node. In this case, the returned + match is the node with the smallest integer id. + """ + + shared_spans = shared_node_spans(ts, other) + matched_span = shared_spans.max(axis=1).todense().A1 + best_match = shared_spans.argmax(axis=1).A1 + # NB: if there are multiple nodes with the largest span in a row, + # argmax returns the node with the smallest integer id + matched_time = other.nodes_time[best_match] + + best_match[matched_span == 0] = tskit.NULL + matched_time[matched_span == 0] = np.nan + + return matched_time, matched_span, best_match + + +def node_spans(ts): + """ + Returns the array of "node spans", i.e., the `j`th entry gives + the total span over which node `j` is in the tree (i.e., does + not have 'missing data' there). + """ + child_spans = np.bincount( + ts.edges_child, + weights=ts.edges_right - ts.edges_left, + minlength=ts.num_nodes, + ) + for t in ts.trees(): + span = t.span + for r in t.roots: + # do this check to exempt 'missing data' + if t.num_children(r) > 0: + child_spans[r] += span + return child_spans + + +def total_span(ts): + """ + Returns the total length of all "node spans", computed from + `node_spans(ts)`. + """ + ts_node_spans = node_spans(ts) + ts_total_span = np.sum(ts_node_spans) + return ts_total_span + + +def tree_discrepancy(ts, other): + """ + For two tree sequences `ts` and `other`, + this method returns three values, as a tuple: + 1. The fraction of the total span of `ts` over which each nodes' descendant + sample set does not match its' best match's descendant sample set. + 2. The root mean squared difference + between the times of the nodes in `ts` + and times of their best matching nodes in `other`, + with the average weighted by the nodes' spans in `ts`. + 3. The proportion of the span in `other` that is correctly + represented in `ts` (i.e., the total matching span divided + by the total span of `other`). + + This is done as follows: + + For each node in `ts` the best matching node(s) from `other` + has the longest matching span using `shared_node_spans`. + If there are multiple matches with the same longest shared span + for a single node, the best match is the match that is closest in time. + The discrepancy is: + ..math:: + + d(ts, other) = 1 - + \\left(sum_{x\\in \\operatorname{ts}} + \\min_{y\\in \\operatorname{other}} + |t_x - t_y| \\max{y \\in \\operatorname{other}} + \frac{1}{T}* \\operatorname{shared_span}(x,y)\right), + + where :math: `T` is the sum of spans of all nodes in `ts`. + + Returns three values: + `discrepancy` (float) the value computed above + `root-mean-squared discrepancy` (float) + `proportion of span of `other` correctly matching in `ts` (float) + + """ + + shared_spans = shared_node_spans(ts, other) + # Find all potential matches for a node based on max shared span length + max_span = shared_spans.max(axis=1).toarray().flatten() + col_ind = shared_spans.indices + row_ind = np.repeat( + np.arange(shared_spans.shape[0]), repeats=np.diff(shared_spans.indptr) + ) + # mask to find all potential node matches + match = shared_spans.data == max_span[row_ind] + # scale with difference in node times + # determine best matches with the best_match_matrix + ts_times = ts.nodes_time[row_ind[match]] + other_times = other.nodes_time[col_ind[match]] + time_difference = np.absolute(np.asarray(ts_times - other_times)) + # If a node x in `ts` has no match then we set time_difference to zero + # This node then does not effect the rmse + for j in range(len(shared_spans.data[match])): + if shared_spans.data[match][j] == 0: + time_difference[j] = 0.0 + # If two nodes have the same time, then + # time_difference is zero, which causes problems with argmin + # Instead we store data as 1/(1+x) and find argmax + best_match_matrix = scipy.sparse.coo_matrix( + ( + 1 / (1 + time_difference), + (row_ind[match], col_ind[match]), + ), + shape=(ts.num_nodes, other.num_nodes), + ) + # Between each pair of nodes, find the maximum shared span + best_match = best_match_matrix.argmax(axis=1).A1 + best_match_spans = shared_spans[np.arange(len(best_match)), best_match].reshape(-1) + # Return the discrepancy between ts and other + ts_node_spans = node_spans(ts) + total_node_spans_ts = total_span(ts) + total_node_spans_other = total_span(other) + discrepancy = 1 - np.sum(best_match_spans) / total_node_spans_ts + true_proportion = (1 - discrepancy) * total_node_spans_ts / total_node_spans_other + # Compute the root-mean-square discrepancy in time + # with averaged weighted by span in ts + time_matrix = scipy.sparse.csr_matrix( + (time_difference, (row_ind[match], col_ind[match])), + shape=(ts.num_nodes, other.num_nodes), + ) + time_discrepancies = np.asarray( + time_matrix[np.arange(len(best_match)), best_match].reshape(-1) + ) + product = np.multiply((time_discrepancies**2), ts_node_spans) + rmse = np.sqrt(np.sum(product) / total_node_spans_ts) + return discrepancy, rmse, true_proportion From cfbb8871429ebe997f6c4644814391a61b96b0b3 Mon Sep 17 00:00:00 2001 From: peter Date: Wed, 6 Nov 2024 15:35:47 -0800 Subject: [PATCH 2/2] removed duplicated code --- tests/test_evaluation.py | 148 ------------------------- tsdate/discrepancy.py | 228 +-------------------------------------- 2 files changed, 1 insertion(+), 375 deletions(-) delete mode 100644 tests/test_evaluation.py diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py deleted file mode 100644 index 857c0717..00000000 --- a/tests/test_evaluation.py +++ /dev/null @@ -1,148 +0,0 @@ -# MIT License -# -# Copyright (c) 2021-23 Tskit Developers -# Copyright (c) 2020-21 University of Oxford -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -""" -Test tools for mapping between node sets of different tree sequences -""" - -from collections import defaultdict -from itertools import combinations - -import msprime -import numpy as np -import pytest -import scipy.sparse -import tsinfer - -from tsdate import evaluation - -# --- simulate test case --- -demo = msprime.Demography.isolated_model([1e4]) -for t in np.linspace(500, 10000, 20): - demo.add_census(time=t) -true_unary = msprime.sim_ancestry( - samples=10, - sequence_length=1e6, - demography=demo, - recombination_rate=1e-8, - random_seed=1024, -) -true_unary = msprime.sim_mutations(true_unary, rate=2e-8, random_seed=1024) -assert true_unary.num_trees > 1 -true_simpl = true_unary.simplify(filter_sites=False) -sample_dat = tsinfer.SampleData.from_tree_sequence(true_simpl) -infr_unary = tsinfer.infer(sample_dat) -infr_simpl = infr_unary.simplify(filter_sites=False) - - -def naive_shared_node_spans(ts, other): - """ - Inefficient but transparent function to get span where nodes from two tree - sequences subtend the same sample set - """ - - def _clade_dict(tree): - clade_to_node = defaultdict(set) - for node in tree.nodes(): - clade = frozenset(tree.samples(node)) - clade_to_node[clade].add(node) - return clade_to_node - - assert ts.sequence_length == other.sequence_length - assert ts.num_samples == other.num_samples - out = np.zeros((ts.num_nodes, other.num_nodes)) - for interval, query_tree, target_tree in ts.coiterate(other): - query = _clade_dict(query_tree) - target = _clade_dict(target_tree) - span = interval.right - interval.left - for clade, nodes in query.items(): - if clade in target: - for i in nodes: - for j in target[clade]: - out[i, j] += span - return scipy.sparse.csr_matrix(out) - - -@pytest.mark.parametrize("ts", [true_unary, infr_unary, true_simpl, infr_simpl]) -class TestCladeMap: - def test_map(self, ts): - """ - test that clade map has correct nodes, clades - """ - clade_map = evaluation.CladeMap(ts) - for tree in ts.trees(): - for node in tree.nodes(): - clade = frozenset(tree.samples(node)) - assert node in clade_map._nodes[clade] - assert clade_map._clades[node] == clade - clade_map.next() - - def test_diff(self, ts): - """ - test difference in clades between adjacent trees - """ - clade_map = evaluation.CladeMap(ts) - tree_1 = ts.first() - tree_2 = ts.first() - while True: - tree_2.next() - diff = clade_map.next() - diff_test = {} - for n in set(tree_1.nodes()) | set(tree_2.nodes()): - prev = frozenset(tree_1.samples(n)) - curr = frozenset(tree_2.samples(n)) - if prev != curr: - diff_test[n] = (prev, curr) - for node in diff_test.keys() | diff.keys(): - assert diff_test[node][0] == diff[node][0] - assert diff_test[node][1] == diff[node][1] - if tree_2.index == ts.num_trees - 1: - break - tree_1.next() - - -class TestNodeMatching: - @pytest.mark.parametrize( - "pair", combinations([infr_simpl, true_simpl, infr_unary, true_unary], 2) - ) - def test_shared_spans(self, pair): - """ - Check that efficient implementation returns same answer as naive - implementation - """ - check = naive_shared_node_spans(pair[0], pair[1]) - test = evaluation.shared_node_spans(pair[0], pair[1]) - assert check.shape == test.shape - assert check.nnz == test.nnz - assert np.allclose(check.data, test.data) - - @pytest.mark.parametrize("ts", [infr_simpl, true_simpl]) - def test_match_self(self, ts): - """ - Check that matching against self returns node ids - - TODO: this'll only work reliably when there's not unary nodes. - """ - time, _, hit = evaluation.match_node_ages(ts, ts) - assert np.allclose(time, ts.nodes_time) - assert np.array_equal(hit, np.arange(ts.num_nodes)) diff --git a/tsdate/discrepancy.py b/tsdate/discrepancy.py index 77478aa9..1ddabdd9 100644 --- a/tsdate/discrepancy.py +++ b/tsdate/discrepancy.py @@ -23,236 +23,10 @@ Tools for comparing node times between tree sequences with different node sets """ -import copy -from collections import defaultdict -from itertools import product - import numpy as np import scipy.sparse -import tskit - - -class CladeMap: - """ - An iterator across trees that maintains a mapping from a clade (a `frozenset` of - sample IDs) to a `set` of nodes. When there are unary nodes, there may be multiple - nodes associated with each clade. - """ - - def __init__(self, ts): - self._nil = frozenset() - self._nodes = defaultdict(set) # nodes[clade] = {node ids} - self._clades = defaultdict(frozenset) # clades[node] = {sample ids} - self.tree_sequence = ts - self.tree = ts.first(sample_lists=True) - for node in self.tree.nodes(): - clade = frozenset(self.tree.samples(node)) - self._nodes[clade].add(node) - self._clades[node] = clade - self._prev = copy.deepcopy(self._clades) - self._diff = ts.edge_diffs() - next(self._diff) - - def _propagate(self, edge, downdate=False): - """ - Traverse path from `edge.parent` to root, either adding or removing the - state (clade) associated with `edge.child` from the state of each - visited node. Return a set with the node ids encountered during - traversal. - """ - nodes = set() - node = edge.parent - clade = self._clades[edge.child] - while node != tskit.NULL: - last = self._clades[node] - self._clades[node] = last - clade if downdate else last | clade - if len(last): - self._nodes[last].remove(node) - if len(self._nodes[last]) == 0: - del self._nodes[last] - self._nodes[self._clades[node]].add(node) - nodes.add(node) - node = self.tree.parent(node) - return nodes - - def next(self): - """ - Advance to the next tree, returning the difference between trees as a - dictionary of the form `node : (last_clade, next_clade)` - """ - nodes = set() # nodes with potentially altered clades - diff = {} # diff[node] = (prev_clade, curr_clade) - - if self.tree.index + 1 == self.tree_sequence.num_trees: - return None - - # Subtract clades subtended by outgoing edges - edge_diff = next(self._diff) - for eo in edge_diff.edges_out: - nodes |= self._propagate(eo, downdate=True) - - # Prune nodes that are no longer in tree - for node in self._nodes[self._nil]: - diff[node] = (self._prev[node], self._nil) - del self._clades[node] - nodes -= self._nodes[self._nil] - self._nodes[self._nil].clear() - - # Add clades subtended by incoming edges - self.tree.next() - for ei in edge_diff.edges_in: - nodes |= self._propagate(ei, downdate=False) - - # Find difference in clades between adjacent trees - for node in nodes: - diff[node] = (self._prev[node], self._clades[node]) - if self._prev[node] == self._clades[node]: - del diff[node] - - # Sync previous and current states - for node, (_, curr) in diff.items(): - if curr == self._nil: - del self._prev[node] - else: - self._prev[node] = curr - - return diff - - @property - def interval(self): - """ - Return interval spanned by tree - """ - return self.tree.interval - - def clades(self): - """ - Return set of clades in tree - """ - return self._nodes.keys() - self._nil - - def __getitem__(self, clade): - """ - Return set of nodes associated with a given clade. - """ - return frozenset(self._nodes[clade]) if frozenset(clade) in self else self._nil - - def __contains__(self, clade): - """ - Check if a clade is present in the tree - """ - return clade in self._nodes - - -def shared_node_spans(ts, other): - """ - Calculate the spans over which pairs of nodes in two tree sequences are - ancestral to indentical sets of samples. - - Returns a sparse matrix where rows correspond to nodes in `ts` and columns - correspond to nodes in `other`. - """ - - if ts.sequence_length != other.sequence_length: - raise ValueError("Tree sequences must be of equal sequence length.") - - if ts.num_samples != other.num_samples: - raise ValueError("Tree sequences must have the same numbers of samples.") - - nil = frozenset() - - # Initialize clade iterators - query = CladeMap(ts) - target = CladeMap(other) - - # Initialize buffer[clade] = (query_nodes, target_nodes, left_coord) - modified = query.clades() | target.clades() - buffer = {} - - # Build sparse matrix of matches in triplet format - query_node = [] - target_node = [] - shared_span = [] - right = 0 - while True: - left = right - right = min(query.interval[1], target.interval[1]) - - # Flush pairs of nodes that no longer have matching clades - for clade in modified: # flush: - if clade in buffer: - n_i, n_j, start = buffer.pop(clade) - span = left - start - for i, j in product(n_i, n_j): - query_node.append(i) - target_node.append(j) - shared_span.append(span) - - # Add new pairs of nodes with matching clades - for clade in modified: - assert clade not in buffer - if clade in query and clade in target: - n_i, n_j = query[clade], target[clade] - buffer[clade] = (n_i, n_j, left) - - if right == ts.sequence_length: - break - - # Find difference in clades with advance to next tree - modified.clear() - for clade_map in (query, target): - if clade_map.interval[1] == right: - clade_diff = clade_map.next() - for prev, curr in clade_diff.values(): - if prev != nil: - modified.add(prev) - if curr != nil: - modified.add(curr) - - # Flush final tree - for clade in buffer: - n_i, n_j, start = buffer[clade] - span = right - start - for i, j in product(n_i, n_j): - query_node.append(i) - target_node.append(j) - shared_span.append(span) - - numer = scipy.sparse.coo_matrix( - (shared_span, (query_node, target_node)), - shape=(ts.num_nodes, other.num_nodes), - ).tocsr() - - return numer - - -def match_node_ages(ts, other): - """ - For each node in `ts`, return the age of a matched node from `other`. Node - matching is accomplished by calculating the intervals over which pairs of - nodes (one from `ts`, one from `other`) subtend the same set of samples. - - Returns three vectors of length `ts.num_nodes`: the age of the best - matching node in `other` (e.g. with the longest shared span); the - proportion of the node span in `ts` that is covered by the best match; and - the id of the best match in `other`. - - If either tree sequence contains unary nodes, then there may be multiple - matches with the same span for a single node. In this case, the returned - match is the node with the smallest integer id. - """ - - shared_spans = shared_node_spans(ts, other) - matched_span = shared_spans.max(axis=1).todense().A1 - best_match = shared_spans.argmax(axis=1).A1 - # NB: if there are multiple nodes with the largest span in a row, - # argmax returns the node with the smallest integer id - matched_time = other.nodes_time[best_match] - - best_match[matched_span == 0] = tskit.NULL - matched_time[matched_span == 0] = np.nan - return matched_time, matched_span, best_match +from .evaluation import shared_node_spans def node_spans(ts):