diff --git a/docs/api.md b/docs/api.md
index b76f6c7..5af4400 100644
--- a/docs/api.md
+++ b/docs/api.md
@@ -2,28 +2,8 @@
## Preprocessing
-```{eval-rst}
-.. module:: pycea.pp
-.. currentmodule:: pycea
-
-.. autosummary::
- :toctree: generated
-
- pp.basic_preproc
-```
-
## Tools
-```{eval-rst}
-.. module:: pycea.tl
-.. currentmodule:: pycea
-
-.. autosummary::
- :toctree: generated
-
- tl.basic_tool
-```
-
## Plotting
```{eval-rst}
@@ -33,5 +13,8 @@
.. autosummary::
:toctree: generated
+ pl.tree
pl.branches
+ pl.nodes
+ pl.annotation
```
diff --git a/docs/conf.py b/docs/conf.py
index 6c81c3a..4aa5dd0 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -91,10 +91,16 @@
}
intersphinx_mapping = {
- "python": ("https://docs.python.org/3", None),
"anndata": ("https://anndata.readthedocs.io/en/stable/", None),
+ "cycler": ("https://matplotlib.org/cycler/", None),
+ "matplotlib": ("https://matplotlib.org/stable/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"networkx": ("https://networkx.org/documentation/stable/", None),
+ "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None),
+ "python": ("https://docs.python.org/3", None),
+ "scanpy": ("https://scanpy.readthedocs.io/en/stable/", None),
+ "scipy": ("https://docs.scipy.org/doc/scipy/reference/", None),
+ "squidpy": ("https://squidpy.readthedocs.io/en/stable/", None),
"treedata": ("https://treedata.readthedocs.io/en/stable/", None),
}
diff --git a/docs/notebooks/example.ipynb b/docs/notebooks/example.ipynb
index a9fa468..ce6d2f2 100644
--- a/docs/notebooks/example.ipynb
+++ b/docs/notebooks/example.ipynb
@@ -12,134 +12,7 @@
"execution_count": 1,
"metadata": {},
"outputs": [],
- "source": [
- "import numpy as np\n",
- "from anndata import AnnData\n",
- "import pandas as pd\n",
- "import pycea"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "adata = AnnData(np.random.normal(size=(20, 10)))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "With myst it is possible to link in the text cell of a notebook such as this one the documentation of a function or a class.\n",
- "\n",
- "Let's take as an example the function {func}`pycea.pp.basic_preproc`. \n",
- "You can see that by clicking on the text, the link redirects to the API documentation of the function. \n",
- "Check the raw markdown of this cell to understand how this is specified.\n",
- "\n",
- "This works also for any package listed by `intersphinx`. Go to `docs/conf.py` and look for the `intersphinx_mapping` variable. \n",
- "There, you will see a list of packages (that this package is dependent on) for which this functionality is supported. \n",
- "\n",
- "For instance, we can link to the class {class}`anndata.AnnData`, to the attribute {attr}`anndata.AnnData.obs` or the method {meth}`anndata.AnnData.write`.\n",
- "\n",
- "Again, check the raw markdown of this cell to see how each of these links are specified.\n",
- "\n",
- "You can read more about this in the [intersphinx page](https://www.sphinx-doc.org/en/master/usage/extensions/intersphinx.html) and the [myst page](https://myst-parser.readthedocs.io/en/v0.15.1/syntax/syntax.html#roles-an-in-line-extension-point)."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Implement a preprocessing function here."
- ]
- },
- {
- "data": {
- "text/plain": [
- "0"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "pycea.pp.basic_preproc(adata)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " A | \n",
- " B | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 0 | \n",
- " a | \n",
- " 1 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " b | \n",
- " 2 | \n",
- "
\n",
- " \n",
- " 2 | \n",
- " c | \n",
- " 3 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " A B\n",
- "0 a 1\n",
- "1 b 2\n",
- "2 c 3"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "pd.DataFrame().assign(A=[\"a\", \"b\", \"c\"], B=[1, 2, 3])"
- ]
+ "source": []
}
],
"metadata": {
diff --git a/src/pycea/_utils.py b/src/pycea/_utils.py
deleted file mode 100755
index e249e38..0000000
--- a/src/pycea/_utils.py
+++ /dev/null
@@ -1,24 +0,0 @@
-import networkx as nx
-import pandas as pd
-
-
-def get_root(tree: nx.DiGraph):
- """Finds the root of a tree"""
- if not tree.nodes():
- return None # Handle empty graph case.
- node = next(iter(tree.nodes))
- while True:
- parent = list(tree.predecessors(node))
- if not parent:
- return node # No predecessors, this is the root
- node = parent[0]
-
-
-def _get_keyed_edge_data(tree: nx.DiGraph, key: str) -> pd.Series:
- """Gets edge data for a given key from a tree."""
- edge_data = {
- (parent, child): data.get(key)
- for parent, child, data in tree.edges(data=True)
- if key in data and data[key] is not None
- }
- return pd.Series(edge_data, name=key)
diff --git a/src/pycea/pl/__init__.py b/src/pycea/pl/__init__.py
index e92b347..8b9468f 100644
--- a/src/pycea/pl/__init__.py
+++ b/src/pycea/pl/__init__.py
@@ -1 +1 @@
-from .tree import branches
+from .plot_tree import annotation, branches, nodes, tree
diff --git a/src/pycea/pl/_docs.py b/src/pycea/pl/_docs.py
index 8d1b148..20b8c70 100755
--- a/src/pycea/pl/_docs.py
+++ b/src/pycea/pl/_docs.py
@@ -2,13 +2,27 @@
from __future__ import annotations
+from textwrap import dedent
+
+
+def _doc_params(**kwds):
+ r"""Docstrings should start with ``\\`` in the first line for proper formatting"""
+
+ def dec(obj):
+ obj.__orig_doc__ = obj.__doc__
+ obj.__doc__ = dedent(obj.__doc__).format_map(kwds)
+ return obj
+
+ return dec
+
+
doc_common_plot_args = """\
-color_map
+cmap
Color map to use for continous variables. Can be a name or a
:class:`~matplotlib.colors.Colormap` instance (e.g. `"magma`", `"viridis"`
or `mpl.cm.cividis`), see :func:`~matplotlib.cm.get_cmap`.
If `None`, the value of `mpl.rcParams["image.cmap"]` is used.
- The default `color_map` can be set using :func:`~scanpy.set_figure_params`.
+ The default `cmap` can be set using :func:`~scanpy.set_figure_params`.
palette
Colors to use for plotting categorical annotation groups.
The palette can be a valid :class:`~matplotlib.colors.ListedColormap` name
@@ -18,6 +32,10 @@
If `None`, `mpl.rcParams["axes.prop_cycle"]` is used unless the categorical
variable already has colors stored in `tdata.uns["{var}_colors"]`.
If provided, values of `tdata.uns["{var}_colors"]` will be set.
+vmax
+ The maximum value for the colormap.
+vmin
+ The minimum value for the colormap.
ax
A matplotlib axes object. If `None`, a new figure and axes will be created.
"""
diff --git a/src/pycea/pl/_utils.py b/src/pycea/pl/_utils.py
index 799cf7b..fb60491 100755
--- a/src/pycea/pl/_utils.py
+++ b/src/pycea/pl/_utils.py
@@ -11,7 +11,7 @@
import numpy as np
from scanpy.plotting import palettes
-from pycea._utils import get_root
+from pycea.utils import get_root
def layout_tree(
@@ -62,7 +62,7 @@ def layout_tree(
node_coords = {}
for node in nx.dfs_postorder_nodes(tree, root):
if tree.out_degree(node) == 0:
- lon = (i / n_leaves) * 2 * np.pi
+ lon = (i / (n_leaves)) * 2 * np.pi # + 2 * np.pi / n_leaves
if extend_branches:
node_coords[node] = (max_depth, lon)
else:
@@ -175,3 +175,59 @@ def _get_categorical_colors(tdata, key, data, palette=None):
# store colors in tdata
tdata.uns[key + "_colors"] = colors_list
return dict(zip(categories, colors_list))
+
+
+def _get_categorical_markers(tdata, key, data, markers=None):
+ """Get categorical markers for plotting."""
+ default_markers = ["o", "s", "D", "^", "v", "<", ">", "p", "P", "*", "h", "H", "X"]
+ # Ensure data is a category
+ if not data.dtype.name == "category":
+ data = data.astype("category")
+ categories = data.cat.categories
+ # Use default markers if no markers are provided
+ if markers is None:
+ markers_list = tdata.uns.get(key + "_markers", None)
+ if markers_list is None or len(markers_list) > len(categories):
+ markers_list = default_markers[: len(categories)]
+ # Use provided markers
+ else:
+ if isinstance(markers, cabc.Mapping):
+ markers_list = [markers[k] for k in categories]
+ else:
+ if not isinstance(markers, cabc.Sequence):
+ raise ValueError("Please check that the value of 'markers' is a valid " "list of marker names.")
+ if len(markers) < len(categories):
+ warnings.warn(
+ "Length of markers is smaller than the number of "
+ f"categories (markers length: {len(markers)}, "
+ f"categories length: {len(categories)}. "
+ "Some categories will have the same marker.",
+ stacklevel=2,
+ )
+ markers_list = markers * (len(categories) // len(markers) + 1)
+ else:
+ markers_list = markers[: len(categories)]
+ # store markers in tdata
+ tdata.uns[key + "_markers"] = markers_list
+ return dict(zip(categories, markers_list))
+
+
+def _series_to_rgb_array(series, colors, vmin=None, vmax=None, na_color="#808080"):
+ """Converts a pandas Series to an N x 3 numpy array based using a color map."""
+ if isinstance(colors, dict):
+ # Map using the dictionary
+ color_series = series.map(colors)
+ color_series[series.isna()] = na_color
+ rgb_array = np.array([mcolors.to_rgb(color) for color in color_series])
+ elif isinstance(colors, mcolors.ListedColormap):
+ # Normalize and map values if cmap is a ListedColormap
+ if vmin is not None and vmax is not None:
+ norm = mcolors.Normalize(vmin, vmax)
+ colors.set_bad(na_color)
+ color_series = colors(norm(series))
+ rgb_array = np.vstack(color_series[:, :3])
+ else:
+ raise ValueError("vmin and vmax must be specified when using a ListedColormap.")
+ else:
+ raise ValueError("cmap must be either a dictionary or a ListedColormap.")
+ return rgb_array
diff --git a/src/pycea/pl/plot_tree.py b/src/pycea/pl/plot_tree.py
new file mode 100644
index 0000000..6f66320
--- /dev/null
+++ b/src/pycea/pl/plot_tree.py
@@ -0,0 +1,492 @@
+from __future__ import annotations
+
+from collections.abc import Mapping, Sequence
+
+import cycler
+import matplotlib as mpl
+import matplotlib.colors as mcolors
+import matplotlib.markers as mmarkers
+import matplotlib.pyplot as plt
+import numpy as np
+import treedata as td
+from matplotlib.axes import Axes
+from matplotlib.collections import LineCollection
+
+from pycea.utils import get_keyed_edge_data, get_keyed_node_data, get_keyed_obs_data
+
+from ._docs import _doc_params, doc_common_plot_args
+from ._utils import (
+ _get_categorical_colors,
+ _get_categorical_markers,
+ _series_to_rgb_array,
+ layout_tree,
+)
+
+
+@_doc_params(
+ common_plot_args=doc_common_plot_args,
+)
+def branches(
+ tdata: td.TreeData,
+ key: str = None,
+ polar: bool = False,
+ extend_branches: bool = False,
+ angled_branches: bool = False,
+ color: str = "black",
+ linewidth: int | float | str = 1,
+ cmap: str | mcolors.Colormap = "viridis",
+ palette: cycler.Cycler | mcolors.ListedColormap | Sequence[str] | Mapping[str] | None = None,
+ na_color: str = "lightgrey",
+ na_linewidth: int | float = 1,
+ ax: Axes | None = None,
+ **kwargs,
+) -> Axes:
+ """\
+ Plot the branches of a tree.
+
+ Parameters
+ ----------
+ tdata
+ The `treedata.TreeData` object.
+ key
+ The `obst` key of the tree to plot.
+ polar
+ Whether to plot the tree in polar coordinates.
+ extend_branches
+ Whether to extend branches so the tips are at the same depth.
+ angled_branches
+ Whether to plot branches at an angle.
+ color
+ Either a color name, or a key for an attribute of the edges to color by.
+ linewidth
+ Either an numeric width, or a key for an attribute of the edges to set the linewidth.
+ {common_plot_args}
+ na_color
+ The color to use for edges with missing data.
+ na_linewidth
+ The linewidth to use for edges with missing data.
+ kwargs
+ Additional keyword arguments passed to `matplotlib.collections.LineCollection`.
+
+ Returns
+ -------
+ ax - The axes that the plot was drawn on.
+ """ # noqa: D205
+ kwargs = kwargs if kwargs else {}
+ if not key:
+ key = next(iter(tdata.obst.keys()))
+ tree = tdata.obst[key]
+ # Get layout
+ node_coords, branch_coords, leaves, depth = layout_tree(
+ tree, polar=polar, extend_branches=extend_branches, angled_branches=angled_branches
+ )
+ segments = []
+ edges = []
+ for edge, (lat, lon) in branch_coords.items():
+ coords = np.array([lon, lat] if polar else [lat, lon]).T
+ segments.append(coords)
+ edges.append(edge)
+ kwargs.update({"segments": segments})
+ # Get colors
+ if mcolors.is_color_like(color):
+ kwargs.update({"color": color})
+ elif isinstance(color, str):
+ color_data = get_keyed_edge_data(tree, color)
+ if color_data.dtype.kind in ["i", "f"]:
+ norm = plt.Normalize(vmin=color_data.min(), vmax=color_data.max())
+ cmap = plt.get_cmap(cmap)
+ colors = [cmap(norm(color_data[edge])) if edge in color_data.index else na_color for edge in edges]
+ kwargs.update({"color": colors})
+ else:
+ cmap = _get_categorical_colors(tdata, color, color_data, palette)
+ colors = [cmap[color_data[edge]] if edge in color_data.index else na_color for edge in edges]
+ kwargs.update({"color": colors})
+ else:
+ raise ValueError("Invalid color value. Must be a color name, or an str specifying an attribute of the edges.")
+ # Get linewidths
+ if isinstance(linewidth, (int, float)):
+ kwargs.update({"linewidth": linewidth})
+ elif isinstance(linewidth, str):
+ linewidth_data = get_keyed_edge_data(tree, linewidth)
+ if linewidth_data.dtype.kind in ["i", "f"]:
+ linewidths = [linewidth_data[edge] if edge in linewidth_data.index else na_linewidth for edge in edges]
+ kwargs.update({"linewidth": linewidths})
+ else:
+ raise ValueError("Invalid linewidth data type. Edge attribute must be int or float")
+ else:
+ raise ValueError("Invalid linewidth value. Must be int, float, or an str specifying an attribute of the edges.")
+ # Plot
+ if not ax:
+ subplot_kw = {"projection": "polar"} if polar else None
+ fig, ax = plt.subplots(subplot_kw=subplot_kw)
+ elif (ax.name == "polar") != polar:
+ raise ValueError("Provided axis does not match the requested 'polar' setting.")
+ ax.add_collection(LineCollection(zorder=1, **kwargs))
+ # Configure plot
+ lat_lim = (-0.2, depth)
+ lon_lim = (0, 2 * np.pi)
+ ax.set_xlim(lon_lim if polar else lat_lim)
+ ax.set_ylim(lat_lim if polar else lon_lim)
+ ax.axis("off")
+ ax._attrs = {
+ "node_coords": node_coords,
+ "leaves": leaves,
+ "depth": depth,
+ "offset": depth,
+ "polar": polar,
+ "tree_key": key,
+ }
+ return ax
+
+
+# For internal use
+_branches = branches
+
+
+@_doc_params(
+ common_plot_args=doc_common_plot_args,
+)
+def nodes(
+ tdata: td.TreeData,
+ nodes: str | Sequence[str] = "internal",
+ color: str = "black",
+ style: str = "o",
+ size: int | float | str = 10,
+ cmap: str | mcolors.Colormap = None,
+ palette: cycler.Cycler | mcolors.ListedColormap | Sequence[str] | Mapping[str] | None = None,
+ markers: Sequence[str] | Mapping[str] = None,
+ vmax: int | float | None = None,
+ vmin: int | float | None = None,
+ na_color: str = "#FFFFFF00",
+ na_style: str = "none",
+ na_size: int | float = 0,
+ ax: Axes | None = None,
+ **kwargs,
+) -> Axes:
+ """\
+ Plot the nodes of a tree.
+
+ Parameters
+ ----------
+ tdata
+ The TreeData object.
+ nodes
+ Either "all", "leaves", "internal", or a list of nodes to plot.
+ color
+ Either a color name, or a key for an attribute of the nodes to color by.
+ style
+ Either a marker name, or a key for an attribute of the nodes to set the marker.
+ Can be numeric but will always be treated as a categorical variable.
+ size
+ Either an numeric size, or a key for an attribute of the nodes to set the size.
+ {common_plot_args}
+ markers
+ Object determining how to draw the markers for different levels of the style variable.
+ You can pass a list of markers or a dictionary mapping levels of the style variable to markers.
+ na_color
+ The color to use for annotations with missing data.
+ na_style
+ The marker to use for annotations with missing data.
+ na_size
+ The size to use for annotations with missing data.
+ kwargs
+ Additional keyword arguments passed to `matplotlib.pyplot.scatter`.
+
+ Returns
+ -------
+ ax - The axes that the plot was drawn on.
+ """ # noqa: D205
+ # Setup
+ kwargs = kwargs if kwargs else {}
+ if not ax:
+ ax = plt.gca()
+ attrs = ax._attrs if hasattr(ax, "_attrs") else None
+ if not attrs:
+ raise ValueError("Branches most be plotted with pycea.pl.branches before annotations can be plotted.")
+ if not cmap:
+ cmap = mpl.rcParams["image.cmap"]
+ cmap = plt.get_cmap(cmap)
+ tree = tdata.obst[attrs["tree_key"]]
+ # Get nodes
+ all_nodes = list(attrs["node_coords"].keys())
+ leaves = list(attrs["leaves"])
+ if nodes == "all":
+ nodes = all_nodes
+ elif nodes == "leaves":
+ nodes = leaves
+ elif nodes == "internal":
+ nodes = [node for node in all_nodes if node not in leaves]
+ elif isinstance(nodes, Sequence):
+ if set(nodes).issubset(all_nodes):
+ nodes = list(nodes)
+ else:
+ raise ValueError("Nodes must be a list of nodes in the tree.")
+ else:
+ raise ValueError("Invalid nodes value. Must be 'all', 'leaves', 'no_leaves', or a list of nodes.")
+ # Get coordinates
+ coords = np.vstack([attrs["node_coords"][node] for node in nodes])
+ if attrs["polar"]:
+ kwargs.update({"x": coords[:, 1], "y": coords[:, 0]})
+ else:
+ kwargs.update({"x": coords[:, 0], "y": coords[:, 1]})
+ kwargs_list = []
+ # Get colors
+ if mcolors.is_color_like(color):
+ kwargs.update({"color": color})
+ elif isinstance(color, str):
+ color_data = get_keyed_node_data(tree, color)
+ if color_data.dtype.kind in ["i", "f"]:
+ if not vmin:
+ vmin = color_data.min()
+ if not vmax:
+ vmax = color_data.max()
+ norm = plt.Normalize(vmin=vmin, vmax=vmax)
+ colors = [cmap(norm(color_data[node])) if node in color_data.index else na_color for node in nodes]
+ kwargs.update({"color": colors})
+ else:
+ cmap = _get_categorical_colors(tdata, color, color_data, palette)
+ colors = [cmap[color_data[node]] if node in color_data.index else na_color for node in nodes]
+ kwargs.update({"color": colors})
+ else:
+ raise ValueError("Invalid color value. Must be a color name, or an str specifying an attribute of the nodes.")
+ # Get sizes
+ if isinstance(size, (int, float)):
+ kwargs.update({"s": size})
+ elif isinstance(size, str):
+ size_data = get_keyed_node_data(tree, size)
+ sizes = [size_data[node] if node in size_data.index else na_size for node in nodes]
+ kwargs.update({"s": sizes})
+ else:
+ raise ValueError("Invalid size value. Must be int, float, or an str specifying an attribute of the nodes.")
+ # Get markers
+ if style in mmarkers.MarkerStyle.markers:
+ kwargs.update({"marker": style})
+ elif isinstance(style, str):
+ style_data = get_keyed_node_data(tree, style)
+ mmap = _get_categorical_markers(tdata, style, style_data, markers)
+ styles = [mmap[style_data[node]] if node in style_data.index else na_style for node in nodes]
+ for style in set(styles):
+ style_kwargs = {}
+ idx = [i for i, x in enumerate(styles) if x == style]
+ for key, value in kwargs.items():
+ if isinstance(value, (list, np.ndarray)):
+ style_kwargs[key] = [value[i] for i in idx]
+ else:
+ style_kwargs[key] = value
+ style_kwargs.update({"marker": style})
+ kwargs_list.append(style_kwargs)
+ else:
+ raise ValueError("Invalid style value. Must be a marker name, or an str specifying an attribute of the nodes.")
+ # Plot
+ if len(kwargs_list) > 0:
+ for kwargs in kwargs_list:
+ ax.scatter(**kwargs)
+ else:
+ ax.scatter(**kwargs)
+ return ax
+
+
+# For internal use
+_nodes = nodes
+
+
+@_doc_params(
+ common_plot_args=doc_common_plot_args,
+)
+def annotation(
+ tdata: td.TreeData,
+ keys: str | Sequence[str] = None,
+ width: int | float = 0.05,
+ gap: int | float = 0.03,
+ label: bool | str | Sequence[str] = True,
+ cmap: str | mcolors.Colormap = None,
+ palette: cycler.Cycler | mcolors.ListedColormap | Sequence[str] | Mapping[str] | None = None,
+ vmax: int | float | None = None,
+ vmin: int | float | None = None,
+ na_color: str = "white",
+ ax: Axes | None = None,
+ **kwargs,
+) -> Axes:
+ """\
+ Plot leaf annotations.
+
+ Parameters
+ ----------
+ tdata
+ The TreeData object.
+ keys
+ One or more `obs_keys`, `var_names`, `obsm_keys`, or `obsp_keys` to plot.
+ width
+ The width of the annotation bar relative to the tree.
+ gap
+ The gap between the annotation bar and the tree relative to the tree.
+ label
+ Annotation labels. If `True`, the keys are used as labels.
+ If a string or a sequence of strings, the strings are used as labels.
+ {common_plot_args}
+ na_color
+ The color to use for annotations with missing data.
+ kwargs
+ Additional keyword arguments passed to `matplotlib.pyplot.pcolormesh`.
+
+ Returns
+ -------
+ ax - The axes that the plot was drawn on.
+ """ # noqa: D205
+ # Setup
+ if not ax:
+ ax = plt.gca()
+ attrs = ax._attrs if hasattr(ax, "_attrs") else None
+ if not attrs:
+ raise ValueError("Branches most be plotted with pycea.pl.branches before annotations can be plotted.")
+ if not keys:
+ raise ValueError("No keys provided. Please provide one or more keys to plot.")
+ keys = [keys] if isinstance(keys, str) else keys
+ if not cmap:
+ cmap = mpl.rcParams["image.cmap"]
+ cmap = plt.get_cmap(cmap)
+ # Get data
+ data, is_array = get_keyed_obs_data(tdata, keys)
+ data = data.loc[attrs["leaves"]]
+ numeric_data = data.select_dtypes(exclude="category")
+ if len(numeric_data) > 0 and not vmin:
+ vmin = numeric_data.min().min()
+ if len(numeric_data) > 0 and not vmax:
+ vmax = numeric_data.max().max()
+ # Get labels
+ if label is True:
+ labels = keys
+ elif isinstance(label, str):
+ labels = [label]
+ elif isinstance(label, Sequence):
+ labels = label
+ else:
+ raise ValueError("Invalid label value. Must be a bool, str, or a sequence of strings.")
+ # Compute coordinates for annotations
+ start_lat = attrs["offset"] + attrs["depth"] * gap
+ end_lat = start_lat + attrs["depth"] * width * data.shape[1]
+ lats = np.linspace(start_lat, end_lat, data.shape[1] + 1)
+ lons = np.linspace(0, 2 * np.pi, data.shape[0] + 1)
+ lons = lons - np.pi / len(attrs["leaves"])
+ # Covert to RGB array
+ rgb_array = []
+ if is_array:
+ if data.shape[0] == data.shape[1]:
+ data = data.loc[attrs["leaves"], reversed(attrs["leaves"])]
+ end_lat = start_lat + attrs["depth"] + 2 * np.pi
+ lats = np.linspace(start_lat, end_lat, data.shape[1] + 1)
+ for col in data.columns:
+ rgb_array.append(_series_to_rgb_array(data[col], cmap, vmin=vmin, vmax=vmax, na_color=na_color))
+ else:
+ for key in keys:
+ if data[key].dtype == "category":
+ colors = _get_categorical_colors(tdata, key, data[key], palette)
+ rgb_array.append(_series_to_rgb_array(data[key], colors, na_color=na_color))
+ else:
+ rgb_array.append(_series_to_rgb_array(data[key], cmap, vmin=vmin, vmax=vmax, na_color=na_color))
+ rgb_array = np.stack(rgb_array, axis=1)
+ # Plot
+ if attrs["polar"]:
+ ax.pcolormesh(lons, lats, rgb_array.swapaxes(0, 1), zorder=2, **kwargs)
+ ax.set_ylim(-0.2, end_lat)
+ else:
+ ax.pcolormesh(lats, lons, rgb_array, zorder=2, **kwargs)
+ ax.set_xlim(-0.2, end_lat)
+ labels_lats = np.linspace(start_lat, end_lat, len(labels) + 1)
+ labels_lats = labels_lats + (end_lat - start_lat) / (len(labels) * 2)
+ for idx, label in enumerate(labels):
+ if is_array and len(labels) == 1:
+ ax.text(labels_lats[idx], -0.1, label, ha="center", va="top")
+ ax.set_ylim(-0.5, 2 * np.pi)
+ else:
+ ax.text(labels_lats[idx], -0.1, label, ha="center", va="top", rotation=90)
+ ax.set_ylim(-1, 2 * np.pi)
+ ax._attrs.update({"offset": end_lat})
+ return ax
+
+
+# For internal use
+_annotation = annotation
+
+
+@_doc_params(
+ common_plot_args=doc_common_plot_args,
+)
+def tree(
+ tdata: td.TreeData,
+ key: str = None,
+ nodes: str | Sequence[str] = None,
+ annotation_keys: str | Sequence[str] = None,
+ polar: bool = False,
+ extend_branches: bool = False,
+ angled_branches: bool = False,
+ branch_color: str = "black",
+ branch_linewidth: int | float | str = 1,
+ node_color: str = "black",
+ node_style: str = "o",
+ node_size: int | float = 10,
+ annotation_width: int | float = 0.05,
+ cmap: str | mcolors.Colormap = "viridis",
+ palette: cycler.Cycler | mcolors.ListedColormap | Sequence[str] | Mapping[str] | None = None,
+ ax: Axes | None = None,
+ **kwargs,
+) -> Axes:
+ """\
+ Plot a tree with branches, nodes, and annotations.
+
+ Parameters
+ ----------
+ tdata
+ The TreeData object.
+ key
+ The `obst` key of the tree to plot.
+ nodes
+ Either "all", "leaves", "internal", or a list of nodes to plot.
+ annotation_keys
+ One or more `obs_keys`, `var_names`, `obsm_keys`, or `obsp_keys` to plot.
+ polar
+ Whether to plot the tree in polar coordinates.
+ extend_branches
+ Whether to extend branches so the tips are at the same depth.
+ angled_branches
+ Whether to plot branches at an angle.
+ branch_color
+ Either a color name, or a key for an attribute of the edges to color by.
+ branch_linewidth
+ Either an numeric width, or a key for an attribute of the edges to set the linewidth.
+ node_color
+ Either a color name, or a key for an attribute of the nodes to color by.
+ node_style
+ Either a marker name, or a key for an attribute of the nodes to set the marker.
+ node_size
+ Either an numeric size, or a key for an attribute of the nodes to set the size.
+ annotation_width
+ The width of the annotation bar relative to the tree.
+ {common_plot_args}
+
+ Returns
+ -------
+ ax - The axes that the plot was drawn on.
+ """ # noqa: D205
+ # Plot branches
+ ax = _branches(
+ tdata,
+ key=key,
+ polar=polar,
+ extend_branches=extend_branches,
+ angled_branches=angled_branches,
+ color=branch_color,
+ linewidth=branch_linewidth,
+ cmap=cmap,
+ palette=palette,
+ ax=ax,
+ )
+ # Plot nodes
+ if nodes:
+ ax = _nodes(
+ tdata, nodes=nodes, color=node_color, style=node_style, size=node_size, cmap=cmap, palette=palette, ax=ax
+ )
+ # Plot annotations
+ if annotation_keys:
+ ax = _annotation(tdata, keys=annotation_keys, width=annotation_width, cmap=cmap, palette=palette, ax=ax)
+ return ax
diff --git a/src/pycea/pl/tree.py b/src/pycea/pl/tree.py
deleted file mode 100644
index 99a9051..0000000
--- a/src/pycea/pl/tree.py
+++ /dev/null
@@ -1,122 +0,0 @@
-from __future__ import annotations
-
-from collections.abc import Mapping, Sequence
-
-import cycler
-import matplotlib.colors as mcolors
-import matplotlib.pyplot as plt
-import numpy as np
-import treedata as td
-from matplotlib.axes import Axes
-from matplotlib.collections import LineCollection
-
-from pycea._utils import _get_keyed_edge_data
-
-from ._utils import (
- _get_categorical_colors,
- layout_tree,
-)
-
-
-def branches(
- tdata: td.TreeData,
- key: str = None,
- polar: bool = False,
- extend_branches: bool = False,
- angled_branches: bool = False,
- color: str = "black",
- color_na: str = "lightgrey",
- linewidth: int | float | str = 1,
- linewidth_na: int | float = 1,
- cmap: str | mcolors.Colormap = "viridis",
- palette: cycler.Cycler | mcolors.ListedColormap | Sequence[str] | Mapping[str] | None = None,
- ax: Axes | None = None,
- **kwargs,
-):
- """Plot the branches of a tree.
-
- Parameters
- ----------
- tdata
- The `td.TreeData` object.
- key
- The `obst` key of the tree to plot.
- polar
- Whether to plot the tree in polar coordinates.
- extend_branches
- Whether to extend branches so the tips are at the same depth.
- angled_branches
- Whether to plot branches at an angle.
- color
- Either a color name, or a key for an attribute of the edges to color by.
- color_na
- The color to use for edges with missing data.
- linewidth
- Either an numeric width, or a key for an attribute of the edges to set the linewidth.
- linewidth_na
- The linewidth to use for edges with missing data.
- {doc_common_plot_args}
- kwargs
- Additional keyword arguments passed to `matplotlib.collections.LineCollection`.
-
- Returns
- -------
- `matplotlib.axes.Axes`
- """
- kwargs = kwargs if kwargs else {}
- tree = tdata.obst[key]
-
- # Get layout
- node_coords, branch_coords, leaves, depth = layout_tree(
- tree, polar=polar, extend_branches=extend_branches, angled_branches=angled_branches
- )
- segments = []
- edges = []
- for edge, (lat, lon) in branch_coords.items():
- coords = np.array([lon, lat] if polar else [lat, lon]).T
- segments.append(coords)
- edges.append(edge)
- kwargs.update({"segments": segments})
- # Get colors
- if mcolors.is_color_like(color):
- kwargs.update({"color": color})
- elif isinstance(color, str):
- color_data = _get_keyed_edge_data(tree, color)
- if color_data.dtype.kind in ["i", "f"]:
- norm = plt.Normalize(vmin=color_data.min(), vmax=color_data.max())
- cmap = plt.get_cmap(cmap)
- colors = [cmap(norm(color_data[edge])) if edge in color_data.index else color_na for edge in edges]
- kwargs.update({"color": colors})
- else:
- cmap = _get_categorical_colors(tdata, color, color_data, palette)
- colors = [cmap[color_data[edge]] if edge in color_data.index else color_na for edge in edges]
- kwargs.update({"color": colors})
- else:
- raise ValueError("Invalid color value. Must be a color name, or an str specifying an attribute of the edges.")
- # Get linewidths
- if isinstance(linewidth, (int, float)):
- kwargs.update({"linewidth": linewidth})
- elif isinstance(linewidth, str):
- linewidth_data = _get_keyed_edge_data(tree, linewidth)
- if linewidth_data.dtype.kind in ["i", "f"]:
- linewidths = [linewidth_data[edge] if edge in linewidth_data.index else linewidth_na for edge in edges]
- kwargs.update({"linewidth": linewidths})
- else:
- raise ValueError("Invalid linewidth data type. Edge attribute must be int or float")
- else:
- raise ValueError("Invalid linewidth value. Must be int, float, or an str specifying an attribute of the edges.")
- # Plot
- if not ax:
- subplot_kw = {"projection": "polar"} if polar else None
- fig, ax = plt.subplots(subplot_kw=subplot_kw)
- elif (ax.name == "polar") != polar:
- raise ValueError("Provided axis does not match the requested 'polar' setting.")
- ax.add_collection(LineCollection(**kwargs))
- # Configure plot
- lat_lim = (-0.1, depth)
- lon_lim = (0, 2 * np.pi)
- ax.set_xlim(lon_lim if polar else lat_lim)
- ax.set_ylim(lat_lim if polar else lon_lim)
- ax.axis("off")
- ax._attrs = {"node_coords": node_coords, "leaves": leaves, "depth": depth, "offset": depth, "polar": polar}
- return ax
diff --git a/src/pycea/pp/__init__.py b/src/pycea/pp/__init__.py
index 5e7e293..e69de29 100644
--- a/src/pycea/pp/__init__.py
+++ b/src/pycea/pp/__init__.py
@@ -1 +0,0 @@
-from .basic import basic_preproc
diff --git a/src/pycea/pp/basic.py b/src/pycea/pp/basic.py
deleted file mode 100644
index 5db1ec0..0000000
--- a/src/pycea/pp/basic.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from anndata import AnnData
-
-
-def basic_preproc(adata: AnnData) -> int:
- """Run a basic preprocessing on the AnnData object.
-
- Parameters
- ----------
- adata
- The AnnData object to preprocess.
-
- Returns
- -------
- Some integer value.
- """
- print("Implement a preprocessing function here.")
- return 0
diff --git a/src/pycea/tl/__init__.py b/src/pycea/tl/__init__.py
index 95a32cd..e69de29 100644
--- a/src/pycea/tl/__init__.py
+++ b/src/pycea/tl/__init__.py
@@ -1 +0,0 @@
-from .basic import basic_tool
diff --git a/src/pycea/tl/basic.py b/src/pycea/tl/basic.py
deleted file mode 100644
index d215ade..0000000
--- a/src/pycea/tl/basic.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from anndata import AnnData
-
-
-def basic_tool(adata: AnnData) -> int:
- """Run a tool on the AnnData object.
-
- Parameters
- ----------
- adata
- The AnnData object to preprocess.
-
- Returns
- -------
- Some integer value.
- """
- print("Implement a tool to run on the AnnData object.")
- return 0
diff --git a/src/pycea/utils.py b/src/pycea/utils.py
new file mode 100755
index 0000000..2ffa0e5
--- /dev/null
+++ b/src/pycea/utils.py
@@ -0,0 +1,82 @@
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+import networkx as nx
+import pandas as pd
+import treedata as td
+
+
+def get_root(tree: nx.DiGraph):
+ """Finds the root of a tree"""
+ if not tree.nodes():
+ return None # Handle empty graph case.
+ node = next(iter(tree.nodes))
+ while True:
+ parent = list(tree.predecessors(node))
+ if not parent:
+ return node # No predecessors, this is the root
+ node = parent[0]
+
+
+def get_keyed_edge_data(tree: nx.DiGraph, key: str) -> pd.Series:
+ """Gets edge data for a given key from a tree."""
+ edge_data = {
+ (parent, child): data.get(key)
+ for parent, child, data in tree.edges(data=True)
+ if key in data and data[key] is not None
+ }
+ if len(edge_data) == 0:
+ raise ValueError(f"Key {key!r} is not present in any edge.")
+ return pd.Series(edge_data, name=key)
+
+
+def get_keyed_node_data(tree: nx.DiGraph, key: str) -> pd.Series:
+ """Gets node data for a given key from a tree."""
+ node_data = {node: data.get(key) for node, data in tree.nodes(data=True) if key in data and data[key] is not None}
+ if len(node_data) == 0:
+ raise ValueError(f"Key {key!r} is not present in any node.")
+ return pd.Series(node_data, name=key)
+
+
+def get_keyed_obs_data(tdata: td.TreeData, keys: Sequence[str], layer: str = None) -> pd.DataFrame:
+ """Gets observation data for a given key from a tree."""
+ data = []
+ column_keys = False
+ array_keys = False
+ for key in keys:
+ if key in tdata.obs_keys():
+ if tdata.obs[key].dtype.kind in ["b", "O", "S"]:
+ tdata.obs[key] = tdata.obs[key].astype("category")
+ data.append(tdata.obs[key])
+ column_keys = True
+ elif key in tdata.var_names:
+ data.append(pd.Series(tdata.obs_vector(key, layer=layer), index=tdata.obs_names))
+ column_keys = True
+ elif "obsm" in dir(tdata) and key in tdata.obsm.keys():
+ data.append(tdata.obsm[key])
+ array_keys = True
+ elif "obsp" in dir(tdata) and key in tdata.obsp.keys():
+ data.append(tdata.obsp[key])
+ array_keys = True
+ else:
+ raise ValueError(
+ f"Key {key!r} is invalid! You must pass a valid observation annotation. "
+ f"One of obs_keys, var_names, obsm_keys, obsp_keys."
+ )
+ if column_keys and array_keys:
+ raise ValueError("Cannot mix column and matrix keys.")
+ if array_keys and len(keys) > 1:
+ raise ValueError("Cannot request multiple matrix keys.")
+ if not column_keys and not array_keys:
+ raise ValueError("No valid keys found.")
+ # Convert to DataFrame
+ if column_keys:
+ data = pd.concat(data, axis=1)
+ data.columns = keys
+ elif array_keys:
+ data = pd.DataFrame(data[0], index=tdata.obs_names)
+
+ if data.shape[0] == data.shape[1]:
+ data.columns = tdata.obs_names
+ return data, array_keys
diff --git a/tests/conftest.py b/tests/conftest.py
index ba9a282..3aaf61b 100755
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,5 +1,3 @@
-from pathlib import Path
-
import pytest
import treedata as td
@@ -9,6 +7,3 @@
@pytest.fixture(scope="session")
def tdata() -> td.TreeData:
return _tdata
-
-
-PLOT_PATH = Path("tests/plots/")
diff --git a/tests/test_basic.py b/tests/test_basic.py
deleted file mode 100644
index c01c90a..0000000
--- a/tests/test_basic.py
+++ /dev/null
@@ -1,12 +0,0 @@
-import pytest
-
-import pycea
-
-
-def test_package_has_version():
- assert pycea.__version__ is not None
-
-
-@pytest.mark.skip(reason="This decorator should be removed when test passes.")
-def test_example():
- assert 1 == 0 # This test is designed to fail.
diff --git a/tests/test_plot_tree.py b/tests/test_plot_tree.py
index d295938..c1c9b1d 100755
--- a/tests/test_plot_tree.py
+++ b/tests/test_plot_tree.py
@@ -1,20 +1,71 @@
from pathlib import Path
import matplotlib.pyplot as plt
+import pytest
import pycea
plot_path = Path(__file__).parent / "plots"
-def test_plot_branches(tdata):
- # Polar categorical with missing
- pycea.pl.branches(tdata, key="tree", polar=True, color="clade", palette="Set1")
- plt.savefig(plot_path / "polar_categorical_branches.png")
+def test_polar_with_clades(tdata):
+ fig, ax = plt.subplots(dpi=600, subplot_kw={"polar": True})
+ pycea.pl.branches(tdata, key="tree", polar=True, color="clade", palette="Set1", na_color="black", ax=ax)
+ pycea.pl.nodes(tdata, color="clade", palette="Set1", style="clade", ax=ax)
+ pycea.pl.annotation(tdata, keys="clade", ax=ax)
+ plt.savefig(plot_path / "polar_clades.png")
plt.close()
- # Numeric with line width
+
+
+def test_angled_numeric_annotations(tdata):
+ fig, ax = plt.subplots(dpi=600)
pycea.pl.branches(
- tdata, key="tree", polar=False, color="length", cmap="hsv", linewidth="length", angled_branches=True
+ tdata, key="tree", polar=False, color="length", cmap="hsv", linewidth="length", angled_branches=True, ax=ax
+ )
+ pycea.pl.nodes(tdata, nodes="all", color="time", style="s", size=20, ax=ax)
+ pycea.pl.annotation(tdata, keys=["x", "y"], cmap="magma", width=0.1, gap=0.05, ax=ax)
+ pycea.pl.annotation(tdata, keys=["0", "1", "2", "3", "4", "5"], label="genes", ax=ax)
+ plt.savefig(plot_path / "angled_numeric.png")
+ plt.close()
+
+
+def test_matrix_annotation(tdata):
+ fig, ax = plt.subplots(dpi=600)
+ pycea.pl.tree(
+ tdata,
+ key="tree",
+ nodes="internal",
+ node_color="clade",
+ node_size="time",
+ annotation_keys=["spatial_distance"],
+ ax=ax,
)
- plt.savefig(plot_path / "angled_numeric_branches.png")
+ plt.savefig(plot_path / "matrix_annotation.png")
+ plt.close()
+
+
+def test_branches_invalid_input(tdata):
+ fig, ax = plt.subplots()
+ with pytest.raises(ValueError):
+ pycea.pl.branches(tdata, key="tree", color=["bad"] * 5)
+ with pytest.raises(ValueError):
+ pycea.pl.branches(tdata, key="tree", linewidth=["bad"] * 5)
+ # Can't plot polar with non-polar axis
+ with pytest.raises(ValueError):
+ pycea.pl.branches(tdata, key="tree", polar=True, ax=ax)
+ plt.close()
+
+
+def test_annotation_invalid_input(tdata):
+ # Need to plot branches first
+ fig, ax = plt.subplots()
+ with pytest.raises(ValueError):
+ pycea.pl.annotation(tdata, keys="clade")
+ pycea.pl.branches(tdata, key="tree", ax=ax)
+ with pytest.raises(ValueError):
+ pycea.pl.annotation(tdata, keys=None, ax=ax)
+ with pytest.raises(ValueError):
+ pycea.pl.annotation(tdata, keys=False, ax=ax)
+ with pytest.raises(ValueError):
+ pycea.pl.annotation(tdata, keys="clade", label={}, ax=ax)
plt.close()
diff --git a/tests/test_plot_utils.py b/tests/test_plot_utils.py
index ad3ce6c..b18faa2 100755
--- a/tests/test_plot_utils.py
+++ b/tests/test_plot_utils.py
@@ -5,7 +5,13 @@
import pytest
import treedata as td
-from pycea.pl._utils import _get_categorical_colors, _get_default_categorical_colors, layout_tree
+from pycea.pl._utils import (
+ _get_categorical_colors,
+ _get_categorical_markers,
+ _get_default_categorical_colors,
+ _series_to_rgb_array,
+ layout_tree,
+)
# Test layout_tree
@@ -92,22 +98,17 @@ def test_palette_types(empty_tdata, category_data):
palette = {"apple": "red", "banana": "yellow", "cherry": "pink"}
colors = _get_categorical_colors(empty_tdata, "fruit", category_data, palette)
assert colors["apple"] == "#ff0000ff"
- # List
- palette = ["red", "yellow", "pink"]
- colors = _get_categorical_colors(empty_tdata, "fruit", category_data, palette)
- assert colors["apple"] == "#ff0000ff"
-
-
-def test_not_enough_colors(empty_tdata, category_data):
+ # List not enough
palette = ["red", "yellow"]
with pytest.warns(Warning, match="palette colors is smaller"):
colors = _get_categorical_colors(empty_tdata, "fruit", category_data, palette)
- assert colors["apple"] == "#ff0000ff"
+ assert colors["apple"] == "#ff0000ff"
def test_invalid_palette(empty_tdata, category_data):
- with pytest.raises(ValueError):
- _get_categorical_colors(empty_tdata, "fruit", category_data, ["bad"])
+ with pytest.warns(Warning, match="palette colors is smaller"):
+ with pytest.raises(ValueError):
+ _get_categorical_colors(empty_tdata, "fruit", category_data, ["bad"])
def test_pallete_in_uns(empty_tdata, category_data):
@@ -117,3 +118,54 @@ def test_pallete_in_uns(empty_tdata, category_data):
assert empty_tdata.uns["fruit_colors"] == list(palette_hex.values())
colors = _get_categorical_colors(empty_tdata, "fruit", category_data)
assert colors == palette_hex
+
+
+# Test _get_categorical_markers
+def test_markers_types(empty_tdata, category_data):
+ # None
+ markers = _get_categorical_markers(empty_tdata, "fruit", category_data)
+ assert markers["apple"] == "o"
+ # Dict
+ marker_dict = {"apple": "s", "banana": "o", "cherry": "o"}
+ colors = _get_categorical_markers(empty_tdata, "fruit", category_data, marker_dict)
+ assert colors["apple"] == "s"
+ # List not enough
+ marker_list = ["s", "o"]
+ with pytest.warns(Warning, match="Length of markers"):
+ markers = _get_categorical_markers(empty_tdata, "fruit", category_data, marker_list)
+ assert markers["apple"] == "s"
+
+
+def test_markers_in_uns(empty_tdata, category_data):
+ marker_dict = {"apple": "s", "banana": "o", "cherry": "o"}
+ markers = _get_categorical_markers(empty_tdata, "fruit", category_data, marker_dict)
+ assert "fruit_markers" in empty_tdata.uns
+ assert empty_tdata.uns["fruit_markers"] == list(marker_dict.values())
+ markers = _get_categorical_markers(empty_tdata, "fruit", category_data)
+ assert markers == marker_dict
+
+
+# Test _series_to_rgb_array
+def test_series_to_rgb_discrete(category_data):
+ colors = {"apple": "#ff0000ff", "banana": "#ffff00ff", "cherry": "#ff69b4ff"}
+ rgb_array = _series_to_rgb_array(category_data, colors)
+ expected = np.array([[1, 0, 0], [1, 1, 0], [1, 0.41176471, 0.70588235]])
+ np.testing.assert_almost_equal(rgb_array, expected, decimal=2)
+ # Test with missing data
+ category_data = pd.Series(["apple", pd.NA])
+ rgb_array = _series_to_rgb_array(category_data, colors)
+ expected = np.array([[1, 0, 0], [0.5, 0.5, 0.5]])
+ np.testing.assert_almost_equal(rgb_array, expected, decimal=2)
+
+
+def test_series_to_rgb_numeric():
+ numeric_data = pd.Series([0, 1, 2])
+ colors = mcolors.ListedColormap(["red", "yellow", "blue"])
+ rgb_array = _series_to_rgb_array(numeric_data, colors, vmin=0, vmax=2)
+ expected = np.array([[1, 0, 0], [1, 1, 0], [0, 0, 1]])
+ np.testing.assert_almost_equal(rgb_array, expected, decimal=2)
+ # Test with missing data
+ numeric_data = pd.Series([0, np.nan, 2])
+ rgb_array = _series_to_rgb_array(numeric_data, colors, vmin=0, vmax=2)
+ expected = np.array([[1, 0, 0], [0.5, 0.5, 0.5], [0, 0, 1]])
+ np.testing.assert_almost_equal(rgb_array, expected, decimal=2)
diff --git a/tests/test_utils.py b/tests/test_utils.py
new file mode 100755
index 0000000..d1f936c
--- /dev/null
+++ b/tests/test_utils.py
@@ -0,0 +1,66 @@
+import networkx as nx
+import pandas as pd
+import pytest
+
+from pycea.utils import get_keyed_edge_data, get_keyed_node_data, get_keyed_obs_data, get_root
+
+
+@pytest.fixture
+def tree():
+ t = nx.DiGraph()
+ t.add_edges_from([("A", "B"), ("A", "C"), ("B", "D"), ("C", "E")])
+ nx.set_node_attributes(t, {"A": 1, "B": 2, "C": None, "D": 4, "E": 5}, "value")
+ nx.set_edge_attributes(t, {("A", "B"): 5, ("A", "C"): None, ("B", "D"): 3, ("C", "E"): 4}, "weight")
+ yield t
+
+
+def test_get_root(tree):
+ # Test with an empty graph
+ assert get_root(nx.DiGraph()) is None
+ # Test with a non-empty graph
+ assert get_root(tree) == "A"
+ # Test with a single node
+ single_node_tree = nx.DiGraph()
+ single_node_tree.add_node("A")
+ assert get_root(single_node_tree) == "A"
+
+
+def test_get_keyed_edge_data(tree):
+ result = get_keyed_edge_data(tree, "weight")
+ expected_keys = [("A", "B"), ("B", "D"), ("C", "E")]
+ expected_values = [5, 3, 4]
+ assert all(result[key] == value for key, value in zip(expected_keys, expected_values))
+ assert ("A", "C") not in result
+
+
+def test_get_keyed_node_data(tree):
+ result = get_keyed_node_data(tree, "value")
+ expected_keys = ["A", "B", "D", "E"]
+ expected_values = [1, 2, 4, 5]
+ assert all(result[key] == value for key, value in zip(expected_keys, expected_values))
+ assert "C" not in result
+
+
+def test_get_keyed_obs_data_valid_keys(tdata):
+ data, is_array = get_keyed_obs_data(tdata, ["clade", "x", "0"])
+ assert not is_array
+ assert data.columns.tolist() == ["clade", "x", "0"]
+ # Automatically converts object columns to category
+ assert data["clade"].dtype == "category"
+ assert tdata.obs["clade"].dtype == "category"
+
+
+def test_get_keyed_obs_data_array(tdata):
+ data, is_array = get_keyed_obs_data(tdata, ["spatial"])
+ assert is_array
+ assert isinstance(data, pd.DataFrame)
+ assert data.shape[1] == 2
+ data, is_array = get_keyed_obs_data(tdata, ["spatial_distance"])
+ assert data.shape == (tdata.n_obs, tdata.n_obs)
+
+
+def test_get_keyed_obs_data_invalid_keys(tdata):
+ with pytest.raises(ValueError):
+ get_keyed_obs_data(tdata, ["clade", "x", "0", "invalid_key"])
+ with pytest.raises(ValueError):
+ get_keyed_obs_data(tdata, ["clade", "spatial_distance"])