From 22aed5a3e825efd0c5fe8c9ec85898ee6bfaa3a6 Mon Sep 17 00:00:00 2001 From: colganwi Date: Tue, 28 May 2024 18:27:30 -0400 Subject: [PATCH] add depth --- docs/api.md | 10 ++++++++++ src/pycea/pl/plot_tree.py | 2 +- src/pycea/pp/__init__.py | 1 + src/pycea/pp/setup_tree.py | 39 ++++++++++++++++++++++++++++++++++++++ src/pycea/tl/clades.py | 4 ++-- src/pycea/tl/sort.py | 7 +++++-- tests/test_setup_tree.py | 26 +++++++++++++++++++++++++ 7 files changed, 84 insertions(+), 5 deletions(-) create mode 100755 src/pycea/pp/setup_tree.py create mode 100755 tests/test_setup_tree.py diff --git a/docs/api.md b/docs/api.md index 85e286c..a499485 100644 --- a/docs/api.md +++ b/docs/api.md @@ -2,6 +2,16 @@ ## Preprocessing +```{eval-rst} +.. module:: pycea.pp +.. currentmodule:: pycea + +.. autosummary:: + :toctree: generated + + pp.add_depth +``` + ## Tools ```{eval-rst} diff --git a/src/pycea/pl/plot_tree.py b/src/pycea/pl/plot_tree.py index ef9172d..95d627e 100644 --- a/src/pycea/pl/plot_tree.py +++ b/src/pycea/pl/plot_tree.py @@ -371,7 +371,7 @@ def annotation( **kwargs, ) -> Axes: """\ - Plot leaf annotations. + Plot leaf annotations for a tree. Parameters ---------- diff --git a/src/pycea/pp/__init__.py b/src/pycea/pp/__init__.py index e69de29..cd365e1 100644 --- a/src/pycea/pp/__init__.py +++ b/src/pycea/pp/__init__.py @@ -0,0 +1 @@ +from .setup_tree import add_depth diff --git a/src/pycea/pp/setup_tree.py b/src/pycea/pp/setup_tree.py new file mode 100755 index 0000000..6f72cae --- /dev/null +++ b/src/pycea/pp/setup_tree.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from collections.abc import Sequence + +import networkx as nx +import treedata as td + +from pycea.utils import get_keyed_node_data, get_root, get_trees + + +def _add_depth(tree, depth_key): + """Adds a depth attribute to the nodes of a tree.""" + root = get_root(tree) + depths = nx.single_source_shortest_path_length(tree, root) + nx.set_node_attributes(tree, depths, depth_key) + + +def add_depth( + tdata: td.TreeData, depth_key: str = "depth", tree: str | Sequence[str] | None = None, copy: bool = False +): + """Adds a depth attribute to the nodes of a tree. + + Parameters + ---------- + tdata + TreeData object. + depth_key + Node attribute key to store the depth. + tree + The `obst` key or keys of the trees to use. If `None`, all trees are used. + copy + If True, returns a pd.DataFrame node depths. + """ + tree_keys = tree + trees = get_trees(tdata, tree_keys) + for _, tree in trees.items(): + _add_depth(tree, depth_key) + if copy: + return get_keyed_node_data(tdata, depth_key) diff --git a/src/pycea/tl/clades.py b/src/pycea/tl/clades.py index 806ab97..0a61c04 100755 --- a/src/pycea/tl/clades.py +++ b/src/pycea/tl/clades.py @@ -28,7 +28,7 @@ def _clade_name_generator(): def _clades(tree, depth, depth_key, clades, clade_key, name_generator): - """Identifies clades in a tree.""" + """Marks clades in a tree.""" if (depth is not None) and (clades is None): nodes = _nodes_at_depth(tree, get_root(tree), [], depth, depth_key) clades = dict(zip(nodes, name_generator)) @@ -61,7 +61,7 @@ def clades( tree: str | Sequence[str] | None = None, copy: bool = False, ) -> None | Mapping: - """Identifies clades in a tree. + """Marks clades in a tree. Parameters ---------- diff --git a/src/pycea/tl/sort.py b/src/pycea/tl/sort.py index 6930f65..cca7798 100755 --- a/src/pycea/tl/sort.py +++ b/src/pycea/tl/sort.py @@ -14,14 +14,17 @@ def _sort_tree(tree, key, reverse=False): try: sorted_children = sorted(tree.successors(node), key=lambda x: tree.nodes[x][key], reverse=reverse) except KeyError as err: - raise KeyError(f"Node {next(tree.successors(node))} does not have a {key} attribute.") from err + raise KeyError( + f"Node {next(tree.successors(node))} does not have a {key} attribute.", + "You may need to call `ancestral_states` to infer internal node values", + ) from err tree.remove_edges_from([(node, child) for child in tree.successors(node)]) tree.add_edges_from([(node, child) for child in sorted_children]) return tree def sort(tdata: td.TreeData, key: str, reverse: bool = False, tree: str | Sequence[str] | None = None) -> None: - """Reorders branches based on a given key. + """Reorders branches based on a node attribute. Parameters ---------- diff --git a/tests/test_setup_tree.py b/tests/test_setup_tree.py new file mode 100755 index 0000000..1b0b732 --- /dev/null +++ b/tests/test_setup_tree.py @@ -0,0 +1,26 @@ +import networkx as nx +import pandas as pd +import pytest +import treedata as td + +from pycea.pp.setup_tree import add_depth + + +@pytest.fixture +def tdata(): + tree1 = nx.DiGraph([("root", "A"), ("root", "B"), ("B", "C"), ("B", "D")]) + tree2 = nx.DiGraph([("root", "E"), ("root", "F")]) + tdata = td.TreeData(obs=pd.DataFrame(index=["A", "C", "D", "E", "F"]), obst={"tree1": tree1, "tree2": tree2}) + yield tdata + + +def test_add_depth(tdata): + depths = add_depth(tdata, depth_key="depth", copy=True) + assert depths.loc[("tree1", "root"), "depth"] == 0 + assert depths.loc[("tree1", "C"), "depth"] == 2 + assert tdata.obst["tree1"].nodes["root"]["depth"] == 0 + assert tdata.obst["tree1"].nodes["C"]["depth"] == 2 + + +if __name__ == "__main__": + pytest.main(["-v", __file__])