Skip to content

Commit

Permalink
Merge pull request jeromekelleher#335 from szhan/count_mutations_all_…
Browse files Browse the repository at this point in the history
…nodes

Count mutations above all nodes
  • Loading branch information
jeromekelleher authored Oct 7, 2024
2 parents 88baf9a + 68fb7c5 commit 6051725
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 0 deletions.
37 changes: 37 additions & 0 deletions sc2ts/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,43 @@ def get_path_mrca(path1, path2, node_time):
)


@numba.njit
def _get_num_muts(
ts_num_nodes,
tree_nodes_preorder,
tree_parent_array,
tree_nodes_num_mutations,
):
num_muts = np.zeros(ts_num_nodes, dtype=np.int32)
for node in tree_nodes_preorder:
pa = tree_parent_array[node]
if pa > -1:
num_muts[node] = num_muts[pa]
num_muts[node] += tree_nodes_num_mutations[node]
return num_muts


def get_num_muts(ts):
num_muts_all_trees = np.zeros(ts.num_nodes, dtype=np.int32)
for tree in ts.trees():
tree_nodes_preorder = tree.preorder()
assert np.min(tree_nodes_preorder) >= 0
tree_parent_array = tree.parent_array
mut_pos = ts.sites_position[ts.mutations_site]
is_mut_in_tree = (tree.interval.left <= mut_pos) & (mut_pos < tree.interval.right)
tree_nodes_num_muts = np.bincount(
ts.mutations_node[is_mut_in_tree],
minlength=ts.num_nodes,
)
num_muts_all_trees += _get_num_muts(
ts_num_nodes=ts.num_nodes,
tree_nodes_preorder=tree_nodes_preorder,
tree_parent_array=tree_parent_array,
tree_nodes_num_mutations=tree_nodes_num_muts,
)
return num_muts_all_trees


def get_recombinant_edges(ts):
"""
Return the partial edges from the tree sequence grouped by child (which must
Expand Down
133 changes: 133 additions & 0 deletions tests/test_info.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import pytest
import numpy as np
import pandas as pd

import msprime
import tskit

from sc2ts import info


Expand All @@ -23,3 +27,132 @@ def test_last_date(self, fx_ts_map, fx_metadata_db):
]
assert list(df["db_count"]) == [26, 15, 4, 4, 1, 3, 1, 1, 1]
assert list(df["arg_count"]) == [24, 15, 4, 3, 1, 1, 0, 0, 0]


class TestCountMutations:
def test_1tree_0mut(self):
# 2.00┊ 6 ┊
# ┊ ┏━┻━┓ ┊
# 1.00┊ 4 5 ┊
# ┊ ┏┻┓ ┏┻┓ ┊
# 0.00┊ 0 1 2 3 ┊
# 0 1
ts = tskit.Tree.generate_balanced(4, arity=2).tree_sequence
expected = np.zeros(ts.num_nodes, dtype=np.int32)
actual = info.get_num_muts(ts)
np.testing.assert_equal(expected, actual)

def test_1tree_1mut_below_root(self):
ts = tskit.Tree.generate_balanced(4, arity=2).tree_sequence
tables = ts.dump_tables()
tables.sites.add_row(0, "A")
tables.mutations.add_row(site=0, node=0, derived_state="T")
ts = tables.tree_sequence()
expected = np.zeros(ts.num_nodes, dtype=np.int32)
expected[0] = 1
actual = info.get_num_muts(ts)
np.testing.assert_equal(expected, actual)

def test_1tree_1mut_above_root(self):
ts = tskit.Tree.generate_balanced(4, arity=2).tree_sequence
tables = ts.dump_tables()
tables.sites.add_row(0, "A")
tables.mutations.add_row(site=0, node=6, derived_state="T")
ts = tables.tree_sequence()
expected = np.ones(ts.num_nodes, dtype=np.int32)
actual = info.get_num_muts(ts)
np.testing.assert_equal(expected, actual)

def test_1tree_2mut_homoplasies(self):
ts = tskit.Tree.generate_balanced(4, arity=2).tree_sequence
tables = ts.dump_tables()
tables.sites.add_row(0, "A")
tables.mutations.add_row(site=0, node=0, derived_state="T")
tables.mutations.add_row(site=0, node=3, derived_state="T")
ts = tables.tree_sequence()
expected = np.zeros(ts.num_nodes, dtype=np.int32)
expected[0] = 1
expected[3] = 1
actual = info.get_num_muts(ts)
np.testing.assert_equal(expected, actual)

def test_1tree_2mut_reversion(self):
ts = tskit.Tree.generate_balanced(4, arity=2).tree_sequence
tables = ts.dump_tables()
tables.sites.add_row(0, "A")
tables.mutations.add_row(site=0, node=0, derived_state="A")
tables.mutations.add_row(site=0, node=4, derived_state="T")
ts = tables.tree_sequence()
expected = np.zeros(ts.num_nodes, dtype=np.int32)
expected[0] = 2
expected[1] = 1
expected[4] = 1
actual = info.get_num_muts(ts)
np.testing.assert_equal(expected, actual)

def test_2trees_0mut(self):
ts = msprime.sim_ancestry(
2,
recombination_rate=1e6, # Nearly guarantee recomb.
sequence_length=2,
)
assert ts.num_trees == 2
expected = np.zeros(ts.num_nodes, dtype=np.int32)
actual = info.get_num_muts(ts)
np.testing.assert_equal(expected, actual)

def test_2trees_1mut(self):
ts = msprime.sim_ancestry(
4,
ploidy=1,
recombination_rate=1e6, # Nearly guarantee recomb.
sequence_length=2,
)
tables = ts.dump_tables()
tables.sites.add_row(0, "A")
tables.mutations.add_row(site=0, node=0, derived_state="T")
ts = tables.tree_sequence()
assert ts.num_trees == 2
expected = np.zeros(ts.num_nodes, dtype=np.int32)
expected[0] = 1
actual = info.get_num_muts(ts)
np.testing.assert_equal(expected, actual)

def test_2trees_2mut_diff_trees(self):
ts = msprime.sim_ancestry(
4,
ploidy=1,
recombination_rate=1e6, # Nearly guarantee recomb.
sequence_length=2,
)
tables = ts.dump_tables()
tables.sites.add_row(0, "A")
tables.sites.add_row(1, "A")
tables.mutations.add_row(site=0, node=0, derived_state="T")
tables.mutations.add_row(site=1, node=0, derived_state="T")
ts = tables.tree_sequence()
assert ts.num_trees == 2
expected = np.zeros(ts.num_nodes, dtype=np.int32)
expected[0] = 2
actual = info.get_num_muts(ts)
np.testing.assert_equal(expected, actual)

def test_2trees_2mut_same_tree(self):
ts = msprime.sim_ancestry(
4,
ploidy=1,
recombination_rate=1e6, # Nearly guarantee recomb.
sequence_length=2,
)
tables = ts.dump_tables()
tables.sites.add_row(0, "A")
tables.sites.add_row(1, "A")
tables.mutations.add_row(site=0, node=0, derived_state="T")
tables.mutations.add_row(site=1, node=3, derived_state="T")
ts = tables.tree_sequence()
assert ts.num_trees == 2
expected = np.zeros(ts.num_nodes, dtype=np.int32)
expected[0] = 1
expected[3] = 1
actual = info.get_num_muts(ts)
np.testing.assert_equal(expected, actual)

0 comments on commit 6051725

Please sign in to comment.