Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ancestral states #5

Merged
merged 4 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
.. autosummary::
:toctree: generated

tl.ancestral_states
tl.clades
tl.sort
```
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [
"numpy",
"pandas",
"session-info",
"scipy",
]

[project.optional-dependencies]
Expand Down
26 changes: 21 additions & 5 deletions src/pycea/pl/plot_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ def branches(
if mcolors.is_color_like(color):
kwargs.update({"color": color})
elif isinstance(color, str):
color_data = get_keyed_edge_data(trees, color)
color_data = get_keyed_edge_data(tdata, color, tree_keys)[color]
print(color_data)
if len(color_data) == 0:
raise ValueError(f"Key {color!r} is not present in any edge.")
color_data.index = color_data.index.map(lambda x: f"{x[0]}-{x[1][0]}-{x[1][1]}")
if color_data.dtype.kind in ["i", "f"]:
norm = plt.Normalize(vmin=color_data.min(), vmax=color_data.max())
cmap = plt.get_cmap(cmap)
Expand All @@ -125,7 +129,10 @@ def branches(
if isinstance(linewidth, (int, float)):
kwargs.update({"linewidth": linewidth})
elif isinstance(linewidth, str):
linewidth_data = get_keyed_edge_data(trees, linewidth)
linewidth_data = get_keyed_edge_data(tdata, linewidth, tree_keys)[linewidth]
if len(linewidth_data) == 0:
raise ValueError(f"Key {linewidth!r} is not present in any edge.")
linewidth_data.index = linewidth_data.index.map(lambda x: f"{x[0]}-{x[1][0]}-{x[1][1]}")
if linewidth_data.dtype.kind in ["i", "f"]:
linewidths = [linewidth_data[edge] if edge in linewidth_data.index else na_linewidth for edge in edges]
kwargs.update({"linewidth": linewidths})
Expand Down Expand Up @@ -271,7 +278,10 @@ def nodes(
if mcolors.is_color_like(color):
kwargs.update({"color": color})
elif isinstance(color, str):
color_data = get_keyed_node_data(trees, color)
color_data = get_keyed_node_data(tdata, color, tree_keys)[color]
if len(color_data) == 0:
raise ValueError(f"Key {color!r} is not present in any node.")
color_data.index = color_data.index.map("-".join)
if color_data.dtype.kind in ["i", "f"]:
if not vmin:
vmin = color_data.min()
Expand All @@ -298,7 +308,10 @@ def nodes(
if isinstance(size, (int, float)):
kwargs.update({"s": size})
elif isinstance(size, str):
size_data = get_keyed_node_data(trees, size)
size_data = get_keyed_node_data(tdata, size, tree_keys)[size]
if len(size_data) == 0:
raise ValueError(f"Key {size!r} is not present in any node.")
size_data.index = size_data.index.map("-".join)
sizes = [size_data[node] if node in size_data.index else na_size for node in nodes]
kwargs.update({"s": sizes})
else:
Expand All @@ -307,7 +320,10 @@ def nodes(
if style in mmarkers.MarkerStyle.markers:
kwargs.update({"marker": style})
elif isinstance(style, str):
style_data = get_keyed_node_data(trees, style)
style_data = get_keyed_node_data(tdata, style, tree_keys)[style]
if len(style_data) == 0:
raise ValueError(f"Key {style!r} is not present in any node.")
style_data.index = style_data.index.map("-".join)
mmap = _get_categorical_markers(tdata, style, style_data, markers)
styles = [mmap[style_data[node]] if node in style_data.index else na_style for node in nodes]
for style in set(styles):
Expand Down
1 change: 1 addition & 0 deletions src/pycea/tl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .clades import clades
from .sort import sort
from .ancestral_states import ancestral_states
260 changes: 260 additions & 0 deletions src/pycea/tl/ancestral_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
from __future__ import annotations

from collections.abc import Sequence

import networkx as nx
import numpy as np
import pandas as pd
import treedata as td

from pycea.utils import get_keyed_node_data, get_keyed_obs_data, get_root, get_trees


def _most_common(arr):
"""Finds the most common element in a list."""
unique_values, counts = np.unique(arr, return_counts=True)
most_common_index = np.argmax(counts)
return unique_values[most_common_index]


def _get_node_value(tree, node, key, index):
"""Gets the value of a node attribute."""
if key in tree.nodes[node]:
if index is not None:
return tree.nodes[node][key][index]
else:
return tree.nodes[node][key]
else:
return None


def _set_node_value(tree, node, key, value, index):
"""Sets the value of a node attribute."""
if index is not None:
tree.nodes[node][key][index] = value
else:
tree.nodes[node][key] = value


def _reconstruct_fitch_hartigan(tree, key, missing="-1", index=None):
"""Reconstructs ancestral states using the Fitch-Hartigan algorithm."""

# Recursive function to calculate the downpass
def downpass(node):
# Base case: leaf
if tree.out_degree(node) == 0:
value = _get_node_value(tree, node, key, index)
if value == missing:
tree.nodes[node]["value_set"] = missing
else:
tree.nodes[node]["value_set"] = {value}
# Recursive case: internal node
else:
value_sets = []
for child in tree.successors(node):
downpass(child)
value_set = tree.nodes[child]["value_set"]
if value_set != missing:
value_sets.append(value_set)
if len(value_sets) > 0:
intersection = set.intersection(*value_sets)
if intersection:
tree.nodes[node]["value_set"] = intersection
else:
tree.nodes[node]["value_set"] = set.union(*value_sets)
else:
tree.nodes[node]["value_set"] = missing

# Recursive function to calculate the uppass
def uppass(node, parent_state=None):
value = _get_node_value(tree, node, key, index)
if value is None:
if parent_state and parent_state in tree.nodes[node]["value_set"]:
value = parent_state
else:
value = min(tree.nodes[node]["value_set"])
_set_node_value(tree, node, key, value, index)
elif value == missing:
value = parent_state
_set_node_value(tree, node, key, value, index)
for child in tree.successors(node):
uppass(child, value)

# Run the algorithm
root = get_root(tree)
downpass(root)
uppass(root)
# Clean up
for node in tree.nodes:
if "value_set" in tree.nodes[node]:
del tree.nodes[node]["value_set"]


def _reconstruct_sankoff(tree, key, costs, missing="-1", index=None):
"""Reconstructs ancestral states using the Sankoff algorithm."""

# Recursive function to calculate the Sankoff scores
def sankoff_scores(node):
# Base case: leaf
if tree.out_degree(node) == 0:
leaf_value = _get_node_value(tree, node, key, index)
if leaf_value == missing:
return {value: 0 for value in alphabet}
else:
return {value: 0 if value == leaf_value else float("inf") for value in alphabet}
# Recursive case: internal node
else:
scores = {value: 0 for value in alphabet}
pointers = {value: {} for value in alphabet}
for child in tree.successors(node):
child_scores = sankoff_scores(child)
for value in alphabet:
min_cost, min_value = float("inf"), None
for child_value in alphabet:
cost = child_scores[child_value] + costs.loc[value, child_value]
if cost < min_cost:
min_cost, min_value = cost, child_value
scores[value] += min_cost
pointers[value][child] = min_value
tree.nodes[node]["_pointers"] = pointers
return scores

# Recursive function to traceback the Sankoff scores
def traceback(node, parent_value=None):
for child in tree.successors(node):
child_value = tree.nodes[node]["_pointers"][parent_value][child]
_set_node_value(tree, child, key, child_value, index)
traceback(child, child_value)

# Get scores
root = get_root(tree)
alphabet = set(costs.index)
root_scores = sankoff_scores(root)
# Reconstruct ancestral states
root_value = min(root_scores, key=root_scores.get)
_set_node_value(tree, root, key, root_value, index)
traceback(root, root_value)
# Clean up
for node in tree.nodes:
if "_pointers" in tree.nodes[node]:
del tree.nodes[node]["_pointers"]


def _reconstruct_mean(tree, key, index):
"""Reconstructs ancestral by averaging the values of the children."""

def subtree_mean(node):
if tree.out_degree(node) == 0:
return _get_node_value(tree, node, key, index), 1
else:
values, weights = [], []
for child in tree.successors(node):
child_value, child_n = subtree_mean(child)
values.append(child_value)
weights.append(child_n)
mean_value = np.average(values, weights=weights)
_set_node_value(tree, node, key, mean_value, index)
return mean_value, sum(weights)

root = get_root(tree)
subtree_mean(root)


def _reconstruct_list(tree, key, sum_func, index):
"""Reconstructs ancestral states by concatenating the values of the children."""

def subtree_list(node):
if tree.out_degree(node) == 0:
return [_get_node_value(tree, node, key, index)]
else:
values = []
for child in tree.successors(node):
values.extend(subtree_list(child))
_set_node_value(tree, node, key, sum_func(values), index)
return values

root = get_root(tree)
subtree_list(root)


def _ancestral_states(tree, key, method="mean", costs=None, missing=None, default=None, index=None):
"""Reconstructs ancestral states for a given attribute using a given method"""
if method == "sankoff":
if costs is None:
raise ValueError("Costs matrix must be provided for Sankoff algorithm.")
_reconstruct_sankoff(tree, key, costs, missing, index)
elif method == "fitch_hartigan":
_reconstruct_fitch_hartigan(tree, key, missing, index)
elif method == "mean":
_reconstruct_mean(tree, key, index)
elif method == "mode":
_reconstruct_list(tree, key, _most_common, index)
elif callable(method):
_reconstruct_list(tree, key, method, index)
else:
raise ValueError(f"Method {method} not recognized.")


def ancestral_states(
tdata: td.TreeData,
keys: str | Sequence[str],
method: str = "mean",
missing_state: str = "-1",
default_state: str = "0",
costs: pd.DataFrame = None,
tree: str | Sequence[str] | None = None,
copy: bool = False,
) -> None:
"""Reconstructs ancestral states for an attribute.

Parameters
----------
tdata
TreeData object.
keys
One or more `obs_keys`, `var_names`, `obsm_keys`, or `obsp_keys` to reconstruct.
method
Method to reconstruct ancestral states. One of "mean", "mode", "fitch_hartigan", "sankoff",
or any function that takes a list of values and returns a single value.
missing_state
The state to consider as missing data.
default_state
The expected state for the root node.
costs
A pd.DataFrame with the costs of changing states (from rows to columns).
tree
The `obst` key or keys of the trees to use. If `None`, all trees are used.
copy
If True, returns a pd.DataFrame with ancestral states.
"""
if isinstance(keys, str):
keys = [keys]
tree_keys = tree
trees = get_trees(tdata, tree_keys)
for _, tree in trees.items():
data, is_array = get_keyed_obs_data(tdata, keys)
dtypes = {dtype.kind for dtype in data.dtypes}
# Check data type
if dtypes.intersection({"i", "f"}):
if method in ["fitch_hartigan", "sankoff"]:
raise ValueError(f"Method {method} requires categorical data.")
if dtypes.intersection({"O", "S"}):
if method in ["mean"]:
raise ValueError(f"Method {method} requires numerical data.")
# If array add to tree as list
if is_array:
length = data.shape[1]
node_attrs = data.apply(lambda row: list(row), axis=1).to_dict()
for node in tree.nodes:
if node not in node_attrs:
node_attrs[node] = [None] * length
nx.set_node_attributes(tree, node_attrs, keys[0])
for index in range(length):
_ancestral_states(tree, keys[0], method, costs, missing_state, default_state, index)
# If column add to tree as scalar
else:
for key in keys:
nx.set_node_attributes(tree, data[key].to_dict(), key)
_ancestral_states(tree, key, method, missing_state, default_state)
if copy:
return get_keyed_node_data(tdata, keys, tree_keys)
2 changes: 1 addition & 1 deletion src/pycea/tl/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _sort_tree(tree, key, reverse=False):


def sort(tdata: td.TreeData, key: str, reverse: bool = False, tree: str | Sequence[str] | None = None) -> None:
"""Sorts the children of each internal node in a tree based on a given key.
"""Reorders branches based on a given key.

Parameters
----------
Expand Down
Loading
Loading