From 11064790a1bcca4c404934829388bbf87b66ce99 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Wed, 6 Dec 2023 10:08:21 +0100 Subject: [PATCH 1/6] Draft plotting function for network API --- jaxley/modules/network.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 67678264..c118f01e 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -277,6 +277,17 @@ def _step_synapse( return new_syn_states, syn_voltage_terms, syn_constant_terms + def vis(self, detail: str = "point", **options): + """Visualize the network. + + Args: + detail: Either of [point, sticks, full]. `point` visualizes every neuron as a + point. `sticks` visualizes all branches of every neuron, but draws + branches as straight lines. `full` plots the full morphology of every + neuron, as read from the SWC file. + options: Plotting options passed to `NetworkX.draw()`. + """ + class SynapseView(View): """SynapseView.""" From 5de78b48700f3631ed0ac10d9bf1897d17c9e520 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Wed, 6 Dec 2023 15:28:35 +0100 Subject: [PATCH 2/6] can visualize point neurons --- jaxley/modules/network.py | 60 ++++++++++++++++++++++++++++++++++----- 1 file changed, 53 insertions(+), 7 deletions(-) diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index c118f01e..8cd7971c 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -6,6 +6,8 @@ import numpy as np import pandas as pd from jax import vmap +import networkx as nx + from jaxley.connection import Connectivity from jaxley.modules.base import Module, View @@ -200,20 +202,37 @@ def init_conds(self, params): def init_syns(self): pre_comp_inds = [] post_comp_inds = [] + pre_branch_inds = [] + post_branch_inds = [] + pre_cell_inds = [] + post_cell_inds = [] for connectivity in self.connectivities: - pre_cell_inds, pre_inds, post_cell_inds, post_inds = prepare_syn( + pre_cell_inds_, pre_inds, post_cell_inds_, post_inds = prepare_syn( connectivity.conns, self.nseg ) pre_comp_inds.append( - self.cumsum_nbranches[pre_cell_inds] * self.nseg + pre_inds + self.cumsum_nbranches[pre_cell_inds_] * self.nseg + pre_inds ) post_comp_inds.append( - self.cumsum_nbranches[post_cell_inds] * self.nseg + post_inds + self.cumsum_nbranches[post_cell_inds_] * self.nseg + post_inds ) + pre_branch_inds.append(self.cumsum_nbranches[pre_cell_inds_]) + post_branch_inds.append(self.cumsum_nbranches[post_cell_inds_]) + pre_cell_inds.append(pre_cell_inds_) + post_cell_inds.append(post_cell_inds_) # Prepare synapses. self.syn_edges = pd.DataFrame( - columns=["pre_comp_index", "post_comp_index", "type", "type_ind"] + columns=[ + "pre_comp_index", + "pre_branch_index", + "pre_cell_index", + "post_comp_index", + "post_branch_index", + "post_cell_index", + "type", + "type_ind", + ] ) for i, connectivity in enumerate(self.connectivities): self.syn_edges = pd.concat( @@ -222,7 +241,11 @@ def init_syns(self): pd.DataFrame( dict( pre_comp_index=pre_comp_inds[i], + pre_branch_index=pre_branch_inds[i], + pre_cell_index=pre_cell_inds[i], post_comp_index=post_comp_inds[i], + post_branch_index=post_branch_inds[i], + post_cell_index=post_cell_inds[i], type=type(connectivity.synapse_type).__name__, type_ind=i, ) @@ -277,16 +300,39 @@ def _step_synapse( return new_syn_states, syn_voltage_terms, syn_constant_terms - def vis(self, detail: str = "point", **options): + def vis(self, detail: str = "point", layers: Optional[List] = None, **options): """Visualize the network. Args: - detail: Either of [point, sticks, full]. `point` visualizes every neuron as a - point. `sticks` visualizes all branches of every neuron, but draws + detail: Currently, only `point` is supported. In the future, either of + [point, sticks, full] will be supported. `point` visualizes every neuron + as a point. `sticks` visualizes all branches of every neuron, but draws branches as straight lines. `full` plots the full morphology of every neuron, as read from the SWC file. + layers: Allows to plot the network in layers. Should provide the number of + neurons in each layer, e.g., [5, 10, 1] would be a network with 5 input + neurons, 10 hidden layer neurons, and 1 output neuron. options: Plotting options passed to `NetworkX.draw()`. """ + G = nx.DiGraph() + + def build_extents(*subset_sizes): + return nx.utils.pairwise(itertools.accumulate((0,) + subset_sizes)) + + if layers is not None: + extents = build_extents(*layers) + layers = [range(start, end) for start, end in extents] + for i, layer in enumerate(layers): + G.add_nodes_from(layer, layer=i) + else: + G.add_nodes_from(range(len(self.cells))) + + pre_cell = self.syn_edges["pre_cell_index"].to_numpy() + post_cell = self.syn_edges["post_cell_index"].to_numpy() + + inds = np.stack([pre_cell, post_cell]).T + G.add_edges_from(inds) + return G class SynapseView(View): From 108c740925f9f9aa0eef1ddc08ce865f49bded05 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Wed, 6 Dec 2023 15:38:24 +0100 Subject: [PATCH 3/6] write plotting function for point neurons --- jaxley/modules/network.py | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 8cd7971c..18991825 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -1,6 +1,7 @@ import itertools from copy import deepcopy -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Dict, List, Tuple, Optional, Union + import jax.numpy as jnp import numpy as np @@ -300,7 +301,13 @@ def _step_synapse( return new_syn_states, syn_voltage_terms, syn_constant_terms - def vis(self, detail: str = "point", layers: Optional[List] = None, **options): + def vis( + self, + detail: str = "point", + layers: Optional[List] = None, + figsize: Tuple = (5, 5), + **options, + ) -> None: """Visualize the network. Args: @@ -314,7 +321,18 @@ def vis(self, detail: str = "point", layers: Optional[List] = None, **options): neurons, 10 hidden layer neurons, and 1 output neuron. options: Plotting options passed to `NetworkX.draw()`. """ - G = nx.DiGraph() + assert detail == "point", "Only `point` is implemented." + + graph = self._build_graph(layers, **options) + + if layers is not None: + pos = nx.multipartite_layout(graph, subset_key="layer") + nx.draw(graph, pos, with_labels=True) + else: + nx.draw(graph, with_labels=True) + + def _build_graph(self, layers: Optional[List] = None, **options): + graph = nx.DiGraph() def build_extents(*subset_sizes): return nx.utils.pairwise(itertools.accumulate((0,) + subset_sizes)) @@ -323,16 +341,17 @@ def build_extents(*subset_sizes): extents = build_extents(*layers) layers = [range(start, end) for start, end in extents] for i, layer in enumerate(layers): - G.add_nodes_from(layer, layer=i) + graph.add_nodes_from(layer, layer=i) else: - G.add_nodes_from(range(len(self.cells))) + graph.add_nodes_from(range(len(self.cells))) pre_cell = self.syn_edges["pre_cell_index"].to_numpy() post_cell = self.syn_edges["post_cell_index"].to_numpy() inds = np.stack([pre_cell, post_cell]).T - G.add_edges_from(inds) - return G + graph.add_edges_from(inds) + + return graph class SynapseView(View): From 06820d82979e37b94e2ee55b2de4c29f9502c397 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Wed, 6 Dec 2023 15:39:21 +0100 Subject: [PATCH 4/6] add networkx to setup.py --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 4f76b054..fb450831 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,7 @@ "numpy", "pandas", "matplotlib", + "networkx", ] From 448c3e18edd32beb5cd0763995075e92a52d262b Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Wed, 6 Dec 2023 15:39:43 +0100 Subject: [PATCH 5/6] bugfix --- jaxley/modules/network.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 18991825..e528230e 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -305,7 +305,6 @@ def vis( self, detail: str = "point", layers: Optional[List] = None, - figsize: Tuple = (5, 5), **options, ) -> None: """Visualize the network. From 46c7187785e9b2427f4e46d3634773a0f8a8c8a4 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Wed, 6 Dec 2023 15:42:00 +0100 Subject: [PATCH 6/6] isort --- jaxley/modules/network.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index e528230e..0a917e3e 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -1,14 +1,12 @@ import itertools from copy import deepcopy -from typing import Callable, Dict, List, Tuple, Optional, Union - +from typing import Callable, Dict, List, Optional, Tuple, Union import jax.numpy as jnp +import networkx as nx import numpy as np import pandas as pd from jax import vmap -import networkx as nx - from jaxley.connection import Connectivity from jaxley.modules.base import Module, View