diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py deleted file mode 100644 index 857c0717..00000000 --- a/tests/test_evaluation.py +++ /dev/null @@ -1,148 +0,0 @@ -# MIT License -# -# Copyright (c) 2021-23 Tskit Developers -# Copyright (c) 2020-21 University of Oxford -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -""" -Test tools for mapping between node sets of different tree sequences -""" - -from collections import defaultdict -from itertools import combinations - -import msprime -import numpy as np -import pytest -import scipy.sparse -import tsinfer - -from tsdate import evaluation - -# --- simulate test case --- -demo = msprime.Demography.isolated_model([1e4]) -for t in np.linspace(500, 10000, 20): - demo.add_census(time=t) -true_unary = msprime.sim_ancestry( - samples=10, - sequence_length=1e6, - demography=demo, - recombination_rate=1e-8, - random_seed=1024, -) -true_unary = msprime.sim_mutations(true_unary, rate=2e-8, random_seed=1024) -assert true_unary.num_trees > 1 -true_simpl = true_unary.simplify(filter_sites=False) -sample_dat = tsinfer.SampleData.from_tree_sequence(true_simpl) -infr_unary = tsinfer.infer(sample_dat) -infr_simpl = infr_unary.simplify(filter_sites=False) - - -def naive_shared_node_spans(ts, other): - """ - Inefficient but transparent function to get span where nodes from two tree - sequences subtend the same sample set - """ - - def _clade_dict(tree): - clade_to_node = defaultdict(set) - for node in tree.nodes(): - clade = frozenset(tree.samples(node)) - clade_to_node[clade].add(node) - return clade_to_node - - assert ts.sequence_length == other.sequence_length - assert ts.num_samples == other.num_samples - out = np.zeros((ts.num_nodes, other.num_nodes)) - for interval, query_tree, target_tree in ts.coiterate(other): - query = _clade_dict(query_tree) - target = _clade_dict(target_tree) - span = interval.right - interval.left - for clade, nodes in query.items(): - if clade in target: - for i in nodes: - for j in target[clade]: - out[i, j] += span - return scipy.sparse.csr_matrix(out) - - -@pytest.mark.parametrize("ts", [true_unary, infr_unary, true_simpl, infr_simpl]) -class TestCladeMap: - def test_map(self, ts): - """ - test that clade map has correct nodes, clades - """ - clade_map = evaluation.CladeMap(ts) - for tree in ts.trees(): - for node in tree.nodes(): - clade = frozenset(tree.samples(node)) - assert node in clade_map._nodes[clade] - assert clade_map._clades[node] == clade - clade_map.next() - - def test_diff(self, ts): - """ - test difference in clades between adjacent trees - """ - clade_map = evaluation.CladeMap(ts) - tree_1 = ts.first() - tree_2 = ts.first() - while True: - tree_2.next() - diff = clade_map.next() - diff_test = {} - for n in set(tree_1.nodes()) | set(tree_2.nodes()): - prev = frozenset(tree_1.samples(n)) - curr = frozenset(tree_2.samples(n)) - if prev != curr: - diff_test[n] = (prev, curr) - for node in diff_test.keys() | diff.keys(): - assert diff_test[node][0] == diff[node][0] - assert diff_test[node][1] == diff[node][1] - if tree_2.index == ts.num_trees - 1: - break - tree_1.next() - - -class TestNodeMatching: - @pytest.mark.parametrize( - "pair", combinations([infr_simpl, true_simpl, infr_unary, true_unary], 2) - ) - def test_shared_spans(self, pair): - """ - Check that efficient implementation returns same answer as naive - implementation - """ - check = naive_shared_node_spans(pair[0], pair[1]) - test = evaluation.shared_node_spans(pair[0], pair[1]) - assert check.shape == test.shape - assert check.nnz == test.nnz - assert np.allclose(check.data, test.data) - - @pytest.mark.parametrize("ts", [infr_simpl, true_simpl]) - def test_match_self(self, ts): - """ - Check that matching against self returns node ids - - TODO: this'll only work reliably when there's not unary nodes. - """ - time, _, hit = evaluation.match_node_ages(ts, ts) - assert np.allclose(time, ts.nodes_time) - assert np.array_equal(hit, np.arange(ts.num_nodes)) diff --git a/tsdate/discrepancy.py b/tsdate/discrepancy.py index 77478aa9..1ddabdd9 100644 --- a/tsdate/discrepancy.py +++ b/tsdate/discrepancy.py @@ -23,236 +23,10 @@ Tools for comparing node times between tree sequences with different node sets """ -import copy -from collections import defaultdict -from itertools import product - import numpy as np import scipy.sparse -import tskit - - -class CladeMap: - """ - An iterator across trees that maintains a mapping from a clade (a `frozenset` of - sample IDs) to a `set` of nodes. When there are unary nodes, there may be multiple - nodes associated with each clade. - """ - - def __init__(self, ts): - self._nil = frozenset() - self._nodes = defaultdict(set) # nodes[clade] = {node ids} - self._clades = defaultdict(frozenset) # clades[node] = {sample ids} - self.tree_sequence = ts - self.tree = ts.first(sample_lists=True) - for node in self.tree.nodes(): - clade = frozenset(self.tree.samples(node)) - self._nodes[clade].add(node) - self._clades[node] = clade - self._prev = copy.deepcopy(self._clades) - self._diff = ts.edge_diffs() - next(self._diff) - - def _propagate(self, edge, downdate=False): - """ - Traverse path from `edge.parent` to root, either adding or removing the - state (clade) associated with `edge.child` from the state of each - visited node. Return a set with the node ids encountered during - traversal. - """ - nodes = set() - node = edge.parent - clade = self._clades[edge.child] - while node != tskit.NULL: - last = self._clades[node] - self._clades[node] = last - clade if downdate else last | clade - if len(last): - self._nodes[last].remove(node) - if len(self._nodes[last]) == 0: - del self._nodes[last] - self._nodes[self._clades[node]].add(node) - nodes.add(node) - node = self.tree.parent(node) - return nodes - - def next(self): - """ - Advance to the next tree, returning the difference between trees as a - dictionary of the form `node : (last_clade, next_clade)` - """ - nodes = set() # nodes with potentially altered clades - diff = {} # diff[node] = (prev_clade, curr_clade) - - if self.tree.index + 1 == self.tree_sequence.num_trees: - return None - - # Subtract clades subtended by outgoing edges - edge_diff = next(self._diff) - for eo in edge_diff.edges_out: - nodes |= self._propagate(eo, downdate=True) - - # Prune nodes that are no longer in tree - for node in self._nodes[self._nil]: - diff[node] = (self._prev[node], self._nil) - del self._clades[node] - nodes -= self._nodes[self._nil] - self._nodes[self._nil].clear() - - # Add clades subtended by incoming edges - self.tree.next() - for ei in edge_diff.edges_in: - nodes |= self._propagate(ei, downdate=False) - - # Find difference in clades between adjacent trees - for node in nodes: - diff[node] = (self._prev[node], self._clades[node]) - if self._prev[node] == self._clades[node]: - del diff[node] - - # Sync previous and current states - for node, (_, curr) in diff.items(): - if curr == self._nil: - del self._prev[node] - else: - self._prev[node] = curr - - return diff - - @property - def interval(self): - """ - Return interval spanned by tree - """ - return self.tree.interval - - def clades(self): - """ - Return set of clades in tree - """ - return self._nodes.keys() - self._nil - - def __getitem__(self, clade): - """ - Return set of nodes associated with a given clade. - """ - return frozenset(self._nodes[clade]) if frozenset(clade) in self else self._nil - - def __contains__(self, clade): - """ - Check if a clade is present in the tree - """ - return clade in self._nodes - - -def shared_node_spans(ts, other): - """ - Calculate the spans over which pairs of nodes in two tree sequences are - ancestral to indentical sets of samples. - - Returns a sparse matrix where rows correspond to nodes in `ts` and columns - correspond to nodes in `other`. - """ - - if ts.sequence_length != other.sequence_length: - raise ValueError("Tree sequences must be of equal sequence length.") - - if ts.num_samples != other.num_samples: - raise ValueError("Tree sequences must have the same numbers of samples.") - - nil = frozenset() - - # Initialize clade iterators - query = CladeMap(ts) - target = CladeMap(other) - - # Initialize buffer[clade] = (query_nodes, target_nodes, left_coord) - modified = query.clades() | target.clades() - buffer = {} - - # Build sparse matrix of matches in triplet format - query_node = [] - target_node = [] - shared_span = [] - right = 0 - while True: - left = right - right = min(query.interval[1], target.interval[1]) - - # Flush pairs of nodes that no longer have matching clades - for clade in modified: # flush: - if clade in buffer: - n_i, n_j, start = buffer.pop(clade) - span = left - start - for i, j in product(n_i, n_j): - query_node.append(i) - target_node.append(j) - shared_span.append(span) - - # Add new pairs of nodes with matching clades - for clade in modified: - assert clade not in buffer - if clade in query and clade in target: - n_i, n_j = query[clade], target[clade] - buffer[clade] = (n_i, n_j, left) - - if right == ts.sequence_length: - break - - # Find difference in clades with advance to next tree - modified.clear() - for clade_map in (query, target): - if clade_map.interval[1] == right: - clade_diff = clade_map.next() - for prev, curr in clade_diff.values(): - if prev != nil: - modified.add(prev) - if curr != nil: - modified.add(curr) - - # Flush final tree - for clade in buffer: - n_i, n_j, start = buffer[clade] - span = right - start - for i, j in product(n_i, n_j): - query_node.append(i) - target_node.append(j) - shared_span.append(span) - - numer = scipy.sparse.coo_matrix( - (shared_span, (query_node, target_node)), - shape=(ts.num_nodes, other.num_nodes), - ).tocsr() - - return numer - - -def match_node_ages(ts, other): - """ - For each node in `ts`, return the age of a matched node from `other`. Node - matching is accomplished by calculating the intervals over which pairs of - nodes (one from `ts`, one from `other`) subtend the same set of samples. - - Returns three vectors of length `ts.num_nodes`: the age of the best - matching node in `other` (e.g. with the longest shared span); the - proportion of the node span in `ts` that is covered by the best match; and - the id of the best match in `other`. - - If either tree sequence contains unary nodes, then there may be multiple - matches with the same span for a single node. In this case, the returned - match is the node with the smallest integer id. - """ - - shared_spans = shared_node_spans(ts, other) - matched_span = shared_spans.max(axis=1).todense().A1 - best_match = shared_spans.argmax(axis=1).A1 - # NB: if there are multiple nodes with the largest span in a row, - # argmax returns the node with the smallest integer id - matched_time = other.nodes_time[best_match] - - best_match[matched_span == 0] = tskit.NULL - matched_time[matched_span == 0] = np.nan - return matched_time, matched_span, best_match +from .evaluation import shared_node_spans def node_spans(ts):