From 68fb7c58012c3663c09bfd065ef5f35307ae9e74 Mon Sep 17 00:00:00 2001 From: szhan Date: Sat, 5 Oct 2024 20:48:51 +0100 Subject: [PATCH] Count mutations above all nodes --- sc2ts/info.py | 37 +++++++++++++ tests/test_info.py | 133 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 170 insertions(+) diff --git a/sc2ts/info.py b/sc2ts/info.py index fa1bbec..6046370 100644 --- a/sc2ts/info.py +++ b/sc2ts/info.py @@ -235,6 +235,43 @@ def get_path_mrca(path1, path2, node_time): ) +@numba.njit +def _get_num_muts( + ts_num_nodes, + tree_nodes_preorder, + tree_parent_array, + tree_nodes_num_mutations, +): + num_muts = np.zeros(ts_num_nodes, dtype=np.int32) + for node in tree_nodes_preorder: + pa = tree_parent_array[node] + if pa > -1: + num_muts[node] = num_muts[pa] + num_muts[node] += tree_nodes_num_mutations[node] + return num_muts + + +def get_num_muts(ts): + num_muts_all_trees = np.zeros(ts.num_nodes, dtype=np.int32) + for tree in ts.trees(): + tree_nodes_preorder = tree.preorder() + assert np.min(tree_nodes_preorder) >= 0 + tree_parent_array = tree.parent_array + mut_pos = ts.sites_position[ts.mutations_site] + is_mut_in_tree = (tree.interval.left <= mut_pos) & (mut_pos < tree.interval.right) + tree_nodes_num_muts = np.bincount( + ts.mutations_node[is_mut_in_tree], + minlength=ts.num_nodes, + ) + num_muts_all_trees += _get_num_muts( + ts_num_nodes=ts.num_nodes, + tree_nodes_preorder=tree_nodes_preorder, + tree_parent_array=tree_parent_array, + tree_nodes_num_mutations=tree_nodes_num_muts, + ) + return num_muts_all_trees + + def get_recombinant_edges(ts): """ Return the partial edges from the tree sequence grouped by child (which must diff --git a/tests/test_info.py b/tests/test_info.py index 79e0824..8c0a2a9 100644 --- a/tests/test_info.py +++ b/tests/test_info.py @@ -1,6 +1,10 @@ import pytest +import numpy as np import pandas as pd +import msprime +import tskit + from sc2ts import info @@ -23,3 +27,132 @@ def test_last_date(self, fx_ts_map, fx_metadata_db): ] assert list(df["db_count"]) == [26, 15, 4, 4, 1, 3, 1, 1, 1] assert list(df["arg_count"]) == [24, 15, 4, 3, 1, 1, 0, 0, 0] + + +class TestCountMutations: + def test_1tree_0mut(self): + # 2.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 1.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ + # 0 1 + ts = tskit.Tree.generate_balanced(4, arity=2).tree_sequence + expected = np.zeros(ts.num_nodes, dtype=np.int32) + actual = info.get_num_muts(ts) + np.testing.assert_equal(expected, actual) + + def test_1tree_1mut_below_root(self): + ts = tskit.Tree.generate_balanced(4, arity=2).tree_sequence + tables = ts.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=0, derived_state="T") + ts = tables.tree_sequence() + expected = np.zeros(ts.num_nodes, dtype=np.int32) + expected[0] = 1 + actual = info.get_num_muts(ts) + np.testing.assert_equal(expected, actual) + + def test_1tree_1mut_above_root(self): + ts = tskit.Tree.generate_balanced(4, arity=2).tree_sequence + tables = ts.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=6, derived_state="T") + ts = tables.tree_sequence() + expected = np.ones(ts.num_nodes, dtype=np.int32) + actual = info.get_num_muts(ts) + np.testing.assert_equal(expected, actual) + + def test_1tree_2mut_homoplasies(self): + ts = tskit.Tree.generate_balanced(4, arity=2).tree_sequence + tables = ts.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=0, derived_state="T") + tables.mutations.add_row(site=0, node=3, derived_state="T") + ts = tables.tree_sequence() + expected = np.zeros(ts.num_nodes, dtype=np.int32) + expected[0] = 1 + expected[3] = 1 + actual = info.get_num_muts(ts) + np.testing.assert_equal(expected, actual) + + def test_1tree_2mut_reversion(self): + ts = tskit.Tree.generate_balanced(4, arity=2).tree_sequence + tables = ts.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=0, derived_state="A") + tables.mutations.add_row(site=0, node=4, derived_state="T") + ts = tables.tree_sequence() + expected = np.zeros(ts.num_nodes, dtype=np.int32) + expected[0] = 2 + expected[1] = 1 + expected[4] = 1 + actual = info.get_num_muts(ts) + np.testing.assert_equal(expected, actual) + + def test_2trees_0mut(self): + ts = msprime.sim_ancestry( + 2, + recombination_rate=1e6, # Nearly guarantee recomb. + sequence_length=2, + ) + assert ts.num_trees == 2 + expected = np.zeros(ts.num_nodes, dtype=np.int32) + actual = info.get_num_muts(ts) + np.testing.assert_equal(expected, actual) + + def test_2trees_1mut(self): + ts = msprime.sim_ancestry( + 4, + ploidy=1, + recombination_rate=1e6, # Nearly guarantee recomb. + sequence_length=2, + ) + tables = ts.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=0, derived_state="T") + ts = tables.tree_sequence() + assert ts.num_trees == 2 + expected = np.zeros(ts.num_nodes, dtype=np.int32) + expected[0] = 1 + actual = info.get_num_muts(ts) + np.testing.assert_equal(expected, actual) + + def test_2trees_2mut_diff_trees(self): + ts = msprime.sim_ancestry( + 4, + ploidy=1, + recombination_rate=1e6, # Nearly guarantee recomb. + sequence_length=2, + ) + tables = ts.dump_tables() + tables.sites.add_row(0, "A") + tables.sites.add_row(1, "A") + tables.mutations.add_row(site=0, node=0, derived_state="T") + tables.mutations.add_row(site=1, node=0, derived_state="T") + ts = tables.tree_sequence() + assert ts.num_trees == 2 + expected = np.zeros(ts.num_nodes, dtype=np.int32) + expected[0] = 2 + actual = info.get_num_muts(ts) + np.testing.assert_equal(expected, actual) + + def test_2trees_2mut_same_tree(self): + ts = msprime.sim_ancestry( + 4, + ploidy=1, + recombination_rate=1e6, # Nearly guarantee recomb. + sequence_length=2, + ) + tables = ts.dump_tables() + tables.sites.add_row(0, "A") + tables.sites.add_row(1, "A") + tables.mutations.add_row(site=0, node=0, derived_state="T") + tables.mutations.add_row(site=1, node=3, derived_state="T") + ts = tables.tree_sequence() + assert ts.num_trees == 2 + expected = np.zeros(ts.num_nodes, dtype=np.int32) + expected[0] = 1 + expected[3] = 1 + actual = info.get_num_muts(ts) + np.testing.assert_equal(expected, actual)