Skip to content

Commit

Permalink
A basic version of Compartment plotting is working
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Jan 21, 2024
1 parent 4b16f50 commit 2f6a20d
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 6 deletions.
30 changes: 30 additions & 0 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,36 @@ def _vis(self, ax, col, dims, view, morph_plot_kwargs):
dims=dims,
col=col,
ax=ax,
type="plot",
morph_plot_kwargs=morph_plot_kwargs,
)

return ax

def _scatter(self, ax, col, dims, view, morph_plot_kwargs):
"""Scatter visualization (used for compartments)."""
assert len(view) == 1, "Scatter only deals with compartments."
branch_ind = view["branch_index"].to_numpy().item()
comp_ind = view["comp_index"].to_numpy().item()
assert not np.any(
np.isnan(self.xyzr[branch_ind][:, dims])
), "No coordinates available. Use `vis(detail='point')` or run `.compute_xyz()` before running `.vis()`."

comp_fraction = comp_ind / self.nseg
coords = self.xyzr[branch_ind]
interpolated_loc_x = np.interp(
comp_fraction, np.linspace(0, 1, len(coords)), coords[:, dims[0]]
)
interpolated_loc_y = np.interp(
comp_fraction, np.linspace(0, 1, len(coords)), coords[:, dims[1]]
)

ax = plot_morph(
np.asarray([[[interpolated_loc_x, interpolated_loc_y]]]),
dims=dims,
col=col,
ax=ax,
type="scatter",
morph_plot_kwargs=morph_plot_kwargs,
)

Expand Down
16 changes: 16 additions & 0 deletions jaxley/modules/compartment.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,19 @@ def connect(self, post: "CompartmentView", synapse_type):
self.pointer.synapse_param_names += list(synapse_type.synapse_params.keys())
self.pointer.synapse_state_names += list(synapse_type.synapse_states.keys())
self.pointer.synapses.append(synapse_type)

def vis(
self,
ax=None,
col="k",
dims=(0, 1),
morph_plot_kwargs: Dict = {},
):
nodes = self.set_global_index_and_index(self.view)
return self.pointer._scatter(
ax=ax,
col=col,
dims=dims,
view=nodes,
morph_plot_kwargs=morph_plot_kwargs,
)
28 changes: 22 additions & 6 deletions jaxley/utils/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,39 @@
import matplotlib.pyplot as plt


def plot_morph(xyzr, dims=(0, 1), col="k", ax=None, morph_plot_kwargs: Dict = None):
def plot_morph(
xyzr,
dims=(0, 1),
col="k",
ax=None,
type: str = "plot",
morph_plot_kwargs: Dict = None,
):
"""Plot morphology.
Args:
dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of
two of them.
cols: The color for all branches except the highlighted ones.
highlight_branch_inds: Branch indices that will be highlighted.
type: Either `plot` or `scatter`.
col: The color for all branches.
"""

if ax is None:
_, ax = plt.subplots(1, 1, figsize=(3, 3))
for coords_of_branch in xyzr:
coords_to_plot = coords_of_branch[:, dims]
x_coords = coords_to_plot[:, 0]
y_coords = coords_to_plot[:, 1]

_ = ax.plot(
coords_to_plot[:, 0], coords_to_plot[:, 1], c=col, **morph_plot_kwargs
)
if type == "plot":
_ = ax.plot(
coords_to_plot[:, 0], coords_to_plot[:, 1], c=col, **morph_plot_kwargs
)
elif type == "scatter":
_ = ax.scatter(
coords_to_plot[:, 0], coords_to_plot[:, 1], c=col, **morph_plot_kwargs
)
else:
raise NotImplementedError

return ax

0 comments on commit 2f6a20d

Please sign in to comment.