From 79f311c80f27f0f1aa2b21db3d3f6b67fbc42d32 Mon Sep 17 00:00:00 2001 From: jnsbck <65561470+jnsbck@users.noreply.github.com> Date: Thu, 5 Dec 2024 11:28:33 +0100 Subject: [PATCH] fix: fix synapse_terminals (#545) --- jaxley/modules/network.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index a944fa93..15183bd6 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -428,9 +428,10 @@ def vis( color: str = "k", synapse_color: str = "b", dims: Tuple[int] = (0, 1), - type: str = "line", cell_plot_kwargs: Dict = {}, synapse_plot_kwargs: Dict = {}, + synapse_scatter_kwargs: Dict = {}, + **kwargs, # absorb add. kwargs, i.e. to enable net.cell(0).vis(type="line") ) -> Axes: """Visualize the module. @@ -438,19 +439,17 @@ def vis( detail: Either of [point, full]. `point` visualizes every neuron in the network as a dot. `full` plots the full morphology of every neuron. It requires that - `compute_xyz()` has been run and allows for indivual neurons to be - moved with `.move()`. - color: The color in which cells are plotted. Only takes effect if - `detail='full'`. - type: Either `line` or `scatter`. Only takes effect if `detail='full'`. - synapse_color: The color in which synapses are plotted. Only takes effect if - `detail='full'`. + `compute_xyz()` has been run. + color: The color in which cells are plotted. + synapse_color: The color in which synapses are plotted. dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them. cell_plot_kwargs: Keyword arguments passed to the plotting function for cell morphologies. Only takes effect for `detail='full'`. synapse_plot_kwargs: Keyword arguments passed to the plotting function for - syanpses. Only takes effect for `detail='full'`. + syanpses. + synapse_scatter_kwargs: Keyword arguments passed to the scatter function for + syanpse terminals. """ xyz0 = self.cell(0).xyzr[0][:, :3] same_xyz = np.all([np.all(xyz0 == cell.xyzr[0][:, :3]) for cell in self.cells]) @@ -472,9 +471,7 @@ def vis( pos = cell_to_point_xyz(cell)[dims_np] ax.scatter(*pos, color=color, **cell_plot_kwargs) elif detail == "full": - ax = super().vis( - dims=dims, color=color, ax=ax, type=type, **cell_plot_kwargs - ) + ax = super().vis(dims=dims, color=color, ax=ax, **cell_plot_kwargs) else: raise ValueError("detail must be in {full, point}.") @@ -485,7 +482,7 @@ def vis( loc, comp = edge[[prepost + "_locs", prepost + "_global_comp_index"]] branch = nodes.loc[comp, "global_branch_index"] cell = nodes.loc[comp, "global_cell_index"] - branch_xyz = self.xyzr[branch] + branch_xyz = self.xyzr[branch][:, :3] xyz_loc = branch_xyz if detail == "point": @@ -501,8 +498,10 @@ def vis( prepost_locs.append(xyz_loc) prepost_locs = np.stack(prepost_locs).T - ax.plot(*prepost_locs[dims_np], color=synapse_color, **synapse_plot_kwargs) + ax.scatter( + *prepost_locs[dims_np, 1], color=synapse_color, **synapse_scatter_kwargs + ) return ax