diff --git a/CHANGELOG.md b/CHANGELOG.md index c678ab75..dbdc3d4c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,15 @@ net.record("i_IonotropicSynapse") - Regression tests can be done locally by running `NEW_BASELINE=1 pytest -m regression` i.e. on `main` and then `pytest -m regression` on `feature`, which will produce a test report (printed to the console and saved to .txt). - If a PR introduces new baseline tests or reduces runtimes, then a new baseline can be created by commenting "/update_regression_baselines" on the PR. +- refactor plotting (#539, @jnsbck). + - rm networkx dependency + - add `Network.arrange_in_layers` + - disentangle moving of cells and plotting in `Network.vis`. To get the same as `net.vis(layers=[3,3])`, one now has to do: +```python +net.arrange_in_layers([3,3]) +net.vis() +``` + # 0.5.0 ### API changes diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 5fa50cc0..59ccbd5c 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -2086,10 +2086,10 @@ def _get_external_input( def vis( self, ax: Optional[Axes] = None, - col: str = "k", + color: str = "k", dims: Tuple[int] = (0, 1), type: str = "line", - morph_plot_kwargs: Dict = {}, + **kwargs, ) -> Axes: """Visualize the module. @@ -2102,7 +2102,7 @@ def vis( - `scatter`: All traced points, are plotted as scatter points. - `comp`: Plots the compartmentalized morphology, including radius and shape. (shows the true compartment lengths per default, but this can - be changed via the `morph_plot_kwargs`, for details see + be changed via the `kwargs`, for details see `jaxley.utils.plot_utils.plot_comps`). - `morph`: Reconstructs the 3D shape of the traced morphology. For details see `jaxley.utils.plot_utils.plot_morph`. Warning: For 3D plots and morphologies @@ -2110,16 +2110,21 @@ def vis( Args: ax: An axis into which to plot. - col: The color for all branches. + color: The color for all branches. dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them. type: The type of plot. One of ["line", "scatter", "comp", "morph"]. - morph_plot_kwargs: Keyword arguments passed to the plotting function. + kwargs: Keyword arguments passed to the plotting function. """ + res = 100 if "resolution" not in kwargs else kwargs.pop("resolution") if "comp" in type.lower(): - return plot_comps(self, dims=dims, ax=ax, col=col, **morph_plot_kwargs) + return plot_comps( + self, dims=dims, ax=ax, color=color, resolution=res, **kwargs + ) if "morph" in type.lower(): - return plot_morph(self, dims=dims, ax=ax, col=col, **morph_plot_kwargs) + return plot_morph( + self, dims=dims, ax=ax, color=color, resolution=res, **kwargs + ) assert not np.any( [np.isnan(xyzr[:, dims]).all() for xyzr in self.xyzr] @@ -2128,10 +2133,10 @@ def vis( ax = plot_graph( self.xyzr, dims=dims, - col=col, + color=color, ax=ax, type=type, - morph_plot_kwargs=morph_plot_kwargs, + **kwargs, ) return ax diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 0a65c58c..a944fa93 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -4,12 +4,13 @@ import itertools from copy import deepcopy from typing import Dict, List, Optional, Tuple, Union +from warnings import warn import jax.numpy as jnp -import networkx as nx import numpy as np import pandas as pd from jax import vmap +from matplotlib import pyplot as plt from matplotlib.axes import Axes from jaxley.modules.base import Module @@ -383,171 +384,127 @@ def _synapse_currents( return states, (syn_voltage_terms, syn_constant_terms) + def arrange_in_layers( + self, + layers: List[int], + within_layer_offset: float = 500.0, + between_layer_offset: float = 1500.0, + vertical_layers: bool = False, + ): + """Arrange the cells in the network to form layers. + + Moves the cells in the network to arrange them into layers. + + Args: + layers: List of integers specifying the number of cells in each layer. + within_layer_offset: Offset between cells within the same layer. + between_layer_offset: Offset between layers. + vertical_layers: If True, layers are arranged vertically. + """ + assert ( + np.sum(layers) == self.shape[0] + ), "The number of cells in the layers must match the number of cells in the network." + cells_in_layers = [ + list(range(sum(layers[:i]), sum(layers[: i + 1]))) + for i in range(len(layers)) + ] + + for l, cell_inds in enumerate(cells_in_layers): + layer = self.cell(cell_inds) + for i, cell in enumerate(layer.cells): + if vertical_layers: + x_offset = (i - (len(cell_inds) - 1) / 2) * within_layer_offset + y_offset = (len(layers) - 1 - l) * between_layer_offset + else: + x_offset = l * between_layer_offset + y_offset = (i - (len(cell_inds) - 1) / 2) * within_layer_offset + + cell.move_to(x=x_offset, y=y_offset, z=0) + def vis( self, detail: str = "full", ax: Optional[Axes] = None, - col: str = "k", - synapse_col: str = "b", + color: str = "k", + synapse_color: str = "b", dims: Tuple[int] = (0, 1), type: str = "line", - layers: Optional[List] = None, - morph_plot_kwargs: Dict = {}, + cell_plot_kwargs: Dict = {}, synapse_plot_kwargs: Dict = {}, - synapse_scatter_kwargs: Dict = {}, - networkx_options: Dict = {}, - layer_kwargs: Dict = {}, ) -> Axes: """Visualize the module. Args: detail: Either of [point, full]. `point` visualizes every neuron in the - network as a dot (and it uses `networkx` to obtain cell positions). + network as a dot. `full` plots the full morphology of every neuron. It requires that `compute_xyz()` has been run and allows for indivual neurons to be moved with `.move()`. - col: The color in which cells are plotted. Only takes effect if + color: The color in which cells are plotted. Only takes effect if `detail='full'`. type: Either `line` or `scatter`. Only takes effect if `detail='full'`. - synapse_col: The color in which synapses are plotted. Only takes effect if + synapse_color: The color in which synapses are plotted. Only takes effect if `detail='full'`. dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them. - layers: Allows to plot the network in layers. Should provide the number of - neurons in each layer, e.g., [5, 10, 1] would be a network with 5 input - neurons, 10 hidden layer neurons, and 1 output neuron. - morph_plot_kwargs: Keyword arguments passed to the plotting function for + cell_plot_kwargs: Keyword arguments passed to the plotting function for cell morphologies. Only takes effect for `detail='full'`. synapse_plot_kwargs: Keyword arguments passed to the plotting function for syanpses. Only takes effect for `detail='full'`. - synapse_scatter_kwargs: Keyword arguments passed to the scatter function - for the end point of synapses. Only takes effect for `detail='full'`. - networkx_options: Options passed to `networkx.draw()`. Only takes effect if - `detail='point'`. - layer_kwargs: Only used if `layers` is specified and if `detail='full'`. - Can have the following entries: `within_layer_offset` (float), - `between_layer_offset` (float), `vertical_layers` (bool). """ - if detail == "point": - graph = self._build_graph(layers) + xyz0 = self.cell(0).xyzr[0][:, :3] + same_xyz = np.all([np.all(xyz0 == cell.xyzr[0][:, :3]) for cell in self.cells]) + if same_xyz: + warn( + "Same coordinates for all cells. Consider using `move`, `move_to` or `arrange_in_layers` to move them." + ) + + if ax is None: + fig = plt.figure(figsize=(3, 3)) + ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection="3d") - if layers is not None: - pos = nx.multipartite_layout(graph, subset_key="layer") - nx.draw(graph, pos, with_labels=True, **networkx_options) - else: - nx.draw(graph, with_labels=True, **networkx_options) + # detail="point" -> pos taken to be the mean of all traced points on the cell. + cell_to_point_xyz = lambda cell: np.mean(np.vstack(cell.xyzr)[:, :3], axis=0) + + dims_np = np.asarray(dims) + if detail == "point": + for cell in self.cells: + pos = cell_to_point_xyz(cell)[dims_np] + ax.scatter(*pos, color=color, **cell_plot_kwargs) elif detail == "full": - if layers is not None: - # Assemble cells in the network into layers. - global_counter = 0 - layers_config = { - "within_layer_offset": 500.0, - "between_layer_offset": 1500.0, - "vertical_layers": False, - } - layers_config.update(layer_kwargs) - for layer_ind, num_in_layer in enumerate(layers): - for ind_within_layer in range(num_in_layer): - if layers_config["vertical_layers"]: - x_offset = ( - ind_within_layer - (num_in_layer - 1) / 2 - ) * layers_config["within_layer_offset"] - y_offset = (len(layers) - 1 - layer_ind) * layers_config[ - "between_layer_offset" - ] - else: - x_offset = layer_ind * layers_config["between_layer_offset"] - y_offset = ( - ind_within_layer - (num_in_layer - 1) / 2 - ) * layers_config["within_layer_offset"] - - self.cell(global_counter).move_to(x=x_offset, y=y_offset, z=0) - global_counter += 1 ax = super().vis( - dims=dims, - col=col, - ax=ax, - type=type, - morph_plot_kwargs=morph_plot_kwargs, + dims=dims, color=color, ax=ax, type=type, **cell_plot_kwargs ) + else: + raise ValueError("detail must be in {full, point}.") - pre_locs = self.edges["pre_locs"].to_numpy() - post_locs = self.edges["post_locs"].to_numpy() - pre_comp = self.edges["pre_global_comp_index"].to_numpy() - nodes = self.nodes.set_index("global_comp_index") - pre_branch = nodes.loc[pre_comp, "global_branch_index"].to_numpy() - post_comp = self.edges["post_global_comp_index"].to_numpy() - post_branch = nodes.loc[post_comp, "global_branch_index"].to_numpy() - - dims_np = np.asarray(dims) - - for pre_loc, post_loc, pre_b, post_b in zip( - pre_locs, post_locs, pre_branch, post_branch - ): - pre_coord = self.xyzr[pre_b] - if len(pre_coord) == 2: - # If only start and end point of a branch are traced, perform a - # linear interpolation to get the synpase location. - pre_coord = pre_coord[0] + (pre_coord[1] - pre_coord[0]) * pre_loc - else: - # If densely traced, use intermediate trace values for synapse loc. - middle_ind = int((len(pre_coord) - 1) * pre_loc) - pre_coord = pre_coord[middle_ind] - - post_coord = self.xyzr[post_b] - if len(post_coord) == 2: + nodes = self.nodes.set_index("global_comp_index") + for i, edge in self.edges.iterrows(): + prepost_locs = [] + for prepost in ["pre", "post"]: + loc, comp = edge[[prepost + "_locs", prepost + "_global_comp_index"]] + branch = nodes.loc[comp, "global_branch_index"] + cell = nodes.loc[comp, "global_cell_index"] + branch_xyz = self.xyzr[branch] + + xyz_loc = branch_xyz + if detail == "point": + xyz_loc = cell_to_point_xyz(self.cell(cell)) + elif len(branch_xyz) == 2: # If only start and end point of a branch are traced, perform a # linear interpolation to get the synpase location. - post_coord = ( - post_coord[0] + (post_coord[1] - post_coord[0]) * post_loc - ) + xyz_loc = branch_xyz[0] + (branch_xyz[1] - branch_xyz[0]) * loc else: # If densely traced, use intermediate trace values for synapse loc. - middle_ind = int((len(post_coord) - 1) * post_loc) - post_coord = post_coord[middle_ind] - - coords = np.stack([pre_coord[dims_np], post_coord[dims_np]]).T - ax.plot( - coords[0], - coords[1], - c=synapse_col, - **synapse_plot_kwargs, - ) - ax.scatter( - post_coord[dims_np[0]], - post_coord[dims_np[1]], - c=synapse_col, - **synapse_scatter_kwargs, - ) - else: - raise ValueError("detail must be in {full, point}.") + middle_ind = int((len(branch_xyz) - 1) * loc) + xyz_loc = xyz_loc[middle_ind] - return ax - - def _build_graph(self, layers: Optional[List] = None, **options): - graph = nx.DiGraph() - - def build_extents(*subset_sizes): - return nx.utils.pairwise(itertools.accumulate((0,) + subset_sizes)) - - if layers is not None: - extents = build_extents(*layers) - layers = [range(start, end) for start, end in extents] - for i, layer in enumerate(layers): - graph.add_nodes_from(layer, layer=i) - else: - graph.add_nodes_from(range(len(self._cells_in_view))) - - pre_comp = self.edges["pre_global_comp_index"].to_numpy() - nodes = self.nodes.set_index("global_comp_index") - pre_cell = nodes.loc[pre_comp, "global_cell_index"].to_numpy() - post_comp = self.edges["post_global_comp_index"].to_numpy() - post_cell = nodes.loc[post_comp, "global_cell_index"].to_numpy() + prepost_locs.append(xyz_loc) + prepost_locs = np.stack(prepost_locs).T - inds = np.stack([pre_cell, post_cell]).T - graph.add_edges_from(inds) + ax.plot(*prepost_locs[dims_np], color=synapse_color, **synapse_plot_kwargs) - return graph + return ax def _infer_synapse_type_ind(self, synapse_name): syn_names = self.base.synapse_names diff --git a/jaxley/utils/plot_utils.py b/jaxley/utils/plot_utils.py index e7a0b13c..dea72bc1 100644 --- a/jaxley/utils/plot_utils.py +++ b/jaxley/utils/plot_utils.py @@ -17,10 +17,10 @@ def plot_graph( xyzr: ndarray, dims: Tuple[int] = (0, 1), - col: str = "k", + color: str = "k", ax: Optional[Axes] = None, type: str = "line", - morph_plot_kwargs: Dict = {}, + **kwargs, ) -> Axes: """Plot morphology. @@ -28,10 +28,10 @@ def plot_graph( xyzr: The coordinates of the morphology. dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two or three of them. - col: The color for all branches. + color: The color for all branches. ax: The matplotlib axis to plot on. type: Either `line` or `scatter`. - morph_plot_kwargs: The plot kwargs for plt.plot or plt.scatter. + kwargs: The plot kwargs for plt.plot or plt.scatter. """ if ax is None: @@ -42,9 +42,9 @@ def plot_graph( points = coords_of_branch[:, dims].T if "line" in type.lower(): - _ = ax.plot(*points, color=col, **morph_plot_kwargs) + _ = ax.plot(*points, color=color, **kwargs) elif "scatter" in type.lower(): - _ = ax.scatter(*points, color=col, **morph_plot_kwargs) + _ = ax.scatter(*points, color=color, **kwargs) else: raise NotImplementedError @@ -307,11 +307,11 @@ def plot_mesh( def plot_comps( module_or_view: Union["jx.Module", "jx.View"], dims: Tuple[int] = (0, 1), - col: str = "k", + color: str = "k", ax: Optional[Axes] = None, - comp_plot_kwargs: Dict = {}, true_comp_length: bool = True, resolution: int = 100, + **kwargs, ) -> Axes: """Plot compartmentalized neural morphology. @@ -321,9 +321,8 @@ def plot_comps( module_or_view: The module or view to plot. dims: The dimensions to plot / to project the cylinder onto, i.e. [0,1] xy-plane or [0,1,2] for 3D. - col: The color for all compartments + color: The color for all compartments ax: The matplotlib axis to plot on. - comp_plot_kwargs: The plot kwargs for plt.fill. true_comp_length: If True, the length of the compartment is used, i.e. the length of the traced neurite. This means for zig-zagging neurites the cylinders will be longer than the straight-line distance between the @@ -333,6 +332,7 @@ def plot_comps( resolution: defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting. + kwargs: The plot kwargs for plt.fill. Returns: Plot of the compartmentalized morphology. @@ -360,11 +360,11 @@ def plot_comps( center, np.array(dims), ax, - color=col, - **comp_plot_kwargs, + color=color, + **kwargs, ) else: - ax.add_artist(plt.Circle(locs[0, dims], radius, color=col)) + ax.add_artist(plt.Circle(locs[0, dims], radius, color=color)) else: lens = np.sqrt(np.nansum(np.diff(locs, axis=0) ** 2, axis=1)) lens = np.cumsum([0] + lens.tolist()) @@ -388,8 +388,8 @@ def plot_comps( center, np.array(dims), ax, - color=col, - **comp_plot_kwargs, + color=color, + **kwargs, ) return ax @@ -397,10 +397,10 @@ def plot_comps( def plot_morph( module_or_view: Union["jx.Module", "jx.View"], dims: Tuple[int] = (0, 1), - col: str = "k", + color: str = "k", ax: Optional[Axes] = None, resolution: int = 100, - morph_plot_kwargs: Dict = {}, + **kwargs, ) -> Axes: """Plot the detailed morphology. @@ -414,9 +414,9 @@ def plot_morph( module_or_view: The module or view to plot. dims: The dimensions to plot / to project the cylinder onto, i.e. [0,1] xy-plane or [0,1,2] for 3D. - col: The color for all branches + color: The color for all branches ax: The matplotlib axis to plot on. - morph_plot_kwargs: The plot kwargs for plt.fill. + kwargs: The plot kwargs for plt.fill. resolution: defines the resolution of the mesh. If too low (typically <10), can result in errors. @@ -454,9 +454,9 @@ def plot_morph( dxyz, xyzr1[:3], np.array(dims), - color=col, + color=color, ax=ax, - **morph_plot_kwargs, + **kwargs, ) else: points = create_cone_frustum_mesh( @@ -472,9 +472,9 @@ def plot_morph( np.ones(3), xyzr[0, :3], dims=np.array(dims), - color=col, + color=color, ax=ax, - **morph_plot_kwargs, + **kwargs, ) return ax diff --git a/pyproject.toml b/pyproject.toml index 4a81bbe6..50faea71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,6 @@ keywords = ["neuroscience", "biophysics", "simulator", "jax"] dependencies = [ "jax", "matplotlib", - "networkx", "numpy", "pandas>=2.2.0", "tridiax", diff --git a/tests/test_plotting_api.py b/tests/test_plotting_api.py index a2e3b9e8..0149f4db 100644 --- a/tests/test_plotting_api.py +++ b/tests/test_plotting_api.py @@ -28,8 +28,8 @@ def test_cell(SimpleMorphCell): # Plot 1. _, ax = plt.subplots(1, 1, figsize=(3, 3)) ax = cell.vis(ax=ax) - ax = cell.branch([0, 1, 2]).vis(ax=ax, col="r") - ax = cell.branch(1).loc(0.9).vis(ax=ax, col="b") + ax = cell.branch([0, 1, 2]).vis(ax=ax, color="r") + ax = cell.branch(1).loc(0.9).vis(ax=ax, color="b") # Plot 2. cell.branch(0).add_to_group("soma") @@ -59,9 +59,9 @@ def test_network(SimpleMorphCell): # Plot 1. _, ax = plt.subplots(1, 1, figsize=(3, 3)) ax = net.cell([0, 1]).vis(ax=ax) - ax = net.cell(2).vis(ax=ax, col="r", type="line") - ax = net.cell(2).vis(ax=ax, col="r", type="scatter") - ax = net.cell(0).branch(np.arange(10).tolist()).vis(ax=ax, col="b") + ax = net.cell(2).vis(ax=ax, color="r", type="line") + ax = net.cell(2).vis(ax=ax, color="r", type="scatter") + ax = net.cell(0).branch(np.arange(10).tolist()).vis(ax=ax, color="b") # Plot 2. ax = net.vis(detail="full", type="line") @@ -71,10 +71,12 @@ def test_network(SimpleMorphCell): net.vis(detail="point") # Plot 4. - net.vis(detail="point", layers=[2, 1]) + net.arrange_in_layers([2, 1]) + net.vis(detail="point") # Plot 5. - net.vis(detail="full", layers=[2, 1]) + net.arrange_in_layers([2, 1]) + net.vis(detail="full") # Plot 5. net.cell(0).add_to_group("excitatory") @@ -174,18 +176,20 @@ def test_volume_plotting( morph_cell = SimpleMorphCell(fname, ncomp=1) fig, ax = plt.subplots() - for module in [comp, branch, cell, net, morph_cell]: - module.vis(type="comp", ax=ax, morph_plot_kwargs={"resolution": 6}) + for module in [comp, branch, cell, morph_cell]: + module.vis(type="comp", ax=ax, resolution=6) + net.vis(type="comp", ax=ax, cell_plot_kwargs={"resolution": 6}) plt.close(fig) # test 3D plotting - for module in [comp, branch, cell, net, morph_cell]: - module.vis(type="comp", dims=[0, 1, 2], morph_plot_kwargs={"resolution": 6}) + for module in [comp, branch, cell, morph_cell]: + module.vis(type="comp", dims=[0, 1, 2], resolution=6) + net.vis(type="comp", dims=[0, 1, 2], cell_plot_kwargs={"resolution": 6}) plt.close() # test morph plotting (does not work if no radii in xyzr) morph_cell.branch(1).vis(type="morph") morph_cell.branch(1).vis( - type="morph", dims=[0, 1, 2], morph_plot_kwargs={"resolution": 6} + type="morph", dims=[0, 1, 2], resolution=6 ) # plotting whole thing takes too long plt.close()