Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small refactor of plotting #539

Merged
merged 6 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2086,10 +2086,10 @@ def _get_external_input(
def vis(
self,
ax: Optional[Axes] = None,
col: str = "k",
color: str = "k",
dims: Tuple[int] = (0, 1),
type: str = "line",
morph_plot_kwargs: Dict = {},
**kwargs,
) -> Axes:
"""Visualize the module.

Expand All @@ -2102,24 +2102,29 @@ def vis(
- `scatter`: All traced points, are plotted as scatter points.
- `comp`: Plots the compartmentalized morphology, including radius
and shape. (shows the true compartment lengths per default, but this can
be changed via the `morph_plot_kwargs`, for details see
be changed via the `kwargs`, for details see
`jaxley.utils.plot_utils.plot_comps`).
- `morph`: Reconstructs the 3D shape of the traced morphology. For details see
`jaxley.utils.plot_utils.plot_morph`. Warning: For 3D plots and morphologies
with many traced points this can be very slow.

Args:
ax: An axis into which to plot.
col: The color for all branches.
color: The color for all branches.
dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of
two of them.
type: The type of plot. One of ["line", "scatter", "comp", "morph"].
morph_plot_kwargs: Keyword arguments passed to the plotting function.
kwargs: Keyword arguments passed to the plotting function.
"""
res = 100 if "resolution" not in kwargs else kwargs.pop("resolution")
if "comp" in type.lower():
return plot_comps(self, dims=dims, ax=ax, col=col, **morph_plot_kwargs)
return plot_comps(
self, dims=dims, ax=ax, color=color, resolution=res, **kwargs
)
if "morph" in type.lower():
return plot_morph(self, dims=dims, ax=ax, col=col, **morph_plot_kwargs)
return plot_morph(
self, dims=dims, ax=ax, color=color, resolution=res, **kwargs
)

assert not np.any(
[np.isnan(xyzr[:, dims]).all() for xyzr in self.xyzr]
Expand All @@ -2128,10 +2133,10 @@ def vis(
ax = plot_graph(
self.xyzr,
dims=dims,
col=col,
color=color,
ax=ax,
type=type,
morph_plot_kwargs=morph_plot_kwargs,
**kwargs,
)

return ax
Expand Down
217 changes: 87 additions & 130 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import itertools
from copy import deepcopy
from typing import Dict, List, Optional, Tuple, Union
from warnings import warn

import jax.numpy as jnp
import networkx as nx
import numpy as np
import pandas as pd
from jax import vmap
from matplotlib import pyplot as plt
from matplotlib.axes import Axes

from jaxley.modules.base import Module
Expand Down Expand Up @@ -383,20 +384,53 @@ def _synapse_currents(

return states, (syn_voltage_terms, syn_constant_terms)

def arrange_in_layers(
self,
layers: List[int],
within_layer_offset: float = 500.0,
between_layer_offset: float = 1500.0,
vertical_layers: bool = False,
):
"""Arrange the cells in the network to form layers.

Moves the cells in the network to arrange them into layers.

Args:
layers: List of integers specifying the number of cells in each layer.
within_layer_offset: Offset between cells within the same layer.
between_layer_offset: Offset between layers.
vertical_layers: If True, layers are arranged vertically.
"""
assert (
np.sum(layers) == self.shape[0]
), "The number of cells in the layers must match the number of cells in the network."
cells_in_layers = [
list(range(sum(layers[:i]), sum(layers[: i + 1])))
for i in range(len(layers))
]

for l, cell_inds in enumerate(cells_in_layers):
layer = self.cell(cell_inds)
for i, cell in enumerate(layer.cells):
if vertical_layers:
x_offset = (i - (len(cell_inds) - 1) / 2) * within_layer_offset
y_offset = (len(layers) - 1 - l) * between_layer_offset
else:
x_offset = l * between_layer_offset
y_offset = (i - (len(cell_inds) - 1) / 2) * within_layer_offset

cell.move_to(x=x_offset, y=y_offset, z=0)

def vis(
self,
detail: str = "full",
ax: Optional[Axes] = None,
col: str = "k",
synapse_col: str = "b",
color: str = "k",
synapse_color: str = "b",
dims: Tuple[int] = (0, 1),
type: str = "line",
layers: Optional[List] = None,
morph_plot_kwargs: Dict = {},
cell_plot_kwargs: Dict = {},
synapse_plot_kwargs: Dict = {},
synapse_scatter_kwargs: Dict = {},
networkx_options: Dict = {},
layer_kwargs: Dict = {},
) -> Axes:
"""Visualize the module.

Expand All @@ -406,148 +440,71 @@ def vis(
`full` plots the full morphology of every neuron. It requires that
`compute_xyz()` has been run and allows for indivual neurons to be
moved with `.move()`.
col: The color in which cells are plotted. Only takes effect if
color: The color in which cells are plotted. Only takes effect if
`detail='full'`.
type: Either `line` or `scatter`. Only takes effect if `detail='full'`.
synapse_col: The color in which synapses are plotted. Only takes effect if
synapse_color: The color in which synapses are plotted. Only takes effect if
`detail='full'`.
dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of
two of them.
layers: Allows to plot the network in layers. Should provide the number of
neurons in each layer, e.g., [5, 10, 1] would be a network with 5 input
neurons, 10 hidden layer neurons, and 1 output neuron.
morph_plot_kwargs: Keyword arguments passed to the plotting function for
cell_plot_kwargs: Keyword arguments passed to the plotting function for
cell morphologies. Only takes effect for `detail='full'`.
synapse_plot_kwargs: Keyword arguments passed to the plotting function for
synapse_kwargs: Keyword arguments passed to the plotting function for
syanpses. Only takes effect for `detail='full'`.
synapse_scatter_kwargs: Keyword arguments passed to the scatter function
for the end point of synapses. Only takes effect for `detail='full'`.
networkx_options: Options passed to `networkx.draw()`. Only takes effect if
`detail='point'`.
layer_kwargs: Only used if `layers` is specified and if `detail='full'`.
Can have the following entries: `within_layer_offset` (float),
`between_layer_offset` (float), `vertical_layers` (bool).
"""
if detail == "point":
graph = self._build_graph(layers)
xyz0 = self.cell(0).xyzr[0][:, :3]
same_xyz = np.all([np.all(xyz0 == cell.xyzr[0][:, :3]) for cell in self.cells])
if same_xyz:
warn(
"Same coordinates for all cells. Consider using `move`, `move_to` or `arrange_in_layers` to move them."
)

if ax is None:
fig = plt.figure(figsize=(3, 3))
ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection="3d")

if layers is not None:
pos = nx.multipartite_layout(graph, subset_key="layer")
nx.draw(graph, pos, with_labels=True, **networkx_options)
else:
nx.draw(graph, with_labels=True, **networkx_options)
# detail="point" -> pos taken to be the mean of all traced points on the cell.
cell_to_point_xyz = lambda cell: np.mean(np.vstack(cell.xyzr)[:, :3], axis=0)

dims_np = np.asarray(dims)
if detail == "point":
for cell in self.cells:
pos = cell_to_point_xyz(cell)[dims_np]
ax.scatter(*pos, color=color, **cell_plot_kwargs)
elif detail == "full":
if layers is not None:
# Assemble cells in the network into layers.
global_counter = 0
layers_config = {
"within_layer_offset": 500.0,
"between_layer_offset": 1500.0,
"vertical_layers": False,
}
layers_config.update(layer_kwargs)
for layer_ind, num_in_layer in enumerate(layers):
for ind_within_layer in range(num_in_layer):
if layers_config["vertical_layers"]:
x_offset = (
ind_within_layer - (num_in_layer - 1) / 2
) * layers_config["within_layer_offset"]
y_offset = (len(layers) - 1 - layer_ind) * layers_config[
"between_layer_offset"
]
else:
x_offset = layer_ind * layers_config["between_layer_offset"]
y_offset = (
ind_within_layer - (num_in_layer - 1) / 2
) * layers_config["within_layer_offset"]

self.cell(global_counter).move_to(x=x_offset, y=y_offset, z=0)
global_counter += 1
ax = super().vis(
dims=dims,
col=col,
ax=ax,
type=type,
morph_plot_kwargs=morph_plot_kwargs,
dims=dims, color=color, ax=ax, type=type, **cell_plot_kwargs
)
else:
raise ValueError("detail must be in {full, point}.")

pre_locs = self.edges["pre_locs"].to_numpy()
post_locs = self.edges["post_locs"].to_numpy()
pre_comp = self.edges["pre_global_comp_index"].to_numpy()
nodes = self.nodes.set_index("global_comp_index")
pre_branch = nodes.loc[pre_comp, "global_branch_index"].to_numpy()
post_comp = self.edges["post_global_comp_index"].to_numpy()
post_branch = nodes.loc[post_comp, "global_branch_index"].to_numpy()

dims_np = np.asarray(dims)

for pre_loc, post_loc, pre_b, post_b in zip(
pre_locs, post_locs, pre_branch, post_branch
):
pre_coord = self.xyzr[pre_b]
if len(pre_coord) == 2:
# If only start and end point of a branch are traced, perform a
# linear interpolation to get the synpase location.
pre_coord = pre_coord[0] + (pre_coord[1] - pre_coord[0]) * pre_loc
else:
# If densely traced, use intermediate trace values for synapse loc.
middle_ind = int((len(pre_coord) - 1) * pre_loc)
pre_coord = pre_coord[middle_ind]

post_coord = self.xyzr[post_b]
if len(post_coord) == 2:
nodes = self.nodes.set_index("global_comp_index")
for i, edge in self.edges.iterrows():
prepost_locs = []
for prepost in ["pre", "post"]:
loc, comp = edge[[prepost + "_locs", prepost + "_global_comp_index"]]
branch = nodes.loc[comp, "global_branch_index"]
cell = nodes.loc[comp, "global_cell_index"]
branch_xyz = self.xyzr[branch]

xyz_loc = branch_xyz
if detail == "point":
xyz_loc = cell_to_point_xyz(self.cell(cell))
elif len(branch_xyz) == 2:
# If only start and end point of a branch are traced, perform a
# linear interpolation to get the synpase location.
post_coord = (
post_coord[0] + (post_coord[1] - post_coord[0]) * post_loc
)
xyz_loc = branch_xyz[0] + (branch_xyz[1] - branch_xyz[0]) * loc
else:
# If densely traced, use intermediate trace values for synapse loc.
middle_ind = int((len(post_coord) - 1) * post_loc)
post_coord = post_coord[middle_ind]

coords = np.stack([pre_coord[dims_np], post_coord[dims_np]]).T
ax.plot(
coords[0],
coords[1],
c=synapse_col,
**synapse_plot_kwargs,
)
ax.scatter(
post_coord[dims_np[0]],
post_coord[dims_np[1]],
c=synapse_col,
**synapse_scatter_kwargs,
)
else:
raise ValueError("detail must be in {full, point}.")
middle_ind = int((len(branch_xyz) - 1) * loc)
xyz_loc = xyz_loc[middle_ind]

return ax

def _build_graph(self, layers: Optional[List] = None, **options):
graph = nx.DiGraph()

def build_extents(*subset_sizes):
return nx.utils.pairwise(itertools.accumulate((0,) + subset_sizes))

if layers is not None:
extents = build_extents(*layers)
layers = [range(start, end) for start, end in extents]
for i, layer in enumerate(layers):
graph.add_nodes_from(layer, layer=i)
else:
graph.add_nodes_from(range(len(self._cells_in_view)))

pre_comp = self.edges["pre_global_comp_index"].to_numpy()
nodes = self.nodes.set_index("global_comp_index")
pre_cell = nodes.loc[pre_comp, "global_cell_index"].to_numpy()
post_comp = self.edges["post_global_comp_index"].to_numpy()
post_cell = nodes.loc[post_comp, "global_cell_index"].to_numpy()
prepost_locs.append(xyz_loc)
prepost_locs = np.stack(prepost_locs).T

inds = np.stack([pre_cell, post_cell]).T
graph.add_edges_from(inds)
ax.plot(*prepost_locs[dims_np], color=synapse_color, **synapse_plot_kwargs)

return graph
return ax

def _infer_synapse_type_ind(self, synapse_name):
syn_names = self.base.synapse_names
Expand Down
Loading
Loading