Skip to content

Commit

Permalink
bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Dec 8, 2023
1 parent 43adda6 commit 089cd07
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
10 changes: 6 additions & 4 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,11 +664,13 @@ def vis(

def _vis(self, ax, col, dims, view, morph_plot_kwargs):
branches_inds = view["branch_index"].to_numpy()
coords = [self.xyzr[branch_ind] for branch_ind in branches_inds]
coords = []
for branch_ind in branches_inds:
assert not np.any(
np.isnan(self.xyzr[branch_ind][:, dims])
), "No coordinates available. Use `vis(detail='point')` or run `.compute_xyz()` before running `.vis()`."
coords.append(self.xyzr[branch_ind])

assert not np.any(
np.isnan(np.asarray(self.xyzr)[:, :, dims])
), "No coordinates available. Use `vis(detail='point')` or run `.compute_xyz()` before running `.vis()`."
ax = plot_morph(
coords,
dims=dims,
Expand Down
14 changes: 7 additions & 7 deletions tests/test_plotting_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ def test_cell():

# Plot 1.
_, ax = plt.subplots(1, 1, figsize=(3, 3))
ax = cell.vis(detail="full", ax=ax)
ax = cell.branch([0, 1, 2]).vis(detail="full", ax=ax, col="r")
ax = cell.vis(ax=ax)
ax = cell.branch([0, 1, 2]).vis(ax=ax, col="r")

# Plot 2.
cell.branch(0).add_to_group("soma")
cell.branch(1).add_to_group("soma")
ax = cell.soma.vis(detail="full")
ax = cell.soma.vis()


def test_network():
Expand All @@ -45,9 +45,9 @@ def test_network():

# Plot 1.
_, ax = plt.subplots(1, 1, figsize=(3, 3))
ax = net.cell([0, 1]).vis(detail="full", ax=ax)
ax = net.cell(2).vis(detail="full", ax=ax, col="r")
ax = net.cell(0).branch(np.arange(10).tolist()).vis(detail="full", ax=ax, col="b")
ax = net.cell([0, 1]).vis(ax=ax)
ax = net.cell(2).vis(ax=ax, col="r")
ax = net.cell(0).branch(np.arange(10).tolist()).vis(ax=ax, col="b")

# Plot 2.
ax = net.vis(detail="full")
Expand All @@ -61,7 +61,7 @@ def test_network():
# Plot 5.
net.cell(0).add_to_group("excitatory")
net.cell(1).add_to_group("excitatory")
ax = net.excitatory.vis(detail="full")
ax = net.excitatory.vis()


def test_vis_networks_built_from_scartch():
Expand Down

0 comments on commit 089cd07

Please sign in to comment.