Skip to content

Commit

Permalink
fix: make tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Dec 3, 2024
1 parent 0a608a5 commit 40b2ba6
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 14 deletions.
5 changes: 3 additions & 2 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2116,10 +2116,11 @@ def vis(
type: The type of plot. One of ["line", "scatter", "comp", "morph"].
kwargs: Keyword arguments passed to the plotting function.
"""
res = 100 if "resolution" not in kwargs else kwargs.pop("resolution")
if "comp" in type.lower():
return plot_comps(self, dims=dims, ax=ax, color=color, **kwargs)
return plot_comps(self, dims=dims, ax=ax, color=color, resolution=res, **kwargs)
if "morph" in type.lower():
return plot_morph(self, dims=dims, ax=ax, color=color, **kwargs)
return plot_morph(self, dims=dims, ax=ax, color=color, resolution=res, **kwargs)

assert not np.any(
[np.isnan(xyzr[:, dims]).all() for xyzr in self.xyzr]
Expand Down
1 change: 1 addition & 0 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ def vis(
cell = nodes.loc[comp, "global_cell_index"]
branch_xyz = self.xyzr[branch]

xyz_loc = branch_xyz
if detail == "point":
xyz_loc = np.mean(self.cell(cell).xyzr[0], axis=0)
elif len(branch_xyz) == 2:
Expand Down
28 changes: 16 additions & 12 deletions tests/test_plotting_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def test_cell(SimpleMorphCell):
# Plot 1.
_, ax = plt.subplots(1, 1, figsize=(3, 3))
ax = cell.vis(ax=ax)
ax = cell.branch([0, 1, 2]).vis(ax=ax, col="r")
ax = cell.branch(1).loc(0.9).vis(ax=ax, col="b")
ax = cell.branch([0, 1, 2]).vis(ax=ax, color="r")
ax = cell.branch(1).loc(0.9).vis(ax=ax, color="b")

# Plot 2.
cell.branch(0).add_to_group("soma")
Expand Down Expand Up @@ -59,9 +59,9 @@ def test_network(SimpleMorphCell):
# Plot 1.
_, ax = plt.subplots(1, 1, figsize=(3, 3))
ax = net.cell([0, 1]).vis(ax=ax)
ax = net.cell(2).vis(ax=ax, col="r", type="line")
ax = net.cell(2).vis(ax=ax, col="r", type="scatter")
ax = net.cell(0).branch(np.arange(10).tolist()).vis(ax=ax, col="b")
ax = net.cell(2).vis(ax=ax, color="r", type="line")
ax = net.cell(2).vis(ax=ax, color="r", type="scatter")
ax = net.cell(0).branch(np.arange(10).tolist()).vis(ax=ax, color="b")

# Plot 2.
ax = net.vis(detail="full", type="line")
Expand All @@ -71,10 +71,12 @@ def test_network(SimpleMorphCell):
net.vis(detail="point")

# Plot 4.
net.vis(detail="point", layers=[2, 1])
net.arrange_in_layers([2, 1])
net.vis(detail="point")

# Plot 5.
net.vis(detail="full", layers=[2, 1])
net.arrange_in_layers([2, 1])
net.vis(detail="full")

# Plot 5.
net.cell(0).add_to_group("excitatory")
Expand Down Expand Up @@ -174,18 +176,20 @@ def test_volume_plotting(
morph_cell = SimpleMorphCell(fname, ncomp=1)

fig, ax = plt.subplots()
for module in [comp, branch, cell, net, morph_cell]:
module.vis(type="comp", ax=ax, morph_plot_kwargs={"resolution": 6})
for module in [comp, branch, cell, morph_cell]:
module.vis(type="comp", ax=ax, resolution=6)
net.vis(type="comp", ax=ax, cell_plot_kwargs={"resolution": 6})
plt.close(fig)

# test 3D plotting
for module in [comp, branch, cell, net, morph_cell]:
module.vis(type="comp", dims=[0, 1, 2], morph_plot_kwargs={"resolution": 6})
for module in [comp, branch, cell, morph_cell]:
module.vis(type="comp", dims=[0, 1, 2], resolution=6)
net.vis(type="comp", dims=[0, 1, 2], cell_plot_kwargs={"resolution": 6})
plt.close()

# test morph plotting (does not work if no radii in xyzr)
morph_cell.branch(1).vis(type="morph")
morph_cell.branch(1).vis(
type="morph", dims=[0, 1, 2], morph_plot_kwargs={"resolution": 6}
type="morph", dims=[0, 1, 2], resolution=6
) # plotting whole thing takes too long
plt.close()

0 comments on commit 40b2ba6

Please sign in to comment.