Skip to content

Commit

Permalink
Clean up and test
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Jan 21, 2024
1 parent 2f6a20d commit 1f56fd6
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
20 changes: 13 additions & 7 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
_compute_index_of_child,
_compute_num_children,
compute_levels,
loc_of_index,
)
from jaxley.utils.plot_utils import plot_morph

Expand Down Expand Up @@ -648,25 +649,30 @@ def _vis(self, ax, col, dims, view, morph_plot_kwargs):
return ax

def _scatter(self, ax, col, dims, view, morph_plot_kwargs):
"""Scatter visualization (used for compartments)."""
"""Scatter visualization (used only for compartments)."""
assert len(view) == 1, "Scatter only deals with compartments."
branch_ind = view["branch_index"].to_numpy().item()
comp_ind = view["comp_index"].to_numpy().item()
assert not np.any(
np.isnan(self.xyzr[branch_ind][:, dims])
), "No coordinates available. Use `vis(detail='point')` or run `.compute_xyz()` before running `.vis()`."

comp_fraction = comp_ind / self.nseg
comp_fraction = loc_of_index(comp_ind, self.nseg)
coords = self.xyzr[branch_ind]
interpolated_loc_x = np.interp(
comp_fraction, np.linspace(0, 1, len(coords)), coords[:, dims[0]]

# Perform a linear interpolation between coordinates to get the location.
interp_loc_x = np.interp(
comp_fraction, np.linspace(0, 1, len(coords)), coords[:, 0]
)
interp_loc_y = np.interp(
comp_fraction, np.linspace(0, 1, len(coords)), coords[:, 1]
)
interpolated_loc_y = np.interp(
comp_fraction, np.linspace(0, 1, len(coords)), coords[:, dims[1]]
interp_loc_z = np.interp(
comp_fraction, np.linspace(0, 1, len(coords)), coords[:, 2]
)

ax = plot_morph(
np.asarray([[[interpolated_loc_x, interpolated_loc_y]]]),
np.asarray([[[interp_loc_x, interp_loc_y, interp_loc_z]]]),
dims=dims,
col=col,
ax=ax,
Expand Down
3 changes: 1 addition & 2 deletions jaxley/modules/compartment.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,10 @@ def vis(
dims=(0, 1),
morph_plot_kwargs: Dict = {},
):
nodes = self.set_global_index_and_index(self.view)
return self.pointer._scatter(
ax=ax,
col=col,
dims=dims,
view=nodes,
view=self.view,
morph_plot_kwargs=morph_plot_kwargs,
)
1 change: 1 addition & 0 deletions tests/test_plotting_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def test_cell():
_, ax = plt.subplots(1, 1, figsize=(3, 3))
ax = cell.vis(ax=ax)
ax = cell.branch([0, 1, 2]).vis(ax=ax, col="r")
ax = cell.branch(1).comp(0.9).vis(ax=ax, col="b")

# Plot 2.
cell.branch(0).add_to_group("soma")
Expand Down

0 comments on commit 1f56fd6

Please sign in to comment.