diff --git a/docs/api.md b/docs/api.md index b76f6c7..5af4400 100644 --- a/docs/api.md +++ b/docs/api.md @@ -2,28 +2,8 @@ ## Preprocessing -```{eval-rst} -.. module:: pycea.pp -.. currentmodule:: pycea - -.. autosummary:: - :toctree: generated - - pp.basic_preproc -``` - ## Tools -```{eval-rst} -.. module:: pycea.tl -.. currentmodule:: pycea - -.. autosummary:: - :toctree: generated - - tl.basic_tool -``` - ## Plotting ```{eval-rst} @@ -33,5 +13,8 @@ .. autosummary:: :toctree: generated + pl.tree pl.branches + pl.nodes + pl.annotation ``` diff --git a/docs/conf.py b/docs/conf.py index 6c81c3a..4aa5dd0 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -91,10 +91,16 @@ } intersphinx_mapping = { - "python": ("https://docs.python.org/3", None), "anndata": ("https://anndata.readthedocs.io/en/stable/", None), + "cycler": ("https://matplotlib.org/cycler/", None), + "matplotlib": ("https://matplotlib.org/stable/", None), "numpy": ("https://numpy.org/doc/stable/", None), "networkx": ("https://networkx.org/documentation/stable/", None), + "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None), + "python": ("https://docs.python.org/3", None), + "scanpy": ("https://scanpy.readthedocs.io/en/stable/", None), + "scipy": ("https://docs.scipy.org/doc/scipy/reference/", None), + "squidpy": ("https://squidpy.readthedocs.io/en/stable/", None), "treedata": ("https://treedata.readthedocs.io/en/stable/", None), } diff --git a/docs/notebooks/example.ipynb b/docs/notebooks/example.ipynb index a9fa468..ce6d2f2 100644 --- a/docs/notebooks/example.ipynb +++ b/docs/notebooks/example.ipynb @@ -12,134 +12,7 @@ "execution_count": 1, "metadata": {}, "outputs": [], - "source": [ - "import numpy as np\n", - "from anndata import AnnData\n", - "import pandas as pd\n", - "import pycea" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "adata = AnnData(np.random.normal(size=(20, 10)))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "With myst it is possible to link in the text cell of a notebook such as this one the documentation of a function or a class.\n", - "\n", - "Let's take as an example the function {func}`pycea.pp.basic_preproc`. \n", - "You can see that by clicking on the text, the link redirects to the API documentation of the function. \n", - "Check the raw markdown of this cell to understand how this is specified.\n", - "\n", - "This works also for any package listed by `intersphinx`. Go to `docs/conf.py` and look for the `intersphinx_mapping` variable. \n", - "There, you will see a list of packages (that this package is dependent on) for which this functionality is supported. \n", - "\n", - "For instance, we can link to the class {class}`anndata.AnnData`, to the attribute {attr}`anndata.AnnData.obs` or the method {meth}`anndata.AnnData.write`.\n", - "\n", - "Again, check the raw markdown of this cell to see how each of these links are specified.\n", - "\n", - "You can read more about this in the [intersphinx page](https://www.sphinx-doc.org/en/master/usage/extensions/intersphinx.html) and the [myst page](https://myst-parser.readthedocs.io/en/v0.15.1/syntax/syntax.html#roles-an-in-line-extension-point)." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Implement a preprocessing function here." - ] - }, - { - "data": { - "text/plain": [ - "0" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "pycea.pp.basic_preproc(adata)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
AB
0a1
1b2
2c3
\n", - "
" - ], - "text/plain": [ - " A B\n", - "0 a 1\n", - "1 b 2\n", - "2 c 3" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "pd.DataFrame().assign(A=[\"a\", \"b\", \"c\"], B=[1, 2, 3])" - ] + "source": [] } ], "metadata": { diff --git a/src/pycea/_utils.py b/src/pycea/_utils.py deleted file mode 100755 index e249e38..0000000 --- a/src/pycea/_utils.py +++ /dev/null @@ -1,24 +0,0 @@ -import networkx as nx -import pandas as pd - - -def get_root(tree: nx.DiGraph): - """Finds the root of a tree""" - if not tree.nodes(): - return None # Handle empty graph case. - node = next(iter(tree.nodes)) - while True: - parent = list(tree.predecessors(node)) - if not parent: - return node # No predecessors, this is the root - node = parent[0] - - -def _get_keyed_edge_data(tree: nx.DiGraph, key: str) -> pd.Series: - """Gets edge data for a given key from a tree.""" - edge_data = { - (parent, child): data.get(key) - for parent, child, data in tree.edges(data=True) - if key in data and data[key] is not None - } - return pd.Series(edge_data, name=key) diff --git a/src/pycea/pl/__init__.py b/src/pycea/pl/__init__.py index e92b347..8b9468f 100644 --- a/src/pycea/pl/__init__.py +++ b/src/pycea/pl/__init__.py @@ -1 +1 @@ -from .tree import branches +from .plot_tree import annotation, branches, nodes, tree diff --git a/src/pycea/pl/_docs.py b/src/pycea/pl/_docs.py index 8d1b148..20b8c70 100755 --- a/src/pycea/pl/_docs.py +++ b/src/pycea/pl/_docs.py @@ -2,13 +2,27 @@ from __future__ import annotations +from textwrap import dedent + + +def _doc_params(**kwds): + r"""Docstrings should start with ``\\`` in the first line for proper formatting""" + + def dec(obj): + obj.__orig_doc__ = obj.__doc__ + obj.__doc__ = dedent(obj.__doc__).format_map(kwds) + return obj + + return dec + + doc_common_plot_args = """\ -color_map +cmap Color map to use for continous variables. Can be a name or a :class:`~matplotlib.colors.Colormap` instance (e.g. `"magma`", `"viridis"` or `mpl.cm.cividis`), see :func:`~matplotlib.cm.get_cmap`. If `None`, the value of `mpl.rcParams["image.cmap"]` is used. - The default `color_map` can be set using :func:`~scanpy.set_figure_params`. + The default `cmap` can be set using :func:`~scanpy.set_figure_params`. palette Colors to use for plotting categorical annotation groups. The palette can be a valid :class:`~matplotlib.colors.ListedColormap` name @@ -18,6 +32,10 @@ If `None`, `mpl.rcParams["axes.prop_cycle"]` is used unless the categorical variable already has colors stored in `tdata.uns["{var}_colors"]`. If provided, values of `tdata.uns["{var}_colors"]` will be set. +vmax + The maximum value for the colormap. +vmin + The minimum value for the colormap. ax A matplotlib axes object. If `None`, a new figure and axes will be created. """ diff --git a/src/pycea/pl/_utils.py b/src/pycea/pl/_utils.py index 799cf7b..fb60491 100755 --- a/src/pycea/pl/_utils.py +++ b/src/pycea/pl/_utils.py @@ -11,7 +11,7 @@ import numpy as np from scanpy.plotting import palettes -from pycea._utils import get_root +from pycea.utils import get_root def layout_tree( @@ -62,7 +62,7 @@ def layout_tree( node_coords = {} for node in nx.dfs_postorder_nodes(tree, root): if tree.out_degree(node) == 0: - lon = (i / n_leaves) * 2 * np.pi + lon = (i / (n_leaves)) * 2 * np.pi # + 2 * np.pi / n_leaves if extend_branches: node_coords[node] = (max_depth, lon) else: @@ -175,3 +175,59 @@ def _get_categorical_colors(tdata, key, data, palette=None): # store colors in tdata tdata.uns[key + "_colors"] = colors_list return dict(zip(categories, colors_list)) + + +def _get_categorical_markers(tdata, key, data, markers=None): + """Get categorical markers for plotting.""" + default_markers = ["o", "s", "D", "^", "v", "<", ">", "p", "P", "*", "h", "H", "X"] + # Ensure data is a category + if not data.dtype.name == "category": + data = data.astype("category") + categories = data.cat.categories + # Use default markers if no markers are provided + if markers is None: + markers_list = tdata.uns.get(key + "_markers", None) + if markers_list is None or len(markers_list) > len(categories): + markers_list = default_markers[: len(categories)] + # Use provided markers + else: + if isinstance(markers, cabc.Mapping): + markers_list = [markers[k] for k in categories] + else: + if not isinstance(markers, cabc.Sequence): + raise ValueError("Please check that the value of 'markers' is a valid " "list of marker names.") + if len(markers) < len(categories): + warnings.warn( + "Length of markers is smaller than the number of " + f"categories (markers length: {len(markers)}, " + f"categories length: {len(categories)}. " + "Some categories will have the same marker.", + stacklevel=2, + ) + markers_list = markers * (len(categories) // len(markers) + 1) + else: + markers_list = markers[: len(categories)] + # store markers in tdata + tdata.uns[key + "_markers"] = markers_list + return dict(zip(categories, markers_list)) + + +def _series_to_rgb_array(series, colors, vmin=None, vmax=None, na_color="#808080"): + """Converts a pandas Series to an N x 3 numpy array based using a color map.""" + if isinstance(colors, dict): + # Map using the dictionary + color_series = series.map(colors) + color_series[series.isna()] = na_color + rgb_array = np.array([mcolors.to_rgb(color) for color in color_series]) + elif isinstance(colors, mcolors.ListedColormap): + # Normalize and map values if cmap is a ListedColormap + if vmin is not None and vmax is not None: + norm = mcolors.Normalize(vmin, vmax) + colors.set_bad(na_color) + color_series = colors(norm(series)) + rgb_array = np.vstack(color_series[:, :3]) + else: + raise ValueError("vmin and vmax must be specified when using a ListedColormap.") + else: + raise ValueError("cmap must be either a dictionary or a ListedColormap.") + return rgb_array diff --git a/src/pycea/pl/plot_tree.py b/src/pycea/pl/plot_tree.py new file mode 100644 index 0000000..6f66320 --- /dev/null +++ b/src/pycea/pl/plot_tree.py @@ -0,0 +1,492 @@ +from __future__ import annotations + +from collections.abc import Mapping, Sequence + +import cycler +import matplotlib as mpl +import matplotlib.colors as mcolors +import matplotlib.markers as mmarkers +import matplotlib.pyplot as plt +import numpy as np +import treedata as td +from matplotlib.axes import Axes +from matplotlib.collections import LineCollection + +from pycea.utils import get_keyed_edge_data, get_keyed_node_data, get_keyed_obs_data + +from ._docs import _doc_params, doc_common_plot_args +from ._utils import ( + _get_categorical_colors, + _get_categorical_markers, + _series_to_rgb_array, + layout_tree, +) + + +@_doc_params( + common_plot_args=doc_common_plot_args, +) +def branches( + tdata: td.TreeData, + key: str = None, + polar: bool = False, + extend_branches: bool = False, + angled_branches: bool = False, + color: str = "black", + linewidth: int | float | str = 1, + cmap: str | mcolors.Colormap = "viridis", + palette: cycler.Cycler | mcolors.ListedColormap | Sequence[str] | Mapping[str] | None = None, + na_color: str = "lightgrey", + na_linewidth: int | float = 1, + ax: Axes | None = None, + **kwargs, +) -> Axes: + """\ + Plot the branches of a tree. + + Parameters + ---------- + tdata + The `treedata.TreeData` object. + key + The `obst` key of the tree to plot. + polar + Whether to plot the tree in polar coordinates. + extend_branches + Whether to extend branches so the tips are at the same depth. + angled_branches + Whether to plot branches at an angle. + color + Either a color name, or a key for an attribute of the edges to color by. + linewidth + Either an numeric width, or a key for an attribute of the edges to set the linewidth. + {common_plot_args} + na_color + The color to use for edges with missing data. + na_linewidth + The linewidth to use for edges with missing data. + kwargs + Additional keyword arguments passed to `matplotlib.collections.LineCollection`. + + Returns + ------- + ax - The axes that the plot was drawn on. + """ # noqa: D205 + kwargs = kwargs if kwargs else {} + if not key: + key = next(iter(tdata.obst.keys())) + tree = tdata.obst[key] + # Get layout + node_coords, branch_coords, leaves, depth = layout_tree( + tree, polar=polar, extend_branches=extend_branches, angled_branches=angled_branches + ) + segments = [] + edges = [] + for edge, (lat, lon) in branch_coords.items(): + coords = np.array([lon, lat] if polar else [lat, lon]).T + segments.append(coords) + edges.append(edge) + kwargs.update({"segments": segments}) + # Get colors + if mcolors.is_color_like(color): + kwargs.update({"color": color}) + elif isinstance(color, str): + color_data = get_keyed_edge_data(tree, color) + if color_data.dtype.kind in ["i", "f"]: + norm = plt.Normalize(vmin=color_data.min(), vmax=color_data.max()) + cmap = plt.get_cmap(cmap) + colors = [cmap(norm(color_data[edge])) if edge in color_data.index else na_color for edge in edges] + kwargs.update({"color": colors}) + else: + cmap = _get_categorical_colors(tdata, color, color_data, palette) + colors = [cmap[color_data[edge]] if edge in color_data.index else na_color for edge in edges] + kwargs.update({"color": colors}) + else: + raise ValueError("Invalid color value. Must be a color name, or an str specifying an attribute of the edges.") + # Get linewidths + if isinstance(linewidth, (int, float)): + kwargs.update({"linewidth": linewidth}) + elif isinstance(linewidth, str): + linewidth_data = get_keyed_edge_data(tree, linewidth) + if linewidth_data.dtype.kind in ["i", "f"]: + linewidths = [linewidth_data[edge] if edge in linewidth_data.index else na_linewidth for edge in edges] + kwargs.update({"linewidth": linewidths}) + else: + raise ValueError("Invalid linewidth data type. Edge attribute must be int or float") + else: + raise ValueError("Invalid linewidth value. Must be int, float, or an str specifying an attribute of the edges.") + # Plot + if not ax: + subplot_kw = {"projection": "polar"} if polar else None + fig, ax = plt.subplots(subplot_kw=subplot_kw) + elif (ax.name == "polar") != polar: + raise ValueError("Provided axis does not match the requested 'polar' setting.") + ax.add_collection(LineCollection(zorder=1, **kwargs)) + # Configure plot + lat_lim = (-0.2, depth) + lon_lim = (0, 2 * np.pi) + ax.set_xlim(lon_lim if polar else lat_lim) + ax.set_ylim(lat_lim if polar else lon_lim) + ax.axis("off") + ax._attrs = { + "node_coords": node_coords, + "leaves": leaves, + "depth": depth, + "offset": depth, + "polar": polar, + "tree_key": key, + } + return ax + + +# For internal use +_branches = branches + + +@_doc_params( + common_plot_args=doc_common_plot_args, +) +def nodes( + tdata: td.TreeData, + nodes: str | Sequence[str] = "internal", + color: str = "black", + style: str = "o", + size: int | float | str = 10, + cmap: str | mcolors.Colormap = None, + palette: cycler.Cycler | mcolors.ListedColormap | Sequence[str] | Mapping[str] | None = None, + markers: Sequence[str] | Mapping[str] = None, + vmax: int | float | None = None, + vmin: int | float | None = None, + na_color: str = "#FFFFFF00", + na_style: str = "none", + na_size: int | float = 0, + ax: Axes | None = None, + **kwargs, +) -> Axes: + """\ + Plot the nodes of a tree. + + Parameters + ---------- + tdata + The TreeData object. + nodes + Either "all", "leaves", "internal", or a list of nodes to plot. + color + Either a color name, or a key for an attribute of the nodes to color by. + style + Either a marker name, or a key for an attribute of the nodes to set the marker. + Can be numeric but will always be treated as a categorical variable. + size + Either an numeric size, or a key for an attribute of the nodes to set the size. + {common_plot_args} + markers + Object determining how to draw the markers for different levels of the style variable. + You can pass a list of markers or a dictionary mapping levels of the style variable to markers. + na_color + The color to use for annotations with missing data. + na_style + The marker to use for annotations with missing data. + na_size + The size to use for annotations with missing data. + kwargs + Additional keyword arguments passed to `matplotlib.pyplot.scatter`. + + Returns + ------- + ax - The axes that the plot was drawn on. + """ # noqa: D205 + # Setup + kwargs = kwargs if kwargs else {} + if not ax: + ax = plt.gca() + attrs = ax._attrs if hasattr(ax, "_attrs") else None + if not attrs: + raise ValueError("Branches most be plotted with pycea.pl.branches before annotations can be plotted.") + if not cmap: + cmap = mpl.rcParams["image.cmap"] + cmap = plt.get_cmap(cmap) + tree = tdata.obst[attrs["tree_key"]] + # Get nodes + all_nodes = list(attrs["node_coords"].keys()) + leaves = list(attrs["leaves"]) + if nodes == "all": + nodes = all_nodes + elif nodes == "leaves": + nodes = leaves + elif nodes == "internal": + nodes = [node for node in all_nodes if node not in leaves] + elif isinstance(nodes, Sequence): + if set(nodes).issubset(all_nodes): + nodes = list(nodes) + else: + raise ValueError("Nodes must be a list of nodes in the tree.") + else: + raise ValueError("Invalid nodes value. Must be 'all', 'leaves', 'no_leaves', or a list of nodes.") + # Get coordinates + coords = np.vstack([attrs["node_coords"][node] for node in nodes]) + if attrs["polar"]: + kwargs.update({"x": coords[:, 1], "y": coords[:, 0]}) + else: + kwargs.update({"x": coords[:, 0], "y": coords[:, 1]}) + kwargs_list = [] + # Get colors + if mcolors.is_color_like(color): + kwargs.update({"color": color}) + elif isinstance(color, str): + color_data = get_keyed_node_data(tree, color) + if color_data.dtype.kind in ["i", "f"]: + if not vmin: + vmin = color_data.min() + if not vmax: + vmax = color_data.max() + norm = plt.Normalize(vmin=vmin, vmax=vmax) + colors = [cmap(norm(color_data[node])) if node in color_data.index else na_color for node in nodes] + kwargs.update({"color": colors}) + else: + cmap = _get_categorical_colors(tdata, color, color_data, palette) + colors = [cmap[color_data[node]] if node in color_data.index else na_color for node in nodes] + kwargs.update({"color": colors}) + else: + raise ValueError("Invalid color value. Must be a color name, or an str specifying an attribute of the nodes.") + # Get sizes + if isinstance(size, (int, float)): + kwargs.update({"s": size}) + elif isinstance(size, str): + size_data = get_keyed_node_data(tree, size) + sizes = [size_data[node] if node in size_data.index else na_size for node in nodes] + kwargs.update({"s": sizes}) + else: + raise ValueError("Invalid size value. Must be int, float, or an str specifying an attribute of the nodes.") + # Get markers + if style in mmarkers.MarkerStyle.markers: + kwargs.update({"marker": style}) + elif isinstance(style, str): + style_data = get_keyed_node_data(tree, style) + mmap = _get_categorical_markers(tdata, style, style_data, markers) + styles = [mmap[style_data[node]] if node in style_data.index else na_style for node in nodes] + for style in set(styles): + style_kwargs = {} + idx = [i for i, x in enumerate(styles) if x == style] + for key, value in kwargs.items(): + if isinstance(value, (list, np.ndarray)): + style_kwargs[key] = [value[i] for i in idx] + else: + style_kwargs[key] = value + style_kwargs.update({"marker": style}) + kwargs_list.append(style_kwargs) + else: + raise ValueError("Invalid style value. Must be a marker name, or an str specifying an attribute of the nodes.") + # Plot + if len(kwargs_list) > 0: + for kwargs in kwargs_list: + ax.scatter(**kwargs) + else: + ax.scatter(**kwargs) + return ax + + +# For internal use +_nodes = nodes + + +@_doc_params( + common_plot_args=doc_common_plot_args, +) +def annotation( + tdata: td.TreeData, + keys: str | Sequence[str] = None, + width: int | float = 0.05, + gap: int | float = 0.03, + label: bool | str | Sequence[str] = True, + cmap: str | mcolors.Colormap = None, + palette: cycler.Cycler | mcolors.ListedColormap | Sequence[str] | Mapping[str] | None = None, + vmax: int | float | None = None, + vmin: int | float | None = None, + na_color: str = "white", + ax: Axes | None = None, + **kwargs, +) -> Axes: + """\ + Plot leaf annotations. + + Parameters + ---------- + tdata + The TreeData object. + keys + One or more `obs_keys`, `var_names`, `obsm_keys`, or `obsp_keys` to plot. + width + The width of the annotation bar relative to the tree. + gap + The gap between the annotation bar and the tree relative to the tree. + label + Annotation labels. If `True`, the keys are used as labels. + If a string or a sequence of strings, the strings are used as labels. + {common_plot_args} + na_color + The color to use for annotations with missing data. + kwargs + Additional keyword arguments passed to `matplotlib.pyplot.pcolormesh`. + + Returns + ------- + ax - The axes that the plot was drawn on. + """ # noqa: D205 + # Setup + if not ax: + ax = plt.gca() + attrs = ax._attrs if hasattr(ax, "_attrs") else None + if not attrs: + raise ValueError("Branches most be plotted with pycea.pl.branches before annotations can be plotted.") + if not keys: + raise ValueError("No keys provided. Please provide one or more keys to plot.") + keys = [keys] if isinstance(keys, str) else keys + if not cmap: + cmap = mpl.rcParams["image.cmap"] + cmap = plt.get_cmap(cmap) + # Get data + data, is_array = get_keyed_obs_data(tdata, keys) + data = data.loc[attrs["leaves"]] + numeric_data = data.select_dtypes(exclude="category") + if len(numeric_data) > 0 and not vmin: + vmin = numeric_data.min().min() + if len(numeric_data) > 0 and not vmax: + vmax = numeric_data.max().max() + # Get labels + if label is True: + labels = keys + elif isinstance(label, str): + labels = [label] + elif isinstance(label, Sequence): + labels = label + else: + raise ValueError("Invalid label value. Must be a bool, str, or a sequence of strings.") + # Compute coordinates for annotations + start_lat = attrs["offset"] + attrs["depth"] * gap + end_lat = start_lat + attrs["depth"] * width * data.shape[1] + lats = np.linspace(start_lat, end_lat, data.shape[1] + 1) + lons = np.linspace(0, 2 * np.pi, data.shape[0] + 1) + lons = lons - np.pi / len(attrs["leaves"]) + # Covert to RGB array + rgb_array = [] + if is_array: + if data.shape[0] == data.shape[1]: + data = data.loc[attrs["leaves"], reversed(attrs["leaves"])] + end_lat = start_lat + attrs["depth"] + 2 * np.pi + lats = np.linspace(start_lat, end_lat, data.shape[1] + 1) + for col in data.columns: + rgb_array.append(_series_to_rgb_array(data[col], cmap, vmin=vmin, vmax=vmax, na_color=na_color)) + else: + for key in keys: + if data[key].dtype == "category": + colors = _get_categorical_colors(tdata, key, data[key], palette) + rgb_array.append(_series_to_rgb_array(data[key], colors, na_color=na_color)) + else: + rgb_array.append(_series_to_rgb_array(data[key], cmap, vmin=vmin, vmax=vmax, na_color=na_color)) + rgb_array = np.stack(rgb_array, axis=1) + # Plot + if attrs["polar"]: + ax.pcolormesh(lons, lats, rgb_array.swapaxes(0, 1), zorder=2, **kwargs) + ax.set_ylim(-0.2, end_lat) + else: + ax.pcolormesh(lats, lons, rgb_array, zorder=2, **kwargs) + ax.set_xlim(-0.2, end_lat) + labels_lats = np.linspace(start_lat, end_lat, len(labels) + 1) + labels_lats = labels_lats + (end_lat - start_lat) / (len(labels) * 2) + for idx, label in enumerate(labels): + if is_array and len(labels) == 1: + ax.text(labels_lats[idx], -0.1, label, ha="center", va="top") + ax.set_ylim(-0.5, 2 * np.pi) + else: + ax.text(labels_lats[idx], -0.1, label, ha="center", va="top", rotation=90) + ax.set_ylim(-1, 2 * np.pi) + ax._attrs.update({"offset": end_lat}) + return ax + + +# For internal use +_annotation = annotation + + +@_doc_params( + common_plot_args=doc_common_plot_args, +) +def tree( + tdata: td.TreeData, + key: str = None, + nodes: str | Sequence[str] = None, + annotation_keys: str | Sequence[str] = None, + polar: bool = False, + extend_branches: bool = False, + angled_branches: bool = False, + branch_color: str = "black", + branch_linewidth: int | float | str = 1, + node_color: str = "black", + node_style: str = "o", + node_size: int | float = 10, + annotation_width: int | float = 0.05, + cmap: str | mcolors.Colormap = "viridis", + palette: cycler.Cycler | mcolors.ListedColormap | Sequence[str] | Mapping[str] | None = None, + ax: Axes | None = None, + **kwargs, +) -> Axes: + """\ + Plot a tree with branches, nodes, and annotations. + + Parameters + ---------- + tdata + The TreeData object. + key + The `obst` key of the tree to plot. + nodes + Either "all", "leaves", "internal", or a list of nodes to plot. + annotation_keys + One or more `obs_keys`, `var_names`, `obsm_keys`, or `obsp_keys` to plot. + polar + Whether to plot the tree in polar coordinates. + extend_branches + Whether to extend branches so the tips are at the same depth. + angled_branches + Whether to plot branches at an angle. + branch_color + Either a color name, or a key for an attribute of the edges to color by. + branch_linewidth + Either an numeric width, or a key for an attribute of the edges to set the linewidth. + node_color + Either a color name, or a key for an attribute of the nodes to color by. + node_style + Either a marker name, or a key for an attribute of the nodes to set the marker. + node_size + Either an numeric size, or a key for an attribute of the nodes to set the size. + annotation_width + The width of the annotation bar relative to the tree. + {common_plot_args} + + Returns + ------- + ax - The axes that the plot was drawn on. + """ # noqa: D205 + # Plot branches + ax = _branches( + tdata, + key=key, + polar=polar, + extend_branches=extend_branches, + angled_branches=angled_branches, + color=branch_color, + linewidth=branch_linewidth, + cmap=cmap, + palette=palette, + ax=ax, + ) + # Plot nodes + if nodes: + ax = _nodes( + tdata, nodes=nodes, color=node_color, style=node_style, size=node_size, cmap=cmap, palette=palette, ax=ax + ) + # Plot annotations + if annotation_keys: + ax = _annotation(tdata, keys=annotation_keys, width=annotation_width, cmap=cmap, palette=palette, ax=ax) + return ax diff --git a/src/pycea/pl/tree.py b/src/pycea/pl/tree.py deleted file mode 100644 index 99a9051..0000000 --- a/src/pycea/pl/tree.py +++ /dev/null @@ -1,122 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence - -import cycler -import matplotlib.colors as mcolors -import matplotlib.pyplot as plt -import numpy as np -import treedata as td -from matplotlib.axes import Axes -from matplotlib.collections import LineCollection - -from pycea._utils import _get_keyed_edge_data - -from ._utils import ( - _get_categorical_colors, - layout_tree, -) - - -def branches( - tdata: td.TreeData, - key: str = None, - polar: bool = False, - extend_branches: bool = False, - angled_branches: bool = False, - color: str = "black", - color_na: str = "lightgrey", - linewidth: int | float | str = 1, - linewidth_na: int | float = 1, - cmap: str | mcolors.Colormap = "viridis", - palette: cycler.Cycler | mcolors.ListedColormap | Sequence[str] | Mapping[str] | None = None, - ax: Axes | None = None, - **kwargs, -): - """Plot the branches of a tree. - - Parameters - ---------- - tdata - The `td.TreeData` object. - key - The `obst` key of the tree to plot. - polar - Whether to plot the tree in polar coordinates. - extend_branches - Whether to extend branches so the tips are at the same depth. - angled_branches - Whether to plot branches at an angle. - color - Either a color name, or a key for an attribute of the edges to color by. - color_na - The color to use for edges with missing data. - linewidth - Either an numeric width, or a key for an attribute of the edges to set the linewidth. - linewidth_na - The linewidth to use for edges with missing data. - {doc_common_plot_args} - kwargs - Additional keyword arguments passed to `matplotlib.collections.LineCollection`. - - Returns - ------- - `matplotlib.axes.Axes` - """ - kwargs = kwargs if kwargs else {} - tree = tdata.obst[key] - - # Get layout - node_coords, branch_coords, leaves, depth = layout_tree( - tree, polar=polar, extend_branches=extend_branches, angled_branches=angled_branches - ) - segments = [] - edges = [] - for edge, (lat, lon) in branch_coords.items(): - coords = np.array([lon, lat] if polar else [lat, lon]).T - segments.append(coords) - edges.append(edge) - kwargs.update({"segments": segments}) - # Get colors - if mcolors.is_color_like(color): - kwargs.update({"color": color}) - elif isinstance(color, str): - color_data = _get_keyed_edge_data(tree, color) - if color_data.dtype.kind in ["i", "f"]: - norm = plt.Normalize(vmin=color_data.min(), vmax=color_data.max()) - cmap = plt.get_cmap(cmap) - colors = [cmap(norm(color_data[edge])) if edge in color_data.index else color_na for edge in edges] - kwargs.update({"color": colors}) - else: - cmap = _get_categorical_colors(tdata, color, color_data, palette) - colors = [cmap[color_data[edge]] if edge in color_data.index else color_na for edge in edges] - kwargs.update({"color": colors}) - else: - raise ValueError("Invalid color value. Must be a color name, or an str specifying an attribute of the edges.") - # Get linewidths - if isinstance(linewidth, (int, float)): - kwargs.update({"linewidth": linewidth}) - elif isinstance(linewidth, str): - linewidth_data = _get_keyed_edge_data(tree, linewidth) - if linewidth_data.dtype.kind in ["i", "f"]: - linewidths = [linewidth_data[edge] if edge in linewidth_data.index else linewidth_na for edge in edges] - kwargs.update({"linewidth": linewidths}) - else: - raise ValueError("Invalid linewidth data type. Edge attribute must be int or float") - else: - raise ValueError("Invalid linewidth value. Must be int, float, or an str specifying an attribute of the edges.") - # Plot - if not ax: - subplot_kw = {"projection": "polar"} if polar else None - fig, ax = plt.subplots(subplot_kw=subplot_kw) - elif (ax.name == "polar") != polar: - raise ValueError("Provided axis does not match the requested 'polar' setting.") - ax.add_collection(LineCollection(**kwargs)) - # Configure plot - lat_lim = (-0.1, depth) - lon_lim = (0, 2 * np.pi) - ax.set_xlim(lon_lim if polar else lat_lim) - ax.set_ylim(lat_lim if polar else lon_lim) - ax.axis("off") - ax._attrs = {"node_coords": node_coords, "leaves": leaves, "depth": depth, "offset": depth, "polar": polar} - return ax diff --git a/src/pycea/pp/__init__.py b/src/pycea/pp/__init__.py index 5e7e293..e69de29 100644 --- a/src/pycea/pp/__init__.py +++ b/src/pycea/pp/__init__.py @@ -1 +0,0 @@ -from .basic import basic_preproc diff --git a/src/pycea/pp/basic.py b/src/pycea/pp/basic.py deleted file mode 100644 index 5db1ec0..0000000 --- a/src/pycea/pp/basic.py +++ /dev/null @@ -1,17 +0,0 @@ -from anndata import AnnData - - -def basic_preproc(adata: AnnData) -> int: - """Run a basic preprocessing on the AnnData object. - - Parameters - ---------- - adata - The AnnData object to preprocess. - - Returns - ------- - Some integer value. - """ - print("Implement a preprocessing function here.") - return 0 diff --git a/src/pycea/tl/__init__.py b/src/pycea/tl/__init__.py index 95a32cd..e69de29 100644 --- a/src/pycea/tl/__init__.py +++ b/src/pycea/tl/__init__.py @@ -1 +0,0 @@ -from .basic import basic_tool diff --git a/src/pycea/tl/basic.py b/src/pycea/tl/basic.py deleted file mode 100644 index d215ade..0000000 --- a/src/pycea/tl/basic.py +++ /dev/null @@ -1,17 +0,0 @@ -from anndata import AnnData - - -def basic_tool(adata: AnnData) -> int: - """Run a tool on the AnnData object. - - Parameters - ---------- - adata - The AnnData object to preprocess. - - Returns - ------- - Some integer value. - """ - print("Implement a tool to run on the AnnData object.") - return 0 diff --git a/src/pycea/utils.py b/src/pycea/utils.py new file mode 100755 index 0000000..2ffa0e5 --- /dev/null +++ b/src/pycea/utils.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from collections.abc import Sequence + +import networkx as nx +import pandas as pd +import treedata as td + + +def get_root(tree: nx.DiGraph): + """Finds the root of a tree""" + if not tree.nodes(): + return None # Handle empty graph case. + node = next(iter(tree.nodes)) + while True: + parent = list(tree.predecessors(node)) + if not parent: + return node # No predecessors, this is the root + node = parent[0] + + +def get_keyed_edge_data(tree: nx.DiGraph, key: str) -> pd.Series: + """Gets edge data for a given key from a tree.""" + edge_data = { + (parent, child): data.get(key) + for parent, child, data in tree.edges(data=True) + if key in data and data[key] is not None + } + if len(edge_data) == 0: + raise ValueError(f"Key {key!r} is not present in any edge.") + return pd.Series(edge_data, name=key) + + +def get_keyed_node_data(tree: nx.DiGraph, key: str) -> pd.Series: + """Gets node data for a given key from a tree.""" + node_data = {node: data.get(key) for node, data in tree.nodes(data=True) if key in data and data[key] is not None} + if len(node_data) == 0: + raise ValueError(f"Key {key!r} is not present in any node.") + return pd.Series(node_data, name=key) + + +def get_keyed_obs_data(tdata: td.TreeData, keys: Sequence[str], layer: str = None) -> pd.DataFrame: + """Gets observation data for a given key from a tree.""" + data = [] + column_keys = False + array_keys = False + for key in keys: + if key in tdata.obs_keys(): + if tdata.obs[key].dtype.kind in ["b", "O", "S"]: + tdata.obs[key] = tdata.obs[key].astype("category") + data.append(tdata.obs[key]) + column_keys = True + elif key in tdata.var_names: + data.append(pd.Series(tdata.obs_vector(key, layer=layer), index=tdata.obs_names)) + column_keys = True + elif "obsm" in dir(tdata) and key in tdata.obsm.keys(): + data.append(tdata.obsm[key]) + array_keys = True + elif "obsp" in dir(tdata) and key in tdata.obsp.keys(): + data.append(tdata.obsp[key]) + array_keys = True + else: + raise ValueError( + f"Key {key!r} is invalid! You must pass a valid observation annotation. " + f"One of obs_keys, var_names, obsm_keys, obsp_keys." + ) + if column_keys and array_keys: + raise ValueError("Cannot mix column and matrix keys.") + if array_keys and len(keys) > 1: + raise ValueError("Cannot request multiple matrix keys.") + if not column_keys and not array_keys: + raise ValueError("No valid keys found.") + # Convert to DataFrame + if column_keys: + data = pd.concat(data, axis=1) + data.columns = keys + elif array_keys: + data = pd.DataFrame(data[0], index=tdata.obs_names) + + if data.shape[0] == data.shape[1]: + data.columns = tdata.obs_names + return data, array_keys diff --git a/tests/conftest.py b/tests/conftest.py index ba9a282..3aaf61b 100755 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,3 @@ -from pathlib import Path - import pytest import treedata as td @@ -9,6 +7,3 @@ @pytest.fixture(scope="session") def tdata() -> td.TreeData: return _tdata - - -PLOT_PATH = Path("tests/plots/") diff --git a/tests/test_basic.py b/tests/test_basic.py deleted file mode 100644 index c01c90a..0000000 --- a/tests/test_basic.py +++ /dev/null @@ -1,12 +0,0 @@ -import pytest - -import pycea - - -def test_package_has_version(): - assert pycea.__version__ is not None - - -@pytest.mark.skip(reason="This decorator should be removed when test passes.") -def test_example(): - assert 1 == 0 # This test is designed to fail. diff --git a/tests/test_plot_tree.py b/tests/test_plot_tree.py index d295938..c1c9b1d 100755 --- a/tests/test_plot_tree.py +++ b/tests/test_plot_tree.py @@ -1,20 +1,71 @@ from pathlib import Path import matplotlib.pyplot as plt +import pytest import pycea plot_path = Path(__file__).parent / "plots" -def test_plot_branches(tdata): - # Polar categorical with missing - pycea.pl.branches(tdata, key="tree", polar=True, color="clade", palette="Set1") - plt.savefig(plot_path / "polar_categorical_branches.png") +def test_polar_with_clades(tdata): + fig, ax = plt.subplots(dpi=600, subplot_kw={"polar": True}) + pycea.pl.branches(tdata, key="tree", polar=True, color="clade", palette="Set1", na_color="black", ax=ax) + pycea.pl.nodes(tdata, color="clade", palette="Set1", style="clade", ax=ax) + pycea.pl.annotation(tdata, keys="clade", ax=ax) + plt.savefig(plot_path / "polar_clades.png") plt.close() - # Numeric with line width + + +def test_angled_numeric_annotations(tdata): + fig, ax = plt.subplots(dpi=600) pycea.pl.branches( - tdata, key="tree", polar=False, color="length", cmap="hsv", linewidth="length", angled_branches=True + tdata, key="tree", polar=False, color="length", cmap="hsv", linewidth="length", angled_branches=True, ax=ax + ) + pycea.pl.nodes(tdata, nodes="all", color="time", style="s", size=20, ax=ax) + pycea.pl.annotation(tdata, keys=["x", "y"], cmap="magma", width=0.1, gap=0.05, ax=ax) + pycea.pl.annotation(tdata, keys=["0", "1", "2", "3", "4", "5"], label="genes", ax=ax) + plt.savefig(plot_path / "angled_numeric.png") + plt.close() + + +def test_matrix_annotation(tdata): + fig, ax = plt.subplots(dpi=600) + pycea.pl.tree( + tdata, + key="tree", + nodes="internal", + node_color="clade", + node_size="time", + annotation_keys=["spatial_distance"], + ax=ax, ) - plt.savefig(plot_path / "angled_numeric_branches.png") + plt.savefig(plot_path / "matrix_annotation.png") + plt.close() + + +def test_branches_invalid_input(tdata): + fig, ax = plt.subplots() + with pytest.raises(ValueError): + pycea.pl.branches(tdata, key="tree", color=["bad"] * 5) + with pytest.raises(ValueError): + pycea.pl.branches(tdata, key="tree", linewidth=["bad"] * 5) + # Can't plot polar with non-polar axis + with pytest.raises(ValueError): + pycea.pl.branches(tdata, key="tree", polar=True, ax=ax) + plt.close() + + +def test_annotation_invalid_input(tdata): + # Need to plot branches first + fig, ax = plt.subplots() + with pytest.raises(ValueError): + pycea.pl.annotation(tdata, keys="clade") + pycea.pl.branches(tdata, key="tree", ax=ax) + with pytest.raises(ValueError): + pycea.pl.annotation(tdata, keys=None, ax=ax) + with pytest.raises(ValueError): + pycea.pl.annotation(tdata, keys=False, ax=ax) + with pytest.raises(ValueError): + pycea.pl.annotation(tdata, keys="clade", label={}, ax=ax) plt.close() diff --git a/tests/test_plot_utils.py b/tests/test_plot_utils.py index ad3ce6c..b18faa2 100755 --- a/tests/test_plot_utils.py +++ b/tests/test_plot_utils.py @@ -5,7 +5,13 @@ import pytest import treedata as td -from pycea.pl._utils import _get_categorical_colors, _get_default_categorical_colors, layout_tree +from pycea.pl._utils import ( + _get_categorical_colors, + _get_categorical_markers, + _get_default_categorical_colors, + _series_to_rgb_array, + layout_tree, +) # Test layout_tree @@ -92,22 +98,17 @@ def test_palette_types(empty_tdata, category_data): palette = {"apple": "red", "banana": "yellow", "cherry": "pink"} colors = _get_categorical_colors(empty_tdata, "fruit", category_data, palette) assert colors["apple"] == "#ff0000ff" - # List - palette = ["red", "yellow", "pink"] - colors = _get_categorical_colors(empty_tdata, "fruit", category_data, palette) - assert colors["apple"] == "#ff0000ff" - - -def test_not_enough_colors(empty_tdata, category_data): + # List not enough palette = ["red", "yellow"] with pytest.warns(Warning, match="palette colors is smaller"): colors = _get_categorical_colors(empty_tdata, "fruit", category_data, palette) - assert colors["apple"] == "#ff0000ff" + assert colors["apple"] == "#ff0000ff" def test_invalid_palette(empty_tdata, category_data): - with pytest.raises(ValueError): - _get_categorical_colors(empty_tdata, "fruit", category_data, ["bad"]) + with pytest.warns(Warning, match="palette colors is smaller"): + with pytest.raises(ValueError): + _get_categorical_colors(empty_tdata, "fruit", category_data, ["bad"]) def test_pallete_in_uns(empty_tdata, category_data): @@ -117,3 +118,54 @@ def test_pallete_in_uns(empty_tdata, category_data): assert empty_tdata.uns["fruit_colors"] == list(palette_hex.values()) colors = _get_categorical_colors(empty_tdata, "fruit", category_data) assert colors == palette_hex + + +# Test _get_categorical_markers +def test_markers_types(empty_tdata, category_data): + # None + markers = _get_categorical_markers(empty_tdata, "fruit", category_data) + assert markers["apple"] == "o" + # Dict + marker_dict = {"apple": "s", "banana": "o", "cherry": "o"} + colors = _get_categorical_markers(empty_tdata, "fruit", category_data, marker_dict) + assert colors["apple"] == "s" + # List not enough + marker_list = ["s", "o"] + with pytest.warns(Warning, match="Length of markers"): + markers = _get_categorical_markers(empty_tdata, "fruit", category_data, marker_list) + assert markers["apple"] == "s" + + +def test_markers_in_uns(empty_tdata, category_data): + marker_dict = {"apple": "s", "banana": "o", "cherry": "o"} + markers = _get_categorical_markers(empty_tdata, "fruit", category_data, marker_dict) + assert "fruit_markers" in empty_tdata.uns + assert empty_tdata.uns["fruit_markers"] == list(marker_dict.values()) + markers = _get_categorical_markers(empty_tdata, "fruit", category_data) + assert markers == marker_dict + + +# Test _series_to_rgb_array +def test_series_to_rgb_discrete(category_data): + colors = {"apple": "#ff0000ff", "banana": "#ffff00ff", "cherry": "#ff69b4ff"} + rgb_array = _series_to_rgb_array(category_data, colors) + expected = np.array([[1, 0, 0], [1, 1, 0], [1, 0.41176471, 0.70588235]]) + np.testing.assert_almost_equal(rgb_array, expected, decimal=2) + # Test with missing data + category_data = pd.Series(["apple", pd.NA]) + rgb_array = _series_to_rgb_array(category_data, colors) + expected = np.array([[1, 0, 0], [0.5, 0.5, 0.5]]) + np.testing.assert_almost_equal(rgb_array, expected, decimal=2) + + +def test_series_to_rgb_numeric(): + numeric_data = pd.Series([0, 1, 2]) + colors = mcolors.ListedColormap(["red", "yellow", "blue"]) + rgb_array = _series_to_rgb_array(numeric_data, colors, vmin=0, vmax=2) + expected = np.array([[1, 0, 0], [1, 1, 0], [0, 0, 1]]) + np.testing.assert_almost_equal(rgb_array, expected, decimal=2) + # Test with missing data + numeric_data = pd.Series([0, np.nan, 2]) + rgb_array = _series_to_rgb_array(numeric_data, colors, vmin=0, vmax=2) + expected = np.array([[1, 0, 0], [0.5, 0.5, 0.5], [0, 0, 1]]) + np.testing.assert_almost_equal(rgb_array, expected, decimal=2) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100755 index 0000000..d1f936c --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,66 @@ +import networkx as nx +import pandas as pd +import pytest + +from pycea.utils import get_keyed_edge_data, get_keyed_node_data, get_keyed_obs_data, get_root + + +@pytest.fixture +def tree(): + t = nx.DiGraph() + t.add_edges_from([("A", "B"), ("A", "C"), ("B", "D"), ("C", "E")]) + nx.set_node_attributes(t, {"A": 1, "B": 2, "C": None, "D": 4, "E": 5}, "value") + nx.set_edge_attributes(t, {("A", "B"): 5, ("A", "C"): None, ("B", "D"): 3, ("C", "E"): 4}, "weight") + yield t + + +def test_get_root(tree): + # Test with an empty graph + assert get_root(nx.DiGraph()) is None + # Test with a non-empty graph + assert get_root(tree) == "A" + # Test with a single node + single_node_tree = nx.DiGraph() + single_node_tree.add_node("A") + assert get_root(single_node_tree) == "A" + + +def test_get_keyed_edge_data(tree): + result = get_keyed_edge_data(tree, "weight") + expected_keys = [("A", "B"), ("B", "D"), ("C", "E")] + expected_values = [5, 3, 4] + assert all(result[key] == value for key, value in zip(expected_keys, expected_values)) + assert ("A", "C") not in result + + +def test_get_keyed_node_data(tree): + result = get_keyed_node_data(tree, "value") + expected_keys = ["A", "B", "D", "E"] + expected_values = [1, 2, 4, 5] + assert all(result[key] == value for key, value in zip(expected_keys, expected_values)) + assert "C" not in result + + +def test_get_keyed_obs_data_valid_keys(tdata): + data, is_array = get_keyed_obs_data(tdata, ["clade", "x", "0"]) + assert not is_array + assert data.columns.tolist() == ["clade", "x", "0"] + # Automatically converts object columns to category + assert data["clade"].dtype == "category" + assert tdata.obs["clade"].dtype == "category" + + +def test_get_keyed_obs_data_array(tdata): + data, is_array = get_keyed_obs_data(tdata, ["spatial"]) + assert is_array + assert isinstance(data, pd.DataFrame) + assert data.shape[1] == 2 + data, is_array = get_keyed_obs_data(tdata, ["spatial_distance"]) + assert data.shape == (tdata.n_obs, tdata.n_obs) + + +def test_get_keyed_obs_data_invalid_keys(tdata): + with pytest.raises(ValueError): + get_keyed_obs_data(tdata, ["clade", "x", "0", "invalid_key"]) + with pytest.raises(ValueError): + get_keyed_obs_data(tdata, ["clade", "spatial_distance"])