From ddcedf237f648b02fe583171159c51ba00c8487a Mon Sep 17 00:00:00 2001 From: colganwi Date: Mon, 27 May 2024 20:17:07 -0400 Subject: [PATCH] updated utils --- src/pycea/pl/plot_tree.py | 26 +++++++++--- src/pycea/tl/ancestral_states.py | 13 +----- src/pycea/utils.py | 70 ++++++++++++++++---------------- tests/test_ancestral_states.py | 9 ++-- tests/test_utils.py | 24 +++++------ 5 files changed, 72 insertions(+), 70 deletions(-) diff --git a/src/pycea/pl/plot_tree.py b/src/pycea/pl/plot_tree.py index a35b3e6..ef9172d 100644 --- a/src/pycea/pl/plot_tree.py +++ b/src/pycea/pl/plot_tree.py @@ -101,7 +101,11 @@ def branches( if mcolors.is_color_like(color): kwargs.update({"color": color}) elif isinstance(color, str): - color_data = get_keyed_edge_data(trees, color) + color_data = get_keyed_edge_data(tdata, color, tree_keys)[color] + print(color_data) + if len(color_data) == 0: + raise ValueError(f"Key {color!r} is not present in any edge.") + color_data.index = color_data.index.map(lambda x: f"{x[0]}-{x[1][0]}-{x[1][1]}") if color_data.dtype.kind in ["i", "f"]: norm = plt.Normalize(vmin=color_data.min(), vmax=color_data.max()) cmap = plt.get_cmap(cmap) @@ -125,7 +129,10 @@ def branches( if isinstance(linewidth, (int, float)): kwargs.update({"linewidth": linewidth}) elif isinstance(linewidth, str): - linewidth_data = get_keyed_edge_data(trees, linewidth) + linewidth_data = get_keyed_edge_data(tdata, linewidth, tree_keys)[linewidth] + if len(linewidth_data) == 0: + raise ValueError(f"Key {linewidth!r} is not present in any edge.") + linewidth_data.index = linewidth_data.index.map(lambda x: f"{x[0]}-{x[1][0]}-{x[1][1]}") if linewidth_data.dtype.kind in ["i", "f"]: linewidths = [linewidth_data[edge] if edge in linewidth_data.index else na_linewidth for edge in edges] kwargs.update({"linewidth": linewidths}) @@ -271,7 +278,10 @@ def nodes( if mcolors.is_color_like(color): kwargs.update({"color": color}) elif isinstance(color, str): - color_data = get_keyed_node_data(trees, color) + color_data = get_keyed_node_data(tdata, color, tree_keys)[color] + if len(color_data) == 0: + raise ValueError(f"Key {color!r} is not present in any node.") + color_data.index = color_data.index.map("-".join) if color_data.dtype.kind in ["i", "f"]: if not vmin: vmin = color_data.min() @@ -298,7 +308,10 @@ def nodes( if isinstance(size, (int, float)): kwargs.update({"s": size}) elif isinstance(size, str): - size_data = get_keyed_node_data(trees, size) + size_data = get_keyed_node_data(tdata, size, tree_keys)[size] + if len(size_data) == 0: + raise ValueError(f"Key {size!r} is not present in any node.") + size_data.index = size_data.index.map("-".join) sizes = [size_data[node] if node in size_data.index else na_size for node in nodes] kwargs.update({"s": sizes}) else: @@ -307,7 +320,10 @@ def nodes( if style in mmarkers.MarkerStyle.markers: kwargs.update({"marker": style}) elif isinstance(style, str): - style_data = get_keyed_node_data(trees, style) + style_data = get_keyed_node_data(tdata, style, tree_keys)[style] + if len(style_data) == 0: + raise ValueError(f"Key {style!r} is not present in any node.") + style_data.index = style_data.index.map("-".join) mmap = _get_categorical_markers(tdata, style, style_data, markers) styles = [mmap[style_data[node]] if node in style_data.index else na_style for node in nodes] for style in set(styles): diff --git a/src/pycea/tl/ancestral_states.py b/src/pycea/tl/ancestral_states.py index de57ed1..bb5fc2d 100755 --- a/src/pycea/tl/ancestral_states.py +++ b/src/pycea/tl/ancestral_states.py @@ -257,15 +257,4 @@ def ancestral_states( nx.set_node_attributes(tree, data[key].to_dict(), key) _ancestral_states(tree, key, method, missing_state, default_state) if copy: - states = [] - for name, tree in trees.items(): - tree_states = [] - for key in keys: - data = get_keyed_node_data(tree, key) - tree_states.append(data) - tree_states = pd.concat(tree_states, axis=1) - tree_states["tree"] = name - states.append(tree_states) - states = pd.concat(states) - states["node"] = states.index - return states.reset_index(drop=True) + return get_keyed_node_data(tdata, keys, tree_keys) diff --git a/src/pycea/utils.py b/src/pycea/utils.py index 3e77a49..7648411 100755 --- a/src/pycea/utils.py +++ b/src/pycea/utils.py @@ -24,48 +24,46 @@ def get_leaves(tree: nx.DiGraph): return [node for node in nx.dfs_postorder_nodes(tree, get_root(tree)) if tree.out_degree(node) == 0] -def get_keyed_edge_data(tree: nx.DiGraph | Mapping[str, nx.DiGraph], key: str) -> pd.Series: +def get_keyed_edge_data( + tdata: td.TreeData, keys: str | Sequence[str], tree_keys: str | Sequence[str] = None +) -> pd.DataFrame: """Gets edge data for a given key from a tree or set of trees.""" - if isinstance(tree, nx.DiGraph): - trees = {"": tree} - sep = "" - else: - trees = tree - sep = "-" - edge_data = {} + if isinstance(tree_keys, str): + tree_keys = [tree_keys] + if isinstance(keys, str): + keys = [keys] + trees = get_trees(tdata, tree_keys) + data = [] for name, tree in trees.items(): - edge_data.update( - { - (f"{name}{sep}{parent}", f"{name}{sep}{child}"): data.get(key) - for parent, child, data in tree.edges(data=True) - if key in data and data[key] is not None - } - ) - if len(edge_data) == 0: - raise ValueError(f"Key {key!r} is not present in any edge.") - return pd.Series(edge_data, name=key) + edge_data = {key: nx.get_edge_attributes(tree, key) for key in keys} + edge_data = pd.DataFrame(edge_data) + edge_data["tree"] = name + edge_data["edge"] = edge_data.index + data.append(edge_data) + data = pd.concat(data) + data = data.set_index(["tree", "edge"]) + return data -def get_keyed_node_data(tree: nx.DiGraph | Mapping[str, nx.DiGraph], key: str) -> pd.Series: +def get_keyed_node_data( + tdata: td.TreeData, keys: str | Sequence[str], tree_keys: str | Sequence[str] = None +) -> pd.DataFrame: """Gets node data for a given key a tree or set of trees.""" - if isinstance(tree, nx.DiGraph): - trees = {"": tree} - sep = "" - else: - trees = tree - sep = "-" - node_data = {} + if isinstance(tree_keys, str): + tree_keys = [tree_keys] + if isinstance(keys, str): + keys = [keys] + trees = get_trees(tdata, tree_keys) + data = [] for name, tree in trees.items(): - node_data.update( - { - f"{name}{sep}{node}": data.get(key) - for node, data in tree.nodes(data=True) - if key in data and data[key] is not None - } - ) - if len(node_data) == 0: - raise ValueError(f"Key {key!r} is not present in any node.") - return pd.Series(node_data, name=key) + tree_data = {key: nx.get_node_attributes(tree, key) for key in keys} + tree_data = pd.DataFrame(tree_data) + tree_data["tree"] = name + data.append(tree_data) + data = pd.concat(data) + data["node"] = data.index + data = data.set_index(["tree", "node"]) + return data def get_keyed_obs_data(tdata: td.TreeData, keys: Sequence[str], layer: str = None) -> pd.DataFrame: diff --git a/tests/test_ancestral_states.py b/tests/test_ancestral_states.py index e541a05..8ea8662 100755 --- a/tests/test_ancestral_states.py +++ b/tests/test_ancestral_states.py @@ -46,7 +46,7 @@ def test_ancestral_states_array(tdata): print(states) assert tdata.obst["tree1"].nodes["root"]["spatial"] == [1.0, 2.0] assert tdata.obst["tree1"].nodes["C"]["spatial"] == [1.5, 1.0] - assert states["spatial"][0] == [1.0, 2.0] + assert states.loc[("tree1", "root"), "spatial"] == [1.0, 2.0] # Median states = ancestral_states(tdata, "spatial", method=np.median, copy=True) assert tdata.obst["tree1"].nodes["root"]["spatial"] == [1.0, 1.0] @@ -58,14 +58,15 @@ def test_ancestral_states_missing(tdata): print(states) assert tdata.obst["tree1"].nodes["root"]["with_missing"] == 1.5 assert tdata.obst["tree1"].nodes["C"]["with_missing"] == 3 - assert states["with_missing"][0] == 1.5 + assert states.loc[("tree1", "root"), "with_missing"] == 1.5 def test_ancestral_state_fitch(tdata): states = ancestral_states(tdata, "characters", method="fitch_hartigan", copy=True) assert tdata.obst["tree1"].nodes["root"]["characters"] == ["1", "0"] assert tdata.obst["tree2"].nodes["F"]["characters"] == ["1", "2"] - assert states["characters"][0] == ["1", "0"] + print(states) + assert states.loc[("tree1", "root"), "characters"] == ["1", "0"] def test_ancestral_states_sankoff(tdata): @@ -77,7 +78,7 @@ def test_ancestral_states_sankoff(tdata): states = ancestral_states(tdata, "characters", method="sankoff", costs=costs, copy=True) assert tdata.obst["tree1"].nodes["root"]["characters"] == ["0", "0"] assert tdata.obst["tree2"].nodes["F"]["characters"] == ["1", "2"] - assert states["characters"][0] == ["0", "0"] + assert states.loc[("tree1", "root"), "characters"] == ["0", "0"] costs = pd.DataFrame( [[0, 10, 10], [1, 0, 2], [2, 1, 0]], index=["0", "1", "2"], diff --git a/tests/test_utils.py b/tests/test_utils.py index 3f401d5..22f8491 100755 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,7 +2,7 @@ import pandas as pd import pytest -from pycea.utils import get_keyed_edge_data, get_keyed_node_data, get_keyed_obs_data, get_root, get_leaves +from pycea.utils import get_keyed_edge_data, get_keyed_node_data, get_keyed_obs_data, get_leaves, get_root @pytest.fixture @@ -31,20 +31,14 @@ def test_get_leaves(tree): assert get_leaves(nx.DiGraph()) == [] -def test_get_keyed_edge_data(tree): - result = get_keyed_edge_data(tree, "weight") - expected_keys = [("A", "B"), ("B", "D"), ("C", "E")] - expected_values = [5, 3, 4] - assert all(result[key] == value for key, value in zip(expected_keys, expected_values)) - assert ("A", "C") not in result +def test_get_keyed_edge_data(tdata): + data = get_keyed_edge_data(tdata, ["length", "clade"]) + assert data.columns.tolist() == ["length", "clade"] -def test_get_keyed_node_data(tree): - result = get_keyed_node_data(tree, "value") - expected_keys = ["A", "B", "D", "E"] - expected_values = [1, 2, 4, 5] - assert all(result[key] == value for key, value in zip(expected_keys, expected_values)) - assert "C" not in result +def test_get_keyed_node_data(tdata): + data = get_keyed_node_data(tdata, ["x", "y", "clade"]) + assert data.columns.tolist() == ["x", "y", "clade"] def test_get_keyed_obs_data_valid_keys(tdata): @@ -70,3 +64,7 @@ def test_get_keyed_obs_data_invalid_keys(tdata): get_keyed_obs_data(tdata, ["clade", "x", "0", "invalid_key"]) with pytest.raises(ValueError): get_keyed_obs_data(tdata, ["clade", "spatial_distance"]) + + +if __name__ == "__main__": + pytest.main(["-v", __file__])