Skip to content

Commit

Permalink
kwargs to modify the plot
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Dec 8, 2023
1 parent 583d75d commit 7f6282b
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 14 deletions.
19 changes: 16 additions & 3 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,7 @@ def vis(
ax=None,
col: str = "k",
dims: Tuple[int] = (0, 1),
morph_plot_kwargs: Dict = {},
) -> None:
"""Visualize the module.
Expand All @@ -645,10 +646,15 @@ def vis(
two of them.
"""
return self._vis(
detail=detail, dims=dims, col=col, ax=ax, view=self.nodes
detail=detail,
dims=dims,
col=col,
ax=ax,
view=self.nodes,
morph_plot_kwargs=morph_plot_kwargs,
)

def _vis(self, detail, ax, col, dims, view):
def _vis(self, detail, ax, col, dims, view, morph_plot_kwargs):
branches_inds = view["branch_index"].to_numpy()
coords = [self.xyzr[branch_ind] for branch_ind in branches_inds]

Expand All @@ -659,6 +665,7 @@ def _vis(self, detail, ax, col, dims, view):
dims=dims,
col=col,
ax=ax,
morph_plot_kwargs=morph_plot_kwargs,
)
# elif detail == "sticks":
# fig, ax = plot_morph(
Expand Down Expand Up @@ -770,10 +777,16 @@ def vis(
ax=None,
col="k",
dims=(0, 1),
morph_plot_kwargs: Dict = {},
):
nodes = self.set_global_index_and_index(self.view)
return self.pointer._vis(
detail=detail, ax=ax, col=col, dims=dims, view=nodes
detail=detail,
ax=ax,
col=col,
dims=dims,
view=nodes,
morph_plot_kwargs=morph_plot_kwargs,
)

def adjust_view(self, key: str, index: float):
Expand Down
28 changes: 24 additions & 4 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,12 @@ def vis(
detail: str = "full",
ax=None,
col="k",
synapse_col="b",
dims=(0, 1),
layers: Optional[List] = None,
morph_plot_kwargs: Dict = {},
synapse_plot_kwargs: Dict = {},
synapse_scatter_kwargs: Dict = {},
) -> None:
"""Visualize the module.
Expand All @@ -349,9 +353,14 @@ def vis(
nx.draw(graph, with_labels=True)
else:
ax = self._vis(
detail=detail, dims=dims, col=col, ax=ax, view=self.nodes
detail=detail,
dims=dims,
col=col,
ax=ax,
view=self.nodes,
morph_plot_kwargs=morph_plot_kwargs,
)

pre_locs = self.syn_edges["pre_locs"].to_numpy()
post_locs = self.syn_edges["post_locs"].to_numpy()
pre_branch = self.syn_edges["pre_branch_index"].to_numpy()
Expand All @@ -371,8 +380,19 @@ def vis(
middle_ind = int((len(post_coord) - 1) * post_loc)
post_coord = post_coord[middle_ind]
coords = np.stack([pre_coord[dims_np], post_coord[dims_np]]).T
ax.plot(coords[0], coords[1], linewidth=3.0, c="b")
ax.scatter(post_coord[dims_np[0]], post_coord[dims_np[1]], c="b")
ax.plot(
coords[0],
coords[1],
linewidth=3.0,
c=synapse_col,
**synapse_plot_kwargs,
)
ax.scatter(
post_coord[dims_np[0]],
post_coord[dims_np[1]],
c=synapse_col,
**synapse_scatter_kwargs,
)

return ax

Expand Down
13 changes: 6 additions & 7 deletions jaxley/utils/plot_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Dict

import matplotlib.pyplot as plt
import numpy as np

Expand Down Expand Up @@ -84,12 +86,7 @@ def plot_morph(
return fig, ax


def plot_swc(
xyzr,
dims=(0, 1),
col="k",
ax=None,
):
def plot_swc(xyzr, dims=(0, 1), col="k", ax=None, morph_plot_kwargs: Dict = None):
"""Plot morphology given an SWC file.
Args:
Expand All @@ -104,6 +101,8 @@ def plot_swc(
for coords_of_branch in xyzr:
coords_to_plot = coords_of_branch[:, dims]

_ = ax.plot(coords_to_plot[:, 0], coords_to_plot[:, 1], c=col)
_ = ax.plot(
coords_to_plot[:, 0], coords_to_plot[:, 1], c=col, **morph_plot_kwargs
)

return ax

0 comments on commit 7f6282b

Please sign in to comment.