diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index b46c90b4..ea8328f2 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -13,7 +13,12 @@ from jaxley.channels import Channel from jaxley.solver_voltage import step_voltage_explicit, step_voltage_implicit from jaxley.synapses import Synapse -from jaxley.utils.plot_utils import plot_morph, plot_swc +from jaxley.utils.cell_utils import ( + _compute_index_of_child, + _compute_num_children, + compute_levels, +) +from jaxley.utils.plot_utils import plot_morph class Module(ABC): @@ -636,7 +641,6 @@ def get_external_input( def vis( self, - detail: str = "full", ax=None, col: str = "k", dims: Tuple[int] = (0, 1), @@ -645,16 +649,12 @@ def vis( """Visualize the module. Args: - detail: Either of [sticks, full]. `sticks` visualizes all branches of every - neuron, but draws branches as straight lines. `full` plots the full - morphology of every neuron, as read from the SWC file. ax: An axis into which to plot. col: The color for all branches. dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them. """ return self._vis( - detail=detail, dims=dims, col=col, ax=ax, @@ -662,32 +662,84 @@ def vis( morph_plot_kwargs=morph_plot_kwargs, ) - def _vis(self, detail, ax, col, dims, view, morph_plot_kwargs): + def _vis(self, ax, col, dims, view, morph_plot_kwargs): branches_inds = view["branch_index"].to_numpy() - coords = [self.xyzr[branch_ind] for branch_ind in branches_inds] - - if detail == "full": - assert self.xyzr, "no coordinates available, use `vis(detail='point')`." - ax = plot_swc( - coords, - dims=dims, - col=col, - ax=ax, - morph_plot_kwargs=morph_plot_kwargs, - ) - # elif detail == "sticks": - # fig, ax = plot_morph( - # cell=self, - # col=col, - # max_y_multiplier=5.0, - # min_y_multiplier=0.5, - # ax=ax, - # ) - else: - raise ValueError("`detail must be in {point, full}.") + coords = [] + for branch_ind in branches_inds: + assert not np.any( + np.isnan(self.xyzr[branch_ind][:, dims]) + ), "No coordinates available. Use `vis(detail='point')` or run `.compute_xyz()` before running `.vis()`." + coords.append(self.xyzr[branch_ind]) + + ax = plot_morph( + coords, + dims=dims, + col=col, + ax=ax, + morph_plot_kwargs=morph_plot_kwargs, + ) return ax + def compute_xyz(self): + """Return xyz coordinates of every branch, based on the branch length.""" + max_y_multiplier = 5.0 + min_y_multiplier = 0.5 + + parents = self.comb_parents + num_children = _compute_num_children(parents) + index_of_child = _compute_index_of_child(parents) + levels = compute_levels(parents) + + # Extract branch. + inds_branch = self.nodes.groupby("branch_index")["comp_index"].apply(list) + branch_lens = [ + np.sum(self.params["length"][np.asarray(i)]) for i in inds_branch + ] + endpoints = [] + + # Different levels will get a different "angle" at which the children emerge from + # the parents. This angle is defined by the `y_offset_multiplier`. This value + # defines the range between y-location of the first and of the last child of a + # parent. + y_offset_multiplier = np.linspace( + max_y_multiplier, min_y_multiplier, np.max(levels) + 1 + ) + + for b in range(self.total_nbranches): + if parents[b] > -1: + start_point = endpoints[parents[b]] + num_children_of_parent = num_children[parents[b]] + y_offset = ( + ((index_of_child[b] / (num_children_of_parent - 1))) - 0.5 + ) * y_offset_multiplier[levels[b]] + else: + start_point = [0, 0] + y_offset = 0.0 + + len_of_path = np.sqrt(y_offset**2 + 1.0) + + end_point = [ + start_point[0] + branch_lens[b] / len_of_path * 1.0, + start_point[1] + branch_lens[b] / len_of_path * y_offset, + ] + endpoints.append(end_point) + + self.xyzr[b][:, :2] = np.asarray([start_point, end_point]) + + def move(self, x: float = 0.0, y: float = 0.0, z: float = 0.0): + """Move cells or networks in the (x, y, z) plane.""" + self._move(x, y, z, self.nodes) + + def _move(self, x: float, y: float, z: float, view): + # Need to cast to set because this will return one columnn per compartment, + # not one column per branch. + indizes = set(view["branch_index"].to_numpy().tolist()) + for i in indizes: + self.xyzr[i][:, 0] += x + self.xyzr[i][:, 1] += y + self.xyzr[i][:, 2] += z + class View: """View of a `Module`.""" @@ -781,7 +833,6 @@ def add_to_group(self, group_name: str): def vis( self, - detail: str = "full", ax=None, col="k", dims=(0, 1), @@ -789,7 +840,6 @@ def vis( ): nodes = self.set_global_index_and_index(self.view) return self.pointer._vis( - detail=detail, ax=ax, col=col, dims=dims, @@ -797,6 +847,10 @@ def vis( morph_plot_kwargs=morph_plot_kwargs, ) + def move(self, x: float = 0.0, y: float = 0.0, z: float = 0.0): + nodes = self.set_global_index_and_index(self.view) + self.pointer._move(x, y, z, nodes) + def adjust_view(self, key: str, index: float): """Update view.""" if isinstance(index, int) or isinstance(index, np.int64): diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index 314e3a7d..5727d413 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -47,7 +47,17 @@ def __init__( branch_list = [branches for _ in range(len(parents))] else: branch_list = branches - self.xyzr = xyzr + + if xyzr is not None: + assert len(xyzr) == len(parents) + self.xyzr = xyzr + else: + # For every branch (`len(parents)`), we have a start and end point (`2`) and + # a (x,y,z,r) coordinate for each of them (`4`). + # Since `xyzr` is only inspected at `.vis()` and because it depends on the + # (potentially learned) length of every compartment, we only populate + # self.xyzr at `.vis()`. + self.xyzr = [float("NaN") * np.zeros((2, 4)) for _ in range(len(parents))] self._append_to_params_and_state(branch_list) for branch in branch_list: diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index a0245b0d..52a54f8b 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -39,7 +39,7 @@ def __init__( self._append_to_params_and_state(cells) for cell in cells: self._append_to_channel_params_and_state(cell) - self.xyzr += cell.xyzr + self.xyzr += deepcopy(cell.xyzr) self._append_synapses_to_params_and_state(connectivities) self.cells = cells @@ -203,6 +203,8 @@ def init_conds(self, params): def init_syns(self): global_pre_comp_inds = [] global_post_comp_inds = [] + global_pre_branch_inds = [] + global_post_branch_inds = [] pre_locs = [] post_locs = [] pre_branch_inds = [] @@ -220,6 +222,18 @@ def init_syns(self): global_post_comp_inds.append( self.cumsum_nbranches[post_cell_inds_] * self.nseg + post_inds ) + global_pre_branch_inds.append( + [ + self.cumsum_nbranches[c.pre_cell_ind] + c.pre_branch_ind + for c in connectivity.conns + ] + ) + global_post_branch_inds.append( + [ + self.cumsum_nbranches[c.post_cell_ind] + c.post_branch_ind + for c in connectivity.conns + ] + ) # Local compartment inds. pre_locs.append(np.asarray([c.pre_loc for c in connectivity.conns])) post_locs.append(np.asarray([c.post_loc for c in connectivity.conns])) @@ -246,6 +260,8 @@ def init_syns(self): "type_ind", "global_pre_comp_index", "global_post_comp_index", + "global_pre_branch_index", + "global_post_branch_index", ] ) for i, connectivity in enumerate(self.connectivities): @@ -264,6 +280,8 @@ def init_syns(self): type_ind=i, global_pre_comp_index=global_pre_comp_inds[i], global_post_comp_index=global_post_comp_inds[i], + global_pre_branch_index=global_pre_branch_inds[i], + global_post_branch_index=global_post_branch_inds[i], ) ), ], @@ -351,9 +369,8 @@ def vis( nx.draw(graph, pos, with_labels=True) else: nx.draw(graph, with_labels=True) - else: + elif detail == "full": ax = self._vis( - detail=detail, dims=dims, col=col, ax=ax, @@ -363,22 +380,36 @@ def vis( pre_locs = self.syn_edges["pre_locs"].to_numpy() post_locs = self.syn_edges["post_locs"].to_numpy() - pre_branch = self.syn_edges["pre_branch_index"].to_numpy() - post_branch = self.syn_edges["post_branch_index"].to_numpy() - pre_cell = self.syn_edges["pre_cell_index"].to_numpy() - post_cell = self.syn_edges["post_cell_index"].to_numpy() + pre_branch = self.syn_edges["global_pre_branch_index"].to_numpy() + post_branch = self.syn_edges["global_post_branch_index"].to_numpy() dims_np = np.asarray(dims) - for pre_loc, post_loc, pre_b, post_b, pre_c, post_c in zip( - pre_locs, post_locs, pre_branch, post_branch, pre_cell, post_cell + for pre_loc, post_loc, pre_b, post_b in zip( + pre_locs, post_locs, pre_branch, post_branch ): - pre_coord = self.cells[pre_c].xyzr[pre_b] - middle_ind = int((len(pre_coord) - 1) * pre_loc) - pre_coord = pre_coord[middle_ind] - post_coord = self.cells[post_c].xyzr[post_b] - middle_ind = int((len(post_coord) - 1) * post_loc) - post_coord = post_coord[middle_ind] + pre_coord = self.xyzr[pre_b] + if len(pre_coord) == 2: + # If only start and end point of a branch are traced, perform a + # linear interpolation to get the synpase location. + pre_coord = pre_coord[0] + (pre_coord[1] - pre_coord[0]) * pre_loc + else: + # If densely traced, use intermediate trace values for synapse loc. + middle_ind = int((len(pre_coord) - 1) * pre_loc) + pre_coord = pre_coord[middle_ind] + + post_coord = self.xyzr[post_b] + if len(post_coord) == 2: + # If only start and end point of a branch are traced, perform a + # linear interpolation to get the synpase location. + post_coord = ( + post_coord[0] + (post_coord[1] - post_coord[0]) * post_loc + ) + else: + # If densely traced, use intermediate trace values for synapse loc. + middle_ind = int((len(post_coord) - 1) * post_loc) + post_coord = post_coord[middle_ind] + coords = np.stack([pre_coord[dims_np], post_coord[dims_np]]).T ax.plot( coords[0], @@ -393,6 +424,8 @@ def vis( c=synapse_col, **synapse_scatter_kwargs, ) + else: + raise ValueError("detail must be in {full, point}.") return ax diff --git a/jaxley/utils/__init__.py b/jaxley/utils/__init__.py index f2bbf832..e69de29b 100644 --- a/jaxley/utils/__init__.py +++ b/jaxley/utils/__init__.py @@ -1 +0,0 @@ -from jaxley.utils.plot_utils import plot_morph, plot_swc diff --git a/jaxley/utils/plot_utils.py b/jaxley/utils/plot_utils.py index 4edd5a7e..965206e9 100644 --- a/jaxley/utils/plot_utils.py +++ b/jaxley/utils/plot_utils.py @@ -1,93 +1,10 @@ from typing import Dict import matplotlib.pyplot as plt -import numpy as np -from jaxley.utils.cell_utils import ( - _compute_index_of_child, - _compute_num_children, - compute_levels, -) - -def plot_morph( - cell: "jx.Cell", - col="k", - max_y_multiplier: float = 5.0, - min_y_multiplier: float = 0.5, - ax=None, -): - """Plot the stick representation of a morphology. - - This method operates at the branch level. It does not plot individual compartments, - but only individual branches. It also ignores the radius, but it takes into account - the lengths. - - Args: - cell: The `Cell` object to be plotted. - figsize: Size of the figure. - - Returns: - `fig, ax` of the plot. - """ - parents = cell.comb_parents - num_children = _compute_num_children(parents) - index_of_child = _compute_index_of_child(parents) - levels = compute_levels(parents) - - # Extract branch. - inds_branch = cell.nodes.groupby("branch_index")["comp_index"].apply(list) - branch_lens = [np.sum(cell.params["length"][np.asarray(i)]) for i in inds_branch] - endpoints = [] - - # Different levels will get a different "angle" at which the children emerge from - # the parents. This angle is defined by the `y_offset_multiplier`. This value - # defines the range between y-location of the first and of the last child of a - # parent. - y_offset_multiplier = np.linspace( - max_y_multiplier, min_y_multiplier, np.max(levels) + 1 - ) - - if ax is None: - fig, ax = plt.subplots(1, 1, figsize=(3, 3)) - - for b in range(cell.total_nbranches): - if parents[b] > -1: - start_point = endpoints[parents[b]] - num_children_of_parent = num_children[parents[b]] - y_offset = ( - ((index_of_child[b] / (num_children_of_parent - 1))) - 0.5 - ) * y_offset_multiplier[levels[b]] - else: - start_point = [0, 0] - y_offset = 0.0 - - len_of_path = np.sqrt(y_offset**2 + 1.0) - - end_point = [ - start_point[0] + branch_lens[b] / len_of_path * 1.0, - start_point[1] + branch_lens[b] / len_of_path * y_offset, - ] - endpoints.append(end_point) - - _ = ax.plot( - [start_point[0], end_point[0]], - [start_point[1], end_point[1]], - c=col, - label=f"ind {b}", - ) - - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) - ax.set_xlabel(r"$\mu$m") - ax.set_ylabel(r"$\mu$m") - plt.axis("square") - - return fig, ax - - -def plot_swc(xyzr, dims=(0, 1), col="k", ax=None, morph_plot_kwargs: Dict = None): - """Plot morphology given an SWC file. +def plot_morph(xyzr, dims=(0, 1), col="k", ax=None, morph_plot_kwargs: Dict = None): + """Plot morphology. Args: dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of diff --git a/tests/test_plotting_api.py b/tests/test_plotting_api.py index f5fb866f..844f5960 100644 --- a/tests/test_plotting_api.py +++ b/tests/test_plotting_api.py @@ -19,13 +19,13 @@ def test_cell(): # Plot 1. _, ax = plt.subplots(1, 1, figsize=(3, 3)) - ax = cell.vis(detail="full", ax=ax) - ax = cell.branch([0, 1, 2]).vis(detail="full", ax=ax, col="r") + ax = cell.vis(ax=ax) + ax = cell.branch([0, 1, 2]).vis(ax=ax, col="r") # Plot 2. cell.branch(0).add_to_group("soma") cell.branch(1).add_to_group("soma") - ax = cell.soma.vis(detail="full") + ax = cell.soma.vis() def test_network(): @@ -45,9 +45,9 @@ def test_network(): # Plot 1. _, ax = plt.subplots(1, 1, figsize=(3, 3)) - ax = net.cell([0, 1]).vis(detail="full", ax=ax) - ax = net.cell(2).vis(detail="full", ax=ax, col="r") - ax = net.cell(0).branch(np.arange(10).tolist()).vis(detail="full", ax=ax, col="b") + ax = net.cell([0, 1]).vis(ax=ax) + ax = net.cell(2).vis(ax=ax, col="r") + ax = net.cell(0).branch(np.arange(10).tolist()).vis(ax=ax, col="b") # Plot 2. ax = net.vis(detail="full") @@ -61,4 +61,28 @@ def test_network(): # Plot 5. net.cell(0).add_to_group("excitatory") net.cell(1).add_to_group("excitatory") - ax = net.excitatory.vis(detail="full") + ax = net.excitatory.vis() + + +def test_vis_networks_built_from_scartch(): + comp = jx.Compartment() + branch = jx.Branch(comp, 4) + cell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1]) + + conns = [ + jx.Connectivity( + GlutamateSynapse(), + [jx.Connection(0, 0, 0.0, 1, 0, 0.0), jx.Connection(0, 0, 0.0, 1, 2, 0.0)], + ) + ] + net = jx.Network([cell, cell], conns) + net.compute_xyz() + + # Plot 1. + _, ax = plt.subplots(1, 1, figsize=(3, 3)) + ax = net.vis(detail="full", ax=ax) + + # Plot 2. + _, ax = plt.subplots(1, 1, figsize=(3, 3)) + net.cell(0).move(0, 100) + ax = net.vis(detail="full", ax=ax)