diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 278a7e68..97285f70 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -664,11 +664,13 @@ def vis( def _vis(self, ax, col, dims, view, morph_plot_kwargs): branches_inds = view["branch_index"].to_numpy() - coords = [self.xyzr[branch_ind] for branch_ind in branches_inds] + coords = [] + for branch_ind in branches_inds: + assert not np.any( + np.isnan(self.xyzr[branch_ind][:, dims]) + ), "No coordinates available. Use `vis(detail='point')` or run `.compute_xyz()` before running `.vis()`." + coords.append(self.xyzr[branch_ind]) - assert not np.any( - np.isnan(np.asarray(self.xyzr)[:, :, dims]) - ), "No coordinates available. Use `vis(detail='point')` or run `.compute_xyz()` before running `.vis()`." ax = plot_morph( coords, dims=dims, diff --git a/tests/test_plotting_api.py b/tests/test_plotting_api.py index f37dd383..d0dbe0c0 100644 --- a/tests/test_plotting_api.py +++ b/tests/test_plotting_api.py @@ -19,13 +19,13 @@ def test_cell(): # Plot 1. _, ax = plt.subplots(1, 1, figsize=(3, 3)) - ax = cell.vis(detail="full", ax=ax) - ax = cell.branch([0, 1, 2]).vis(detail="full", ax=ax, col="r") + ax = cell.vis(ax=ax) + ax = cell.branch([0, 1, 2]).vis(ax=ax, col="r") # Plot 2. cell.branch(0).add_to_group("soma") cell.branch(1).add_to_group("soma") - ax = cell.soma.vis(detail="full") + ax = cell.soma.vis() def test_network(): @@ -45,9 +45,9 @@ def test_network(): # Plot 1. _, ax = plt.subplots(1, 1, figsize=(3, 3)) - ax = net.cell([0, 1]).vis(detail="full", ax=ax) - ax = net.cell(2).vis(detail="full", ax=ax, col="r") - ax = net.cell(0).branch(np.arange(10).tolist()).vis(detail="full", ax=ax, col="b") + ax = net.cell([0, 1]).vis(ax=ax) + ax = net.cell(2).vis(ax=ax, col="r") + ax = net.cell(0).branch(np.arange(10).tolist()).vis(ax=ax, col="b") # Plot 2. ax = net.vis(detail="full") @@ -61,7 +61,7 @@ def test_network(): # Plot 5. net.cell(0).add_to_group("excitatory") net.cell(1).add_to_group("excitatory") - ax = net.excitatory.vis(detail="full") + ax = net.excitatory.vis() def test_vis_networks_built_from_scartch():