Skip to content

Commit

Permalink
neighbor distance and doc changes
Browse files Browse the repository at this point in the history
  • Loading branch information
colganwi committed Aug 17, 2024
1 parent c3f8329 commit 2811e3a
Show file tree
Hide file tree
Showing 17 changed files with 841 additions and 294 deletions.
4 changes: 0 additions & 4 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ jobs:
python: "3.9"
- os: ubuntu-latest
python: "3.11"
- os: ubuntu-latest
python: "3.11"
pip-flags: "--pre"
name: PRE-RELEASE DEPENDENCIES

name: ${{ matrix.name }} Python ${{ matrix.python }}

Expand Down
4 changes: 4 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
tl.ancestral_states
tl.clades
tl.compare_distance
tl.distance
tl.sort
tl.tree_distance
tl.tree_neighbors
```

## Plotting
Expand Down
16 changes: 9 additions & 7 deletions src/pycea/pp/setup_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Sequence

import networkx as nx
import pandas as pd
import treedata as td

from pycea.utils import get_keyed_leaf_data, get_keyed_node_data, get_root, get_trees
Expand All @@ -17,7 +18,7 @@ def _add_depth(tree, depth_key):

def add_depth(
tdata: td.TreeData, key_added: str = "depth", tree: str | Sequence[str] | None = None, copy: bool = False
):
) -> None | pd.DataFrame:
"""Adds a depth attribute to the tree.
Parameters
Expand All @@ -29,17 +30,18 @@ def add_depth(
tree
The `obst` key or keys of the trees to use. If `None`, all trees are used.
copy
If True, returns a :class:`pandas.DataFrame` with node depths.
If True, returns a :class:`DataFrame <pandas.DataFrame>` with node depths.
Returns
-------
Returns `None` if `copy=False`, else returns a :class:`pandas.DataFrame`. Sets the following fields:
Returns `None` if `copy=False`, else returns node depths.
`tdata.obs[key_added]` : :class:`pandas.Series` (dtype `float`)
Distance from the root node.
`tdata.obst[tree].nodes[key_added]` : `float`
Distance from the root node.
Sets the following fields:
* `tdata.obs[key_added]` : :class:`Series <pandas.Series>` (dtype `float`)
- Distance from the root node.
* `tdata.obst[tree].nodes[key_added]` : `float`
- Distance from the root node.
"""
tree_keys = tree
trees = get_trees(tdata, tree_keys)
Expand Down
1 change: 1 addition & 0 deletions src/pycea/tl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .distance import compare_distance, distance
from .sort import sort
from .tree_distance import tree_distance
from .tree_neighbors import tree_neighbors
45 changes: 40 additions & 5 deletions src/pycea/tl/_metrics.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
from collections.abc import Callable
from typing import Literal, Union
from typing import Literal

import numpy as np
import treedata as td

_MetricFn = Callable[[np.ndarray, np.ndarray], float]
# from sklearn.metrics.pairwise_distances.__doc__:
_MetricSparseCapable = Literal["cityblock", "cosine", "euclidean", "l1", "l2", "manhattan"]
_MetricScipySpatial = Literal[

_Metric = Literal[
"braycurtis",
"canberra",
"chebyshev",
"cityblock",
"cosine",
"correlation",
"dice",
"euclidean",
"hamming",
"jaccard",
"kulsinski",
"l1",
"l2",
"mahalanobis",
"minkowski",
"manhattan",
"rogerstanimoto",
"russellrao",
"seuclidean",
Expand All @@ -25,4 +31,33 @@
"sqeuclidean",
"yule",
]
_Metric = Union[_MetricSparseCapable, _MetricScipySpatial]


def _lca_distance(tree, depth_key, node1, node2, lca):
"""Compute the lca distance between two nodes in a tree."""
if node1 == node2:
return tree.nodes[node1][depth_key]
else:
return tree.nodes[lca][depth_key]


def _path_distance(tree, depth_key, node1, node2, lca):
"""Compute the path distance between two nodes in a tree."""
if node1 == node2:
return 0
else:
return abs(tree.nodes[node1][depth_key] + tree.nodes[node2][depth_key] - 2 * tree.nodes[lca][depth_key])


_TreeMetricFn = Callable[[td.TreeData, str, str, str, str], np.ndarray]

_TreeMetric = Literal["lca", "path"]


def _get_tree_metric(metric: str) -> _TreeMetricFn:
if metric == "lca":
return _lca_distance
elif metric == "path":
return _path_distance
else:
raise ValueError(f"Unknown metric: {metric}. Valid metrics are 'lca' and 'path'.")
104 changes: 104 additions & 0 deletions src/pycea/tl/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Tool utilities"""

from __future__ import annotations

import random
import warnings
from collections.abc import Mapping, Sequence

import numpy as np
import scipy as sp
import treedata as td


def _check_previous_params(tdata: td.TreeData, params: Mapping, key: str, suffixes: Sequence[str]) -> None:
"""When a function is updating previous results, check that the parameters are the same."""
for suffix in suffixes:
if f"{key}_{suffix}" in tdata.uns:
prev_params = tdata.uns[f"{key}_{suffix}"]["params"]
for param, value in params.items():
if param not in prev_params or prev_params[param] != value:
raise ValueError(
f"{param} value does not match previous call. "
f"Previous: {prev_params}. Current: {params}. "
f"Set `update=False` to avoid this error."
)
return None


def _csr_data_mask(csr):
"""Boolean mask of explicit data in a csr matrix including zeros"""
return sp.sparse.csr_matrix((np.ones(len(csr.data), dtype=bool), csr.indices, csr.indptr), shape=csr.shape)


def _set_random_state(random_state):
"""Set random state"""
if random_state is not None:
random.seed(random_state)
np.random.seed(random_state)
return


def _format_keys(keys, suffix):
"""Ensures that keys are formatted correctly"""
if keys is None:
pass
elif isinstance(keys, str):
if not keys.endswith(suffix):
keys = f"{keys}_{suffix}"
elif isinstance(keys, Sequence):
keys = [f"{key}_{suffix}" if not key.endswith(suffix) else key for key in keys]
else:
raise ValueError("keys must be a string or a sequence of strings.")
return keys


def _format_as_list(obj):
"""Ensures that obj is a list"""
if obj is None:
pass
elif not isinstance(obj, Sequence):
obj = [obj]
return obj


def _check_tree_overlap(tdata, tree_keys):
"""If overlap is allowed there can only be one tree"""
n_trees = len(tdata.obst.keys())
if (n_trees > 1) and tdata.allow_overlap and len(tree_keys) != 1:
raise ValueError("Must specify a singe tree if tdata.allow_overlap is True.")
return


def _set_distances_and_connectivities(tdata, key_added, dist, connect, update):
"""Set distances and connectivities in tdata"""
dist_key = f"{key_added}_distances"
connect_key = f"{key_added}_connectivities"
if update and (dist_key in tdata.obsp.keys()):
if isinstance(dist, np.ndarray):
tdata.obsp[dist_key] = dist
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
mask = _csr_data_mask(dist)
tdata.obsp[dist_key][mask] = dist[mask]
else:
if dist_key in tdata.obsp.keys():
del tdata.obsp[dist_key]
if f"{key_added}_neighbors" in tdata.uns.keys():
del tdata.uns[f"{key_added}_neighbors"]
tdata.obsp[dist_key] = dist
if connect is not None:
tdata.obsp[connect_key] = connect
return None


def _assert_param_xor(params):
"""Assert that only one of the parameters is set"""
n_set = sum([value is not None for key, value in params.items()])
param_text = ", ".join(params.keys())
if n_set > 1:
raise ValueError(f"Only one of {param_text} can be set.")
if n_set == 0:
raise ValueError(f"At least one of {param_text} must be set.")
return None
23 changes: 15 additions & 8 deletions src/pycea/tl/ancestral_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def ancestral_states(
keys_added: str | Sequence[str] = None,
tree: str | Sequence[str] | None = None,
copy: bool = False,
) -> None:
) -> None | pd.DataFrame:
"""Reconstructs ancestral states for an attribute.
Parameters
Expand All @@ -218,27 +218,34 @@ def ancestral_states(
keys
One or more `obs_keys`, `var_names`, `obsm_keys`, or `obsp_keys` to reconstruct.
method
Method to reconstruct ancestral states. One of "mean", "mode", "fitch_hartigan", "sankoff",
or any function that takes a list of values and returns a single value.
Method to reconstruct ancestral states:
* 'mean' : The mean of leaves in subtree.
* 'mode' : The most common value in the subtree.
* 'fitch_hartigan' : The Fitch-Hartigan algorithm.
* 'sankoff' : The Sankoff algorithm with specified costs.
* Any function that takes a list of values and returns a single value.
missing_state
The state to consider as missing data.
default_state
The expected state for the root node.
costs
A pd.DataFrame with the costs of changing states (from rows to columns).
keys_added
The keys to store the ancestral states. If None, the same keys are used.
Attribute keys of `tdata.obst[tree].nodes` where ancestral states will be stored. If `None`, `keys` are used.
tree
The `obst` key or keys of the trees to use. If `None`, all trees are used.
copy
If True, returns a :class:`pandas.DataFrame` with ancestral states.
If True, returns a :class:`DataFrame <pandas.DataFrame>` with ancestral states.
Returns
-------
Returns `None` if `copy=False`, else returns a :class:`pandas.DataFrame`. Sets the following fields for each key:
Returns `None` if `copy=False`, else return :class:`DataFrame <pandas.DataFrame>` with ancestral states.
Sets the following fields for each key:
`tdata.obst[tree].nodes[key_added]` : `float` | `Object` | `List[Object]`
Inferred ancestral states. List of states if data was an array.
* `tdata.obst[tree].nodes[key_added]` : `float` | `Object` | `List[Object]`
- Inferred ancestral states. List of states if data was an array.
"""
if isinstance(keys, str):
keys = [keys]
Expand Down
22 changes: 10 additions & 12 deletions src/pycea/tl/clades.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
import treedata as td

from pycea.utils import get_keyed_leaf_data, get_root, get_trees
from pycea.utils import check_tree_has_key, get_keyed_leaf_data, get_root, get_trees


def _nodes_at_depth(tree, parent, nodes, depth, depth_key):
Expand All @@ -32,10 +32,7 @@ def _clades(tree, depth, depth_key, clades, clade_key, name_generator, update):
# Check that root has depth key
root = get_root(tree)
if (depth is not None) and (clades is None):
if depth_key not in tree.nodes[root]:
raise ValueError(
f"Tree does not have {depth_key} attribute. You can run `pycea.pp.add_depth` to add depth attribute."
)
check_tree_has_key(tree, depth_key)
nodes = _nodes_at_depth(tree, root, [], depth, depth_key)
clades = dict(zip(nodes, name_generator))
elif (clades is not None) and (depth is None):
Expand Down Expand Up @@ -78,7 +75,7 @@ def clades(
depth
Depth to cut tree at. Must be specified if clades is None.
depth_key
Key where depth is stored.
Attribute of `tdata.obst[tree].nodes` where depth is stored.
clades
A dictionary mapping nodes to clades.
key_added
Expand All @@ -88,17 +85,18 @@ def clades(
tree
The `obst` key or keys of the trees to use. If `None`, all trees are used.
copy
If True, returns a :class:`pandas.DataFrame` with clades.
If True, returns a :class:`DataFrame <pandas.DataFrame>` with clades.
Returns
-------
Returns `None` if `copy=False`, else returns a :class:`pandas.DataFrame`. Sets the following fields:
Returns `None` if `copy=False`, else returns a :class:`DataFrame <pandas.DataFrame>`.
`tdata.obs[key_added]` : :class:`pandas.Series` (dtype `Object`)
Clade.
`tdata.obst[tree].nodes[key_added]` : `Object`
Clade.
Sets the following fields:
* `tdata.obs[key_added]` : :class:`Series <pandas.Series>` (dtype `Object`)
- Clade assignment for each observation.
* `tdata.obst[tree].nodes[key_added]` : `Object`
- Clade assignment for each node.
"""
# Setup
tree_keys = tree
Expand Down
Loading

0 comments on commit 2811e3a

Please sign in to comment.