From da10097f7fd6dfc4aebb4f394a20a94bd4adfa77 Mon Sep 17 00:00:00 2001 From: colganwi Date: Fri, 17 May 2024 09:26:21 -0400 Subject: [PATCH] CCPhylo branch lengths --- cassiopeia/solver/DistanceSolver.py | 7 +++++-- cassiopeia/solver/NeighborJoiningSolver.py | 14 ++++++++++---- cassiopeia/solver/UPGMASolver.py | 20 +++++++++++++------- test/solver_tests/ccphylo_solver_test.py | 15 +++++++++++++++ 4 files changed, 43 insertions(+), 13 deletions(-) diff --git a/cassiopeia/solver/DistanceSolver.py b/cassiopeia/solver/DistanceSolver.py index fe9ea3c9..2053d5aa 100644 --- a/cassiopeia/solver/DistanceSolver.py +++ b/cassiopeia/solver/DistanceSolver.py @@ -287,12 +287,14 @@ def _ccphylo_solve( root = T.get_tree_root() if midpoint in root.children: last_split = [root.name,midpoint.name] + length = tree.get_edge_data(root.name,midpoint.name).get("length",None) else: last_split = [root.name,root.children[0].name] + length = tree.get_edge_data(root.name,root.children[0].name).get("length",None) tree.remove_edge(last_split[0],last_split[1]) # root tree - tree = self.root_tree(tree,cassiopeia_tree.root_sample_name,last_split) + tree = self.root_tree(tree,cassiopeia_tree.root_sample_name,last_split,length=length) # remove root from character matrix before populating tree if ( @@ -396,7 +398,7 @@ def setup_dissimilarity_map( @abc.abstractmethod def root_tree( - self, tree: nx.Graph, root_sample: str, remaining_samples: List[str] + self, tree: nx.Graph, root_sample: str, remaining_samples: List[str], length: Optional[float] = None ) -> nx.DiGraph: """Roots a tree. @@ -407,6 +409,7 @@ def root_tree( tree: an undirected networkx tree topology root_sample: node name to treat as the root of the tree topology remaining_samples: samples yet to be added to the tree. + length: length of the edge to the root node. Returns: A rooted networkx tree diff --git a/cassiopeia/solver/NeighborJoiningSolver.py b/cassiopeia/solver/NeighborJoiningSolver.py index 81a236d4..3dde0e6d 100755 --- a/cassiopeia/solver/NeighborJoiningSolver.py +++ b/cassiopeia/solver/NeighborJoiningSolver.py @@ -98,7 +98,7 @@ def __init__( ) def root_tree( - self, tree: nx.Graph, root_sample: str, remaining_samples: List[str] + self, tree: nx.Graph, root_sample: str, remaining_samples: List[str], length: Optional[float] = None ) -> nx.DiGraph(): """Roots a tree produced by Neighbor-Joining at the specified root. @@ -108,15 +108,21 @@ def root_tree( tree: Networkx object representing the tree topology root_sample: Sample to treat as the root remaining_samples: The last two unjoined nodes in the tree + length: length of the edge to the root node. + Returns: A rooted tree """ - tree.add_edge(remaining_samples[0], remaining_samples[1]) + if length: + tree.add_edge(remaining_samples[0], remaining_samples[1], length=length) + else: + tree.add_edge(remaining_samples[0], remaining_samples[1]) rooted_tree = nx.DiGraph() - for e in nx.dfs_edges(tree, source=root_sample): - rooted_tree.add_edge(e[0], e[1]) + for u, v in nx.dfs_edges(tree, source=root_sample): + edge_data = tree.get_edge_data(u, v) + rooted_tree.add_edge(u, v, **edge_data) return rooted_tree diff --git a/cassiopeia/solver/UPGMASolver.py b/cassiopeia/solver/UPGMASolver.py index 12995eec..fa196044 100644 --- a/cassiopeia/solver/UPGMASolver.py +++ b/cassiopeia/solver/UPGMASolver.py @@ -87,7 +87,7 @@ def __init__( self.__cluster_to_cluster_size = defaultdict(int) def root_tree( - self, tree: nx.Graph, root_sample: str, remaining_samples: List[str] + self, tree: nx.Graph, root_sample: str, remaining_samples: List[str], length: Optional[float] = None ): """Roots a tree produced by UPGMA. @@ -99,19 +99,25 @@ def root_tree( tree: Networkx object representing the tree topology root_sample: Ignored in this case, the root is known in this case remaining_samples: The last two unjoined nodes in the tree + length: length of the edge to the root node. Returns: A rooted tree. """ - tree.add_node("root") - tree.add_edges_from( - [("root", remaining_samples[0]), ("root", remaining_samples[1])] - ) + tree.add_node(root_sample) + if length is not None: + tree.add_edge(root_sample, remaining_samples[0], length=length/2) + tree.add_edge(root_sample, remaining_samples[1], length=length/2) + else: + tree.add_edges_from( + [(root_sample, remaining_samples[0]), (root_sample, remaining_samples[1])] + ) rooted_tree = nx.DiGraph() - for e in nx.dfs_edges(tree, source="root"): - rooted_tree.add_edge(e[0], e[1]) + for u, v in nx.dfs_edges(tree, source=root_sample): + edge_data = tree.get_edge_data(u, v) + rooted_tree.add_edge(u, v, **edge_data) return rooted_tree diff --git a/test/solver_tests/ccphylo_solver_test.py b/test/solver_tests/ccphylo_solver_test.py index a70a0289..12e36e4b 100755 --- a/test/solver_tests/ccphylo_solver_test.py +++ b/test/solver_tests/ccphylo_solver_test.py @@ -155,6 +155,14 @@ def test_ccphylo_dnj_solver(self): dnj_tree = self.basic_tree.copy() self.ccphylo_dnj_solver.solve(dnj_tree) + # Test for edge lengths and node times + tree = dnj_tree.get_tree_topology() + self.assertAlmostEqual(tree.nodes["root"]["time"], 0.0) + self.assertAlmostEqual(tree.edges[('root', 'cassiopeia_internal_node1')]["length"], 0.196666667) + self.assertAlmostEqual(tree.nodes["cassiopeia_internal_node1"]["time"], 0.196666667) + for edge in tree.edges: + self.assertTrue("length" in tree.edges[edge]) + # test for expected number of edges self.assertEqual(len(nj_tree.edges), len(dnj_tree.edges)) @@ -201,6 +209,13 @@ def test_ccphylo_upgma_solver(self): # test for expected number of edges self.assertEqual(len(upgma_tree.edges), len(ccphylo_upgma_tree.edges)) + # Test for edge lengths and node times + tree = ccphylo_upgma_tree.get_tree_topology() + self.assertAlmostEqual(tree.nodes["root"]["time"], 0.0) + self.assertAlmostEqual(tree.edges[('root', 'cassiopeia_internal_node0')]["length"], 0.4484375) + self.assertAlmostEqual(tree.nodes["cassiopeia_internal_node0"]["time"], 0.4484375) + for edge in tree.edges: + self.assertTrue("length" in tree.edges[edge]) triplets = itertools.combinations(["a", "c", "d", "e"], 3) for triplet in triplets: