diff --git a/jaxley/utils/plot_utils.py b/jaxley/utils/plot_utils.py index af5dc20d..f91071a9 100644 --- a/jaxley/utils/plot_utils.py +++ b/jaxley/utils/plot_utils.py @@ -103,6 +103,7 @@ def create_cone_frustum_mesh( radius_top: float, bottom_dome: bool = False, top_dome: bool = False, + resolution: int = 100, ) -> ndarray: """Generates mesh points for a cone frustum, with optional domes at either end. @@ -120,12 +121,14 @@ def create_cone_frustum_mesh( The dome is a hemisphere with radius `radius_bottom`. top_dome: If True, a dome is added to the top of the frustum. The dome is a hemisphere with radius `radius_top`. + resolution: defines the resolution of the mesh. + If too low (typically <10), can result in errors. + Useful too have a simpler mesh for plotting. Returns: An array of mesh points. """ - resolution = 100 t = np.linspace(0, 2 * np.pi, resolution) # Determine the total height including domes @@ -175,7 +178,9 @@ def create_cone_frustum_mesh( return np.stack([x_coords, y_coords, z_coords]) -def create_cylinder_mesh(length: float, radius: float) -> ndarray: +def create_cylinder_mesh( + length: float, radius: float, resolution: int = 100 +) -> ndarray: """Generates mesh points for a cylinder. This is used to render cylindrical compartments in 3D (and to project it to 2D) @@ -184,12 +189,14 @@ def create_cylinder_mesh(length: float, radius: float) -> ndarray: Args: length: The length of the cylinder. radius: The radius of the cylinder. + resolution: defines the resolution of the mesh. + If too low (typically <10), can result in errors. + Useful too have a simpler mesh for plotting. Returns: An array of mesh points. """ # Define cylinder - resolution = 100 t = np.linspace(0, 2 * np.pi, resolution) z_coords = np.linspace(-length / 2, length / 2, resolution) t_grid, z_coords = np.meshgrid(t, z_coords) @@ -199,7 +206,7 @@ def create_cylinder_mesh(length: float, radius: float) -> ndarray: return np.stack([x_coords, y_coords, z_coords]) -def create_sphere_mesh(radius: float) -> np.ndarray: +def create_sphere_mesh(radius: float, resolution: int = 100) -> np.ndarray: """Generates mesh points for a sphere. This is used to render spherical compartments in 3D (and to project it to 2D) @@ -207,11 +214,13 @@ def create_sphere_mesh(radius: float) -> np.ndarray: Args: radius: The radius of the sphere. + resolution: defines the resolution of the mesh. + If too low (typically <10), can result in errors. + Useful too have a simpler mesh for plotting. Returns: An array of mesh points. """ - resolution = 100 phi = np.linspace(0, np.pi, resolution) theta = np.linspace(0, 2 * np.pi, resolution) @@ -302,8 +311,9 @@ def plot_comps( ax: Optional[Axes] = None, comp_plot_kwargs: Dict = {}, true_comp_length: bool = True, + resolution: int = 100, ) -> Axes: - """Plot compartmentalized neural mrophology. + """Plot compartmentalized neural morphology. Plots the projection of the cylindrical compartments. @@ -320,6 +330,9 @@ def plot_comps( start and end point of the neurite. This can lead to overlapping and miss-aligned cylinders. Setting this False will use the straight-line distance instead for nicer plots. + resolution: defines the resolution of the mesh. + If too low (typically <10), can result in errors. + Useful too have a simpler mesh for plotting. Returns: Plot of the compartmentalized morphology. @@ -340,7 +353,7 @@ def plot_comps( radius = xyzr[:, -1] center = xyzr[0, :3] if len(dims) == 3: - xyz = create_sphere_mesh(radius) + xyz = create_sphere_mesh(radius, resolution) ax = plot_mesh( xyz, np.array([0, 0, 1]), @@ -368,7 +381,7 @@ def plot_comps( center = comp[["x", "y", "z"]] radius = comp["radius"] length = comp["length"] if true_comp_length else l - xyz = create_cylinder_mesh(length, radius) + xyz = create_cylinder_mesh(length, radius, resolution) ax = plot_mesh( xyz, axis, @@ -386,6 +399,7 @@ def plot_morph( dims: Tuple[int] = (0, 1), col: str = "k", ax: Optional[Axes] = None, + resolution: int = 100, morph_plot_kwargs: Dict = {}, ) -> Axes: """Plot the detailed morphology. @@ -404,6 +418,10 @@ def plot_morph( ax: The matplotlib axis to plot on. morph_plot_kwargs: The plot kwargs for plt.fill. + resolution: defines the resolution of the mesh. + If too low (typically <10), can result in errors. + Useful too have a simpler mesh for plotting. + Returns: Plot of the detailed morphology.""" if ax is None: @@ -424,7 +442,12 @@ def plot_morph( dxyz = xyzr2[:3] - xyzr1[:3] length = np.sqrt(np.sum(dxyz**2)) points = create_cone_frustum_mesh( - length, xyzr1[-1], xyzr2[-1], bottom_dome=True, top_dome=True + length, + xyzr1[-1], + xyzr2[-1], + bottom_dome=True, + top_dome=True, + resolution=resolution, ) plot_mesh( points, @@ -437,7 +460,12 @@ def plot_morph( ) else: points = create_cone_frustum_mesh( - 0, xyzr[:, -1], xyzr[:, -1], bottom_dome=True, top_dome=True + 0, + xyzr[:, -1], + xyzr[:, -1], + bottom_dome=True, + top_dome=True, + resolution=resolution, ) plot_mesh( points, diff --git a/tests/test_plotting_api.py b/tests/test_plotting_api.py index c7c34022..d0be7190 100644 --- a/tests/test_plotting_api.py +++ b/tests/test_plotting_api.py @@ -19,31 +19,58 @@ from jaxley.synapses import IonotropicSynapse -def test_cell(): - dirname = os.path.dirname(__file__) - fname = os.path.join(dirname, "swc_files", "morph.swc") - cell = jx.read_swc(fname, nseg=4) +@pytest.fixture(scope="module") +def comp() -> jx.Compartment: + comp = jx.Compartment() + comp.compute_xyz() + return comp + + +@pytest.fixture(scope="module") +def branch(comp) -> jx.Branch: + branch = jx.Branch(comp, 4) + branch.compute_xyz() + return branch + +@pytest.fixture(scope="module") +def cell(branch) -> jx.Cell: + cell = jx.Cell(branch, [-1, 0, 0, 1, 1]) + cell.compute_xyz() + return cell + + +@pytest.fixture(scope="module") +def simple_net(cell) -> jx.Network: + net = jx.Network([cell] * 4) + net.compute_xyz() + return net + + +@pytest.fixture(scope="module") +def morph_cell() -> jx.Cell: + morph_cell = jx.read_swc( + os.path.join(os.path.dirname(__file__), "swc_files", "morph.swc"), + nseg=1, + ) + return morph_cell + + +def test_cell(morph_cell): # Plot 1. _, ax = plt.subplots(1, 1, figsize=(3, 3)) - ax = cell.vis(ax=ax) - ax = cell.branch([0, 1, 2]).vis(ax=ax, col="r") - ax = cell.branch(1).loc(0.9).vis(ax=ax, col="b") + ax = morph_cell.vis(ax=ax) + ax = morph_cell.branch([0, 1, 2]).vis(ax=ax, col="r") + ax = morph_cell.branch(1).loc(0.9).vis(ax=ax, col="b") # Plot 2. - cell.branch(0).add_to_group("soma") - cell.branch(1).add_to_group("soma") - ax = cell.soma.vis() - + morph_cell.branch(0).add_to_group("soma") + morph_cell.branch(1).add_to_group("soma") + ax = morph_cell.soma.vis() -def test_network(): - dirname = os.path.dirname(__file__) - fname = os.path.join(dirname, "swc_files", "morph.swc") - cell1 = jx.read_swc(fname, nseg=4) - cell2 = jx.read_swc(fname, nseg=4) - cell3 = jx.read_swc(fname, nseg=4) - net = jx.Network([cell1, cell2, cell3]) +def test_network(morph_cell): + net = jx.Network([morph_cell, morph_cell, morph_cell]) connect( net.cell(0).branch(0).loc(0.0), net.cell(1).branch(0).loc(0.0), @@ -81,11 +108,7 @@ def test_network(): ax = net.excitatory.vis() -def test_vis_networks_built_from_scartch(): - comp = jx.Compartment() - branch = jx.Branch(comp, 4) - cell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1]) - +def test_vis_networks_built_from_scratch(comp, branch, cell): net = jx.Network([cell, cell]) connect( net.cell(0).branch(0).loc(0.0), @@ -110,25 +133,15 @@ def test_vis_networks_built_from_scartch(): # Plot 3. _, ax = plt.subplots(1, 1, figsize=(3, 3)) - comp.compute_xyz() ax = comp.vis(ax=ax) # Plot 4. _, ax = plt.subplots(1, 1, figsize=(3, 3)) - branch.compute_xyz() ax = branch.vis(ax=ax) -def test_mixed_network(): - dirname = os.path.dirname(__file__) - fname = os.path.join(dirname, "swc_files", "morph.swc") - cell1 = jx.read_swc(fname, nseg=4) - - comp = jx.Compartment() - branch = jx.Branch(comp, 4) - cell2 = jx.Cell(branch, parents=[-1, 0, 0, 1, 1]) - - net = jx.Network([cell1, cell2]) +def test_mixed_network(morph_cell, cell): + net = jx.Network([morph_cell, cell]) connect( net.cell(0).branch(0).loc(0.0), net.cell(1).branch(0).loc(0.0), @@ -145,9 +158,9 @@ def test_mixed_network(): net.cell(1).move(0, -800) net.rotate(180) - before_xyzrs = deepcopy(net.xyzr[len(cell1.xyzr) :]) + before_xyzrs = deepcopy(net.xyzr[len(morph_cell.xyzr) :]) net.cell(1).rotate(90) - after_xyzrs = net.xyzr[len(cell1.xyzr) :] + after_xyzrs = net.xyzr[len(morph_cell.xyzr) :] # Test that rotation worked as expected. for b, a in zip(before_xyzrs, after_xyzrs): assert np.allclose(b[:, 0], -a[:, 1], atol=1e-6) @@ -156,34 +169,24 @@ def test_mixed_network(): _ = net.vis(detail="full") -def test_volume_plotting(): - comp = jx.Compartment() - comp.compute_xyz() - branch = jx.Branch(comp, 4) - branch.compute_xyz() - cell = jx.Cell([branch] * 3, [-1, 0, 0]) - cell.compute_xyz() - net = jx.Network([cell] * 4) - net.compute_xyz() - - morph_cell = jx.read_swc( - os.path.join(os.path.dirname(__file__), "swc_files", "morph.swc"), - nseg=1, - ) - +def test_volume_plotting_2d(comp, branch, cell, simple_net, morph_cell): fig, ax = plt.subplots() - for module in [comp, branch, cell, net, morph_cell]: - module.vis(type="comp", ax=ax) + for module in [comp, branch, cell, simple_net, morph_cell]: + module.vis(type="comp", ax=ax, morph_plot_kwargs={"resolution": 6}) plt.close(fig) + +def test_volume_plotting_3d(comp, branch, cell, simple_net, morph_cell): # test 3D plotting - for module in [comp, branch, cell, net, morph_cell]: - module.vis(type="comp", dims=[0, 1, 2]) + for module in [comp, branch, cell, simple_net, morph_cell]: + module.vis(type="comp", dims=[0, 1, 2], morph_plot_kwargs={"resolution": 6}) plt.close() + +def test_morph_plotting(morph_cell): # test morph plotting (does not work if no radii in xyzr) - morph_cell.vis(type="morph") + morph_cell.vis(type="morph", morph_plot_kwargs={"resolution": 6}) morph_cell.branch(1).vis( - type="morph", dims=[0, 1, 2] + type="morph", dims=[0, 1, 2], morph_plot_kwargs={"resolution": 6} ) # plotting whole thing takes too long plt.close() diff --git a/tests/test_swc.py b/tests/test_swc.py index 3da0baec..09c2b44c 100644 --- a/tests/test_swc.py +++ b/tests/test_swc.py @@ -180,7 +180,7 @@ def test_swc_voltages(file): for i in trunk_inds + tuft_inds + basal_inds: cell.branch(i).loc(0.05).record() - voltages_jaxley = jx.integrate(cell, delta_t=dt) + voltages_jaxley = jx.integrate(cell, delta_t=dt, voltage_solver="jax.sparse") ################### NEURON ################# stim = h.IClamp(h.soma[0](0.1))