diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index e79bee3e..76d406b2 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -632,6 +632,7 @@ def vis( ax=None, col: str = "k", dims: Tuple[int] = (0, 1), + morph_plot_kwargs: Dict = {}, ) -> None: """Visualize the module. @@ -645,10 +646,15 @@ def vis( two of them. """ return self._vis( - detail=detail, dims=dims, col=col, ax=ax, view=self.nodes + detail=detail, + dims=dims, + col=col, + ax=ax, + view=self.nodes, + morph_plot_kwargs=morph_plot_kwargs, ) - def _vis(self, detail, ax, col, dims, view): + def _vis(self, detail, ax, col, dims, view, morph_plot_kwargs): branches_inds = view["branch_index"].to_numpy() coords = [self.xyzr[branch_ind] for branch_ind in branches_inds] @@ -659,6 +665,7 @@ def _vis(self, detail, ax, col, dims, view): dims=dims, col=col, ax=ax, + morph_plot_kwargs=morph_plot_kwargs, ) # elif detail == "sticks": # fig, ax = plot_morph( @@ -770,10 +777,16 @@ def vis( ax=None, col="k", dims=(0, 1), + morph_plot_kwargs: Dict = {}, ): nodes = self.set_global_index_and_index(self.view) return self.pointer._vis( - detail=detail, ax=ax, col=col, dims=dims, view=nodes + detail=detail, + ax=ax, + col=col, + dims=dims, + view=nodes, + morph_plot_kwargs=morph_plot_kwargs, ) def adjust_view(self, key: str, index: float): diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 55320a71..a0245b0d 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -321,8 +321,12 @@ def vis( detail: str = "full", ax=None, col="k", + synapse_col="b", dims=(0, 1), layers: Optional[List] = None, + morph_plot_kwargs: Dict = {}, + synapse_plot_kwargs: Dict = {}, + synapse_scatter_kwargs: Dict = {}, ) -> None: """Visualize the module. @@ -349,9 +353,14 @@ def vis( nx.draw(graph, with_labels=True) else: ax = self._vis( - detail=detail, dims=dims, col=col, ax=ax, view=self.nodes + detail=detail, + dims=dims, + col=col, + ax=ax, + view=self.nodes, + morph_plot_kwargs=morph_plot_kwargs, ) - + pre_locs = self.syn_edges["pre_locs"].to_numpy() post_locs = self.syn_edges["post_locs"].to_numpy() pre_branch = self.syn_edges["pre_branch_index"].to_numpy() @@ -371,8 +380,19 @@ def vis( middle_ind = int((len(post_coord) - 1) * post_loc) post_coord = post_coord[middle_ind] coords = np.stack([pre_coord[dims_np], post_coord[dims_np]]).T - ax.plot(coords[0], coords[1], linewidth=3.0, c="b") - ax.scatter(post_coord[dims_np[0]], post_coord[dims_np[1]], c="b") + ax.plot( + coords[0], + coords[1], + linewidth=3.0, + c=synapse_col, + **synapse_plot_kwargs, + ) + ax.scatter( + post_coord[dims_np[0]], + post_coord[dims_np[1]], + c=synapse_col, + **synapse_scatter_kwargs, + ) return ax diff --git a/jaxley/utils/plot_utils.py b/jaxley/utils/plot_utils.py index d25ed2ca..4edd5a7e 100644 --- a/jaxley/utils/plot_utils.py +++ b/jaxley/utils/plot_utils.py @@ -1,3 +1,5 @@ +from typing import Dict + import matplotlib.pyplot as plt import numpy as np @@ -84,12 +86,7 @@ def plot_morph( return fig, ax -def plot_swc( - xyzr, - dims=(0, 1), - col="k", - ax=None, -): +def plot_swc(xyzr, dims=(0, 1), col="k", ax=None, morph_plot_kwargs: Dict = None): """Plot morphology given an SWC file. Args: @@ -104,6 +101,8 @@ def plot_swc( for coords_of_branch in xyzr: coords_to_plot = coords_of_branch[:, dims] - _ = ax.plot(coords_to_plot[:, 0], coords_to_plot[:, 1], c=col) + _ = ax.plot( + coords_to_plot[:, 0], coords_to_plot[:, 1], c=col, **morph_plot_kwargs + ) return ax