diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 67678264..0a917e3e 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -1,8 +1,9 @@ import itertools from copy import deepcopy -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import jax.numpy as jnp +import networkx as nx import numpy as np import pandas as pd from jax import vmap @@ -200,20 +201,37 @@ def init_conds(self, params): def init_syns(self): pre_comp_inds = [] post_comp_inds = [] + pre_branch_inds = [] + post_branch_inds = [] + pre_cell_inds = [] + post_cell_inds = [] for connectivity in self.connectivities: - pre_cell_inds, pre_inds, post_cell_inds, post_inds = prepare_syn( + pre_cell_inds_, pre_inds, post_cell_inds_, post_inds = prepare_syn( connectivity.conns, self.nseg ) pre_comp_inds.append( - self.cumsum_nbranches[pre_cell_inds] * self.nseg + pre_inds + self.cumsum_nbranches[pre_cell_inds_] * self.nseg + pre_inds ) post_comp_inds.append( - self.cumsum_nbranches[post_cell_inds] * self.nseg + post_inds + self.cumsum_nbranches[post_cell_inds_] * self.nseg + post_inds ) + pre_branch_inds.append(self.cumsum_nbranches[pre_cell_inds_]) + post_branch_inds.append(self.cumsum_nbranches[post_cell_inds_]) + pre_cell_inds.append(pre_cell_inds_) + post_cell_inds.append(post_cell_inds_) # Prepare synapses. self.syn_edges = pd.DataFrame( - columns=["pre_comp_index", "post_comp_index", "type", "type_ind"] + columns=[ + "pre_comp_index", + "pre_branch_index", + "pre_cell_index", + "post_comp_index", + "post_branch_index", + "post_cell_index", + "type", + "type_ind", + ] ) for i, connectivity in enumerate(self.connectivities): self.syn_edges = pd.concat( @@ -222,7 +240,11 @@ def init_syns(self): pd.DataFrame( dict( pre_comp_index=pre_comp_inds[i], + pre_branch_index=pre_branch_inds[i], + pre_cell_index=pre_cell_inds[i], post_comp_index=post_comp_inds[i], + post_branch_index=post_branch_inds[i], + post_cell_index=post_cell_inds[i], type=type(connectivity.synapse_type).__name__, type_ind=i, ) @@ -277,6 +299,57 @@ def _step_synapse( return new_syn_states, syn_voltage_terms, syn_constant_terms + def vis( + self, + detail: str = "point", + layers: Optional[List] = None, + **options, + ) -> None: + """Visualize the network. + + Args: + detail: Currently, only `point` is supported. In the future, either of + [point, sticks, full] will be supported. `point` visualizes every neuron + as a point. `sticks` visualizes all branches of every neuron, but draws + branches as straight lines. `full` plots the full morphology of every + neuron, as read from the SWC file. + layers: Allows to plot the network in layers. Should provide the number of + neurons in each layer, e.g., [5, 10, 1] would be a network with 5 input + neurons, 10 hidden layer neurons, and 1 output neuron. + options: Plotting options passed to `NetworkX.draw()`. + """ + assert detail == "point", "Only `point` is implemented." + + graph = self._build_graph(layers, **options) + + if layers is not None: + pos = nx.multipartite_layout(graph, subset_key="layer") + nx.draw(graph, pos, with_labels=True) + else: + nx.draw(graph, with_labels=True) + + def _build_graph(self, layers: Optional[List] = None, **options): + graph = nx.DiGraph() + + def build_extents(*subset_sizes): + return nx.utils.pairwise(itertools.accumulate((0,) + subset_sizes)) + + if layers is not None: + extents = build_extents(*layers) + layers = [range(start, end) for start, end in extents] + for i, layer in enumerate(layers): + graph.add_nodes_from(layer, layer=i) + else: + graph.add_nodes_from(range(len(self.cells))) + + pre_cell = self.syn_edges["pre_cell_index"].to_numpy() + post_cell = self.syn_edges["post_cell_index"].to_numpy() + + inds = np.stack([pre_cell, post_cell]).T + graph.add_edges_from(inds) + + return graph + class SynapseView(View): """SynapseView.""" diff --git a/setup.py b/setup.py index 4f76b054..fb450831 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,7 @@ "numpy", "pandas", "matplotlib", + "networkx", ]