From a6415d09cbc655efae4dcfe9005cee4c957ae7ef Mon Sep 17 00:00:00 2001 From: colganwi Date: Mon, 20 May 2024 09:56:08 -0400 Subject: [PATCH] annotation borders --- src/pycea/pl/plot_tree.py | 33 ++++++++++++++++++++++++--------- tests/test_plot_tree.py | 18 +++++++++--------- 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/src/pycea/pl/plot_tree.py b/src/pycea/pl/plot_tree.py index 606d897..bd75a97 100644 --- a/src/pycea/pl/plot_tree.py +++ b/src/pycea/pl/plot_tree.py @@ -29,7 +29,7 @@ ) def branches( tdata: td.TreeData, - key: str = None, + keys: str | Sequence[str] = None, polar: bool = False, extend_branches: bool = False, angled_branches: bool = False, @@ -49,8 +49,8 @@ def branches( ---------- tdata The `treedata.TreeData` object. - key - The `obst` key of the tree to plot. + keys + The `obst` key or keys of the trees to plot. polar Whether to plot the tree in polar coordinates. extend_branches @@ -80,8 +80,12 @@ def branches( warnings.warn("Polar setting of axes does not match requested type. Creating new axes.", stacklevel=2) fig, ax = plt.subplots(subplot_kw={"projection": "polar"} if polar else None) kwargs = kwargs if kwargs else {} - if not key: + if not keys: key = next(iter(tdata.obst.keys())) + elif isinstance(keys, str): + key = keys + else: + raise ValueError("Passing a list of keys not implemented. Please pass a single key.") tree = tdata.obst[key] # Get layout node_coords, branch_coords, leaves, depth = layout_tree( @@ -308,6 +312,7 @@ def annotation( width: int | float = 0.05, gap: int | float = 0.03, label: bool | str | Sequence[str] = True, + border_width: int | float = 0, cmap: str | mcolors.Colormap = None, palette: cycler.Cycler | mcolors.ListedColormap | Sequence[str] | Mapping[str] | None = None, vmax: int | float | None = None, @@ -332,6 +337,8 @@ def annotation( label Annotation labels. If `True`, the keys are used as labels. If a string or a sequence of strings, the strings are used as labels. + border_width + The width of the border around the annotation bar. {common_plot_args} na_color The color to use for annotations with missing data. @@ -398,9 +405,17 @@ def annotation( if attrs["polar"]: ax.pcolormesh(lons, lats, rgb_array.swapaxes(0, 1), zorder=2, **kwargs) ax.set_ylim(-attrs["depth"] * 0.05, end_lat) + # ax.plot([0, np.pi, np.pi, 0, 0], [start_lat, start_lat, end_lat, end_lat, start_lat], color="black") else: ax.pcolormesh(lats, lons, rgb_array, zorder=2, **kwargs) - ax.set_xlim(-attrs["depth"] * 0.05, end_lat) + ax.set_xlim(-attrs["depth"] * 0.05, end_lat + attrs["depth"] * 0.05) + # Add border + ax.plot( + [lats[0], lats[0], lats[-1], lats[-1], lats[0]], + [lons[0], lons[-1], lons[-1], lons[0], lons[0]], + color="black", + linewidth=border_width, + ) # Add labels if labels and len(labels) > 0: labels_lats = np.linspace(start_lat, end_lat, len(labels) + 1) @@ -427,7 +442,7 @@ def annotation( ) def tree( tdata: td.TreeData, - key: str = None, + keys: str | Sequence[str] = None, nodes: str | Sequence[str] = None, annotation_keys: str | Sequence[str] = None, polar: bool = False, @@ -451,8 +466,8 @@ def tree( ---------- tdata The TreeData object. - key - The `obst` key of the tree to plot. + keys + The `obst` key or keys of the trees to plot. nodes Either "all", "leaves", "internal", or a list of nodes to plot. annotation_keys @@ -484,7 +499,7 @@ def tree( # Plot branches ax = _branches( tdata, - key=key, + keys=keys, polar=polar, extend_branches=extend_branches, angled_branches=angled_branches, diff --git a/tests/test_plot_tree.py b/tests/test_plot_tree.py index fc5c2fd..0aeec21 100755 --- a/tests/test_plot_tree.py +++ b/tests/test_plot_tree.py @@ -10,7 +10,7 @@ def test_polar_with_clades(tdata): fig, ax = plt.subplots(dpi=300, subplot_kw={"polar": True}) - pycea.pl.branches(tdata, key="tree", polar=True, color="clade", palette="Set1", na_color="black", ax=ax) + pycea.pl.branches(tdata, keys="tree", polar=True, color="clade", palette="Set1", na_color="black", ax=ax) pycea.pl.nodes(tdata, color="clade", palette="Set1", style="clade", ax=ax) pycea.pl.annotation(tdata, keys="clade", ax=ax) plt.savefig(plot_path / "polar_clades.png") @@ -19,11 +19,11 @@ def test_polar_with_clades(tdata): def test_angled_numeric_annotations(tdata): pycea.pl.branches( - tdata, key="tree", polar=False, color="length", cmap="hsv", linewidth="length", angled_branches=True + tdata, keys="tree", polar=False, color="length", cmap="hsv", linewidth="length", angled_branches=True ) pycea.pl.nodes(tdata, nodes="all", color="time", style="s", size=20) - pycea.pl.annotation(tdata, keys=["x", "y"], cmap="magma", width=0.1, gap=0.05) - pycea.pl.annotation(tdata, keys=["0", "1", "2", "3", "4", "5"], label="genes") + pycea.pl.annotation(tdata, keys=["x", "y"], cmap="magma", width=0.1, gap=0.05, border_width=2) + pycea.pl.annotation(tdata, keys=["0", "1", "2", "3", "4", "5"], label="genes", border_width=2) plt.savefig(plot_path / "angled_numeric.png") plt.close() @@ -32,7 +32,7 @@ def test_matrix_annotation(tdata): fig, ax = plt.subplots(dpi=300) pycea.pl.tree( tdata, - key="tree", + keys="tree", nodes="internal", node_color="clade", node_size="time", @@ -46,12 +46,12 @@ def test_matrix_annotation(tdata): def test_branches_bad_input(tdata): fig, ax = plt.subplots() with pytest.raises(ValueError): - pycea.pl.branches(tdata, key="tree", color=["bad"] * 5) + pycea.pl.branches(tdata, keys="tree", color=["bad"] * 5) with pytest.raises(ValueError): - pycea.pl.branches(tdata, key="tree", linewidth=["bad"] * 5) + pycea.pl.branches(tdata, keys="tree", linewidth=["bad"] * 5) # Warns about polar with pytest.warns(match="Polar"): - pycea.pl.branches(tdata, key="tree", polar=True, ax=ax) + pycea.pl.branches(tdata, keys="tree", polar=True, ax=ax) plt.close() @@ -60,7 +60,7 @@ def test_annotation_bad_input(tdata): fig, ax = plt.subplots() with pytest.raises(ValueError): pycea.pl.annotation(tdata, keys="clade") - pycea.pl.branches(tdata, key="tree", ax=ax) + pycea.pl.branches(tdata, keys="tree", ax=ax) with pytest.raises(ValueError): pycea.pl.annotation(tdata, keys=None, ax=ax) with pytest.raises(ValueError):