Skip to content

Commit

Permalink
Merge pull request #6 from YosefLab/add-depth
Browse files Browse the repository at this point in the history
add depth
  • Loading branch information
colganwi authored May 28, 2024
2 parents e5596a2 + 22aed5a commit 8fffb2c
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 5 deletions.
10 changes: 10 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@

## Preprocessing

```{eval-rst}
.. module:: pycea.pp
.. currentmodule:: pycea
.. autosummary::
:toctree: generated
pp.add_depth
```

## Tools

```{eval-rst}
Expand Down
2 changes: 1 addition & 1 deletion src/pycea/pl/plot_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def annotation(
**kwargs,
) -> Axes:
"""\
Plot leaf annotations.
Plot leaf annotations for a tree.
Parameters
----------
Expand Down
1 change: 1 addition & 0 deletions src/pycea/pp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .setup_tree import add_depth
39 changes: 39 additions & 0 deletions src/pycea/pp/setup_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from __future__ import annotations

from collections.abc import Sequence

import networkx as nx
import treedata as td

from pycea.utils import get_keyed_node_data, get_root, get_trees


def _add_depth(tree, depth_key):
"""Adds a depth attribute to the nodes of a tree."""
root = get_root(tree)
depths = nx.single_source_shortest_path_length(tree, root)
nx.set_node_attributes(tree, depths, depth_key)


def add_depth(
tdata: td.TreeData, depth_key: str = "depth", tree: str | Sequence[str] | None = None, copy: bool = False
):
"""Adds a depth attribute to the nodes of a tree.
Parameters
----------
tdata
TreeData object.
depth_key
Node attribute key to store the depth.
tree
The `obst` key or keys of the trees to use. If `None`, all trees are used.
copy
If True, returns a pd.DataFrame node depths.
"""
tree_keys = tree
trees = get_trees(tdata, tree_keys)
for _, tree in trees.items():
_add_depth(tree, depth_key)
if copy:
return get_keyed_node_data(tdata, depth_key)
4 changes: 2 additions & 2 deletions src/pycea/tl/clades.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _clade_name_generator():


def _clades(tree, depth, depth_key, clades, clade_key, name_generator):
"""Identifies clades in a tree."""
"""Marks clades in a tree."""
if (depth is not None) and (clades is None):
nodes = _nodes_at_depth(tree, get_root(tree), [], depth, depth_key)
clades = dict(zip(nodes, name_generator))
Expand Down Expand Up @@ -61,7 +61,7 @@ def clades(
tree: str | Sequence[str] | None = None,
copy: bool = False,
) -> None | Mapping:
"""Identifies clades in a tree.
"""Marks clades in a tree.
Parameters
----------
Expand Down
7 changes: 5 additions & 2 deletions src/pycea/tl/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@ def _sort_tree(tree, key, reverse=False):
try:
sorted_children = sorted(tree.successors(node), key=lambda x: tree.nodes[x][key], reverse=reverse)
except KeyError as err:
raise KeyError(f"Node {next(tree.successors(node))} does not have a {key} attribute.") from err
raise KeyError(
f"Node {next(tree.successors(node))} does not have a {key} attribute.",
"You may need to call `ancestral_states` to infer internal node values",
) from err
tree.remove_edges_from([(node, child) for child in tree.successors(node)])
tree.add_edges_from([(node, child) for child in sorted_children])
return tree


def sort(tdata: td.TreeData, key: str, reverse: bool = False, tree: str | Sequence[str] | None = None) -> None:
"""Reorders branches based on a given key.
"""Reorders branches based on a node attribute.
Parameters
----------
Expand Down
26 changes: 26 additions & 0 deletions tests/test_setup_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import networkx as nx
import pandas as pd
import pytest
import treedata as td

from pycea.pp.setup_tree import add_depth


@pytest.fixture
def tdata():
tree1 = nx.DiGraph([("root", "A"), ("root", "B"), ("B", "C"), ("B", "D")])
tree2 = nx.DiGraph([("root", "E"), ("root", "F")])
tdata = td.TreeData(obs=pd.DataFrame(index=["A", "C", "D", "E", "F"]), obst={"tree1": tree1, "tree2": tree2})
yield tdata


def test_add_depth(tdata):
depths = add_depth(tdata, depth_key="depth", copy=True)
assert depths.loc[("tree1", "root"), "depth"] == 0
assert depths.loc[("tree1", "C"), "depth"] == 2
assert tdata.obst["tree1"].nodes["root"]["depth"] == 0
assert tdata.obst["tree1"].nodes["C"]["depth"] == 2


if __name__ == "__main__":
pytest.main(["-v", __file__])

0 comments on commit 8fffb2c

Please sign in to comment.