Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CCPhylo branch lengths #242

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions cassiopeia/solver/DistanceSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,14 @@ def _ccphylo_solve(
root = T.get_tree_root()
if midpoint in root.children:
last_split = [root.name,midpoint.name]
length = tree.get_edge_data(root.name,midpoint.name).get("length",None)
else:
last_split = [root.name,root.children[0].name]
length = tree.get_edge_data(root.name,root.children[0].name).get("length",None)
tree.remove_edge(last_split[0],last_split[1])

# root tree
tree = self.root_tree(tree,cassiopeia_tree.root_sample_name,last_split)
tree = self.root_tree(tree,cassiopeia_tree.root_sample_name,last_split,length=length)

# remove root from character matrix before populating tree
if (
Expand Down Expand Up @@ -396,7 +398,7 @@ def setup_dissimilarity_map(

@abc.abstractmethod
def root_tree(
self, tree: nx.Graph, root_sample: str, remaining_samples: List[str]
self, tree: nx.Graph, root_sample: str, remaining_samples: List[str], length: Optional[float] = None
) -> nx.DiGraph:
"""Roots a tree.

Expand All @@ -407,6 +409,7 @@ def root_tree(
tree: an undirected networkx tree topology
root_sample: node name to treat as the root of the tree topology
remaining_samples: samples yet to be added to the tree.
length: length of the edge to the root node.

Returns:
A rooted networkx tree
Expand Down
14 changes: 10 additions & 4 deletions cassiopeia/solver/NeighborJoiningSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(
)

def root_tree(
self, tree: nx.Graph, root_sample: str, remaining_samples: List[str]
self, tree: nx.Graph, root_sample: str, remaining_samples: List[str], length: Optional[float] = None
) -> nx.DiGraph():
"""Roots a tree produced by Neighbor-Joining at the specified root.

Expand All @@ -108,15 +108,21 @@ def root_tree(
tree: Networkx object representing the tree topology
root_sample: Sample to treat as the root
remaining_samples: The last two unjoined nodes in the tree
length: length of the edge to the root node.


Returns:
A rooted tree
"""
tree.add_edge(remaining_samples[0], remaining_samples[1])
if length:
tree.add_edge(remaining_samples[0], remaining_samples[1], length=length)
else:
tree.add_edge(remaining_samples[0], remaining_samples[1])

rooted_tree = nx.DiGraph()
for e in nx.dfs_edges(tree, source=root_sample):
rooted_tree.add_edge(e[0], e[1])
for u, v in nx.dfs_edges(tree, source=root_sample):
edge_data = tree.get_edge_data(u, v)
rooted_tree.add_edge(u, v, **edge_data)

return rooted_tree

Expand Down
20 changes: 13 additions & 7 deletions cassiopeia/solver/UPGMASolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
self.__cluster_to_cluster_size = defaultdict(int)

def root_tree(
self, tree: nx.Graph, root_sample: str, remaining_samples: List[str]
self, tree: nx.Graph, root_sample: str, remaining_samples: List[str], length: Optional[float] = None
):
"""Roots a tree produced by UPGMA.

Expand All @@ -99,19 +99,25 @@ def root_tree(
tree: Networkx object representing the tree topology
root_sample: Ignored in this case, the root is known in this case
remaining_samples: The last two unjoined nodes in the tree
length: length of the edge to the root node.

Returns:
A rooted tree.
"""

tree.add_node("root")
tree.add_edges_from(
[("root", remaining_samples[0]), ("root", remaining_samples[1])]
)
tree.add_node(root_sample)
if length is not None:
tree.add_edge(root_sample, remaining_samples[0], length=length/2)
tree.add_edge(root_sample, remaining_samples[1], length=length/2)
else:
tree.add_edges_from(
[(root_sample, remaining_samples[0]), (root_sample, remaining_samples[1])]
)

rooted_tree = nx.DiGraph()
for e in nx.dfs_edges(tree, source="root"):
rooted_tree.add_edge(e[0], e[1])
for u, v in nx.dfs_edges(tree, source=root_sample):
edge_data = tree.get_edge_data(u, v)
rooted_tree.add_edge(u, v, **edge_data)

return rooted_tree

Expand Down
15 changes: 15 additions & 0 deletions test/solver_tests/ccphylo_solver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,14 @@ def test_ccphylo_dnj_solver(self):
dnj_tree = self.basic_tree.copy()
self.ccphylo_dnj_solver.solve(dnj_tree)

# Test for edge lengths and node times
tree = dnj_tree.get_tree_topology()
self.assertAlmostEqual(tree.nodes["root"]["time"], 0.0)
self.assertAlmostEqual(tree.edges[('root', 'cassiopeia_internal_node1')]["length"], 0.196666667)
self.assertAlmostEqual(tree.nodes["cassiopeia_internal_node1"]["time"], 0.196666667)
for edge in tree.edges:
self.assertTrue("length" in tree.edges[edge])

# test for expected number of edges
self.assertEqual(len(nj_tree.edges), len(dnj_tree.edges))

Expand Down Expand Up @@ -201,6 +209,13 @@ def test_ccphylo_upgma_solver(self):
# test for expected number of edges
self.assertEqual(len(upgma_tree.edges), len(ccphylo_upgma_tree.edges))

# Test for edge lengths and node times
tree = ccphylo_upgma_tree.get_tree_topology()
self.assertAlmostEqual(tree.nodes["root"]["time"], 0.0)
self.assertAlmostEqual(tree.edges[('root', 'cassiopeia_internal_node0')]["length"], 0.4484375)
self.assertAlmostEqual(tree.nodes["cassiopeia_internal_node0"]["time"], 0.4484375)
for edge in tree.edges:
self.assertTrue("length" in tree.edges[edge])

triplets = itertools.combinations(["a", "c", "d", "e"], 3)
for triplet in triplets:
Expand Down
Loading