Skip to content

Commit

Permalink
updated utils
Browse files Browse the repository at this point in the history
  • Loading branch information
colganwi committed May 28, 2024
1 parent 5216b5e commit ddcedf2
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 70 deletions.
26 changes: 21 additions & 5 deletions src/pycea/pl/plot_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ def branches(
if mcolors.is_color_like(color):
kwargs.update({"color": color})
elif isinstance(color, str):
color_data = get_keyed_edge_data(trees, color)
color_data = get_keyed_edge_data(tdata, color, tree_keys)[color]
print(color_data)
if len(color_data) == 0:
raise ValueError(f"Key {color!r} is not present in any edge.")
color_data.index = color_data.index.map(lambda x: f"{x[0]}-{x[1][0]}-{x[1][1]}")
if color_data.dtype.kind in ["i", "f"]:
norm = plt.Normalize(vmin=color_data.min(), vmax=color_data.max())
cmap = plt.get_cmap(cmap)
Expand All @@ -125,7 +129,10 @@ def branches(
if isinstance(linewidth, (int, float)):
kwargs.update({"linewidth": linewidth})
elif isinstance(linewidth, str):
linewidth_data = get_keyed_edge_data(trees, linewidth)
linewidth_data = get_keyed_edge_data(tdata, linewidth, tree_keys)[linewidth]
if len(linewidth_data) == 0:
raise ValueError(f"Key {linewidth!r} is not present in any edge.")
linewidth_data.index = linewidth_data.index.map(lambda x: f"{x[0]}-{x[1][0]}-{x[1][1]}")
if linewidth_data.dtype.kind in ["i", "f"]:
linewidths = [linewidth_data[edge] if edge in linewidth_data.index else na_linewidth for edge in edges]
kwargs.update({"linewidth": linewidths})
Expand Down Expand Up @@ -271,7 +278,10 @@ def nodes(
if mcolors.is_color_like(color):
kwargs.update({"color": color})
elif isinstance(color, str):
color_data = get_keyed_node_data(trees, color)
color_data = get_keyed_node_data(tdata, color, tree_keys)[color]
if len(color_data) == 0:
raise ValueError(f"Key {color!r} is not present in any node.")
color_data.index = color_data.index.map("-".join)
if color_data.dtype.kind in ["i", "f"]:
if not vmin:
vmin = color_data.min()
Expand All @@ -298,7 +308,10 @@ def nodes(
if isinstance(size, (int, float)):
kwargs.update({"s": size})
elif isinstance(size, str):
size_data = get_keyed_node_data(trees, size)
size_data = get_keyed_node_data(tdata, size, tree_keys)[size]
if len(size_data) == 0:
raise ValueError(f"Key {size!r} is not present in any node.")
size_data.index = size_data.index.map("-".join)
sizes = [size_data[node] if node in size_data.index else na_size for node in nodes]
kwargs.update({"s": sizes})
else:
Expand All @@ -307,7 +320,10 @@ def nodes(
if style in mmarkers.MarkerStyle.markers:
kwargs.update({"marker": style})
elif isinstance(style, str):
style_data = get_keyed_node_data(trees, style)
style_data = get_keyed_node_data(tdata, style, tree_keys)[style]
if len(style_data) == 0:
raise ValueError(f"Key {style!r} is not present in any node.")
style_data.index = style_data.index.map("-".join)
mmap = _get_categorical_markers(tdata, style, style_data, markers)
styles = [mmap[style_data[node]] if node in style_data.index else na_style for node in nodes]
for style in set(styles):
Expand Down
13 changes: 1 addition & 12 deletions src/pycea/tl/ancestral_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,15 +257,4 @@ def ancestral_states(
nx.set_node_attributes(tree, data[key].to_dict(), key)
_ancestral_states(tree, key, method, missing_state, default_state)
if copy:
states = []
for name, tree in trees.items():
tree_states = []
for key in keys:
data = get_keyed_node_data(tree, key)
tree_states.append(data)
tree_states = pd.concat(tree_states, axis=1)
tree_states["tree"] = name
states.append(tree_states)
states = pd.concat(states)
states["node"] = states.index
return states.reset_index(drop=True)
return get_keyed_node_data(tdata, keys, tree_keys)
70 changes: 34 additions & 36 deletions src/pycea/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,48 +24,46 @@ def get_leaves(tree: nx.DiGraph):
return [node for node in nx.dfs_postorder_nodes(tree, get_root(tree)) if tree.out_degree(node) == 0]


def get_keyed_edge_data(tree: nx.DiGraph | Mapping[str, nx.DiGraph], key: str) -> pd.Series:
def get_keyed_edge_data(
tdata: td.TreeData, keys: str | Sequence[str], tree_keys: str | Sequence[str] = None
) -> pd.DataFrame:
"""Gets edge data for a given key from a tree or set of trees."""
if isinstance(tree, nx.DiGraph):
trees = {"": tree}
sep = ""
else:
trees = tree
sep = "-"
edge_data = {}
if isinstance(tree_keys, str):
tree_keys = [tree_keys]
if isinstance(keys, str):
keys = [keys]
trees = get_trees(tdata, tree_keys)
data = []
for name, tree in trees.items():
edge_data.update(
{
(f"{name}{sep}{parent}", f"{name}{sep}{child}"): data.get(key)
for parent, child, data in tree.edges(data=True)
if key in data and data[key] is not None
}
)
if len(edge_data) == 0:
raise ValueError(f"Key {key!r} is not present in any edge.")
return pd.Series(edge_data, name=key)
edge_data = {key: nx.get_edge_attributes(tree, key) for key in keys}
edge_data = pd.DataFrame(edge_data)
edge_data["tree"] = name
edge_data["edge"] = edge_data.index
data.append(edge_data)
data = pd.concat(data)
data = data.set_index(["tree", "edge"])
return data


def get_keyed_node_data(tree: nx.DiGraph | Mapping[str, nx.DiGraph], key: str) -> pd.Series:
def get_keyed_node_data(
tdata: td.TreeData, keys: str | Sequence[str], tree_keys: str | Sequence[str] = None
) -> pd.DataFrame:
"""Gets node data for a given key a tree or set of trees."""
if isinstance(tree, nx.DiGraph):
trees = {"": tree}
sep = ""
else:
trees = tree
sep = "-"
node_data = {}
if isinstance(tree_keys, str):
tree_keys = [tree_keys]
if isinstance(keys, str):
keys = [keys]
trees = get_trees(tdata, tree_keys)
data = []
for name, tree in trees.items():
node_data.update(
{
f"{name}{sep}{node}": data.get(key)
for node, data in tree.nodes(data=True)
if key in data and data[key] is not None
}
)
if len(node_data) == 0:
raise ValueError(f"Key {key!r} is not present in any node.")
return pd.Series(node_data, name=key)
tree_data = {key: nx.get_node_attributes(tree, key) for key in keys}
tree_data = pd.DataFrame(tree_data)
tree_data["tree"] = name
data.append(tree_data)
data = pd.concat(data)
data["node"] = data.index
data = data.set_index(["tree", "node"])
return data


def get_keyed_obs_data(tdata: td.TreeData, keys: Sequence[str], layer: str = None) -> pd.DataFrame:
Expand Down
9 changes: 5 additions & 4 deletions tests/test_ancestral_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_ancestral_states_array(tdata):
print(states)
assert tdata.obst["tree1"].nodes["root"]["spatial"] == [1.0, 2.0]
assert tdata.obst["tree1"].nodes["C"]["spatial"] == [1.5, 1.0]
assert states["spatial"][0] == [1.0, 2.0]
assert states.loc[("tree1", "root"), "spatial"] == [1.0, 2.0]
# Median
states = ancestral_states(tdata, "spatial", method=np.median, copy=True)
assert tdata.obst["tree1"].nodes["root"]["spatial"] == [1.0, 1.0]
Expand All @@ -58,14 +58,15 @@ def test_ancestral_states_missing(tdata):
print(states)
assert tdata.obst["tree1"].nodes["root"]["with_missing"] == 1.5
assert tdata.obst["tree1"].nodes["C"]["with_missing"] == 3
assert states["with_missing"][0] == 1.5
assert states.loc[("tree1", "root"), "with_missing"] == 1.5


def test_ancestral_state_fitch(tdata):
states = ancestral_states(tdata, "characters", method="fitch_hartigan", copy=True)
assert tdata.obst["tree1"].nodes["root"]["characters"] == ["1", "0"]
assert tdata.obst["tree2"].nodes["F"]["characters"] == ["1", "2"]
assert states["characters"][0] == ["1", "0"]
print(states)
assert states.loc[("tree1", "root"), "characters"] == ["1", "0"]


def test_ancestral_states_sankoff(tdata):
Expand All @@ -77,7 +78,7 @@ def test_ancestral_states_sankoff(tdata):
states = ancestral_states(tdata, "characters", method="sankoff", costs=costs, copy=True)
assert tdata.obst["tree1"].nodes["root"]["characters"] == ["0", "0"]
assert tdata.obst["tree2"].nodes["F"]["characters"] == ["1", "2"]
assert states["characters"][0] == ["0", "0"]
assert states.loc[("tree1", "root"), "characters"] == ["0", "0"]
costs = pd.DataFrame(
[[0, 10, 10], [1, 0, 2], [2, 1, 0]],
index=["0", "1", "2"],
Expand Down
24 changes: 11 additions & 13 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pandas as pd
import pytest

from pycea.utils import get_keyed_edge_data, get_keyed_node_data, get_keyed_obs_data, get_root, get_leaves
from pycea.utils import get_keyed_edge_data, get_keyed_node_data, get_keyed_obs_data, get_leaves, get_root


@pytest.fixture
Expand Down Expand Up @@ -31,20 +31,14 @@ def test_get_leaves(tree):
assert get_leaves(nx.DiGraph()) == []


def test_get_keyed_edge_data(tree):
result = get_keyed_edge_data(tree, "weight")
expected_keys = [("A", "B"), ("B", "D"), ("C", "E")]
expected_values = [5, 3, 4]
assert all(result[key] == value for key, value in zip(expected_keys, expected_values))
assert ("A", "C") not in result
def test_get_keyed_edge_data(tdata):
data = get_keyed_edge_data(tdata, ["length", "clade"])
assert data.columns.tolist() == ["length", "clade"]


def test_get_keyed_node_data(tree):
result = get_keyed_node_data(tree, "value")
expected_keys = ["A", "B", "D", "E"]
expected_values = [1, 2, 4, 5]
assert all(result[key] == value for key, value in zip(expected_keys, expected_values))
assert "C" not in result
def test_get_keyed_node_data(tdata):
data = get_keyed_node_data(tdata, ["x", "y", "clade"])
assert data.columns.tolist() == ["x", "y", "clade"]


def test_get_keyed_obs_data_valid_keys(tdata):
Expand All @@ -70,3 +64,7 @@ def test_get_keyed_obs_data_invalid_keys(tdata):
get_keyed_obs_data(tdata, ["clade", "x", "0", "invalid_key"])
with pytest.raises(ValueError):
get_keyed_obs_data(tdata, ["clade", "spatial_distance"])


if __name__ == "__main__":
pytest.main(["-v", __file__])

0 comments on commit ddcedf2

Please sign in to comment.