diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 8cd7971c..18991825 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -1,6 +1,7 @@ import itertools from copy import deepcopy -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Dict, List, Tuple, Optional, Union + import jax.numpy as jnp import numpy as np @@ -300,7 +301,13 @@ def _step_synapse( return new_syn_states, syn_voltage_terms, syn_constant_terms - def vis(self, detail: str = "point", layers: Optional[List] = None, **options): + def vis( + self, + detail: str = "point", + layers: Optional[List] = None, + figsize: Tuple = (5, 5), + **options, + ) -> None: """Visualize the network. Args: @@ -314,7 +321,18 @@ def vis(self, detail: str = "point", layers: Optional[List] = None, **options): neurons, 10 hidden layer neurons, and 1 output neuron. options: Plotting options passed to `NetworkX.draw()`. """ - G = nx.DiGraph() + assert detail == "point", "Only `point` is implemented." + + graph = self._build_graph(layers, **options) + + if layers is not None: + pos = nx.multipartite_layout(graph, subset_key="layer") + nx.draw(graph, pos, with_labels=True) + else: + nx.draw(graph, with_labels=True) + + def _build_graph(self, layers: Optional[List] = None, **options): + graph = nx.DiGraph() def build_extents(*subset_sizes): return nx.utils.pairwise(itertools.accumulate((0,) + subset_sizes)) @@ -323,16 +341,17 @@ def build_extents(*subset_sizes): extents = build_extents(*layers) layers = [range(start, end) for start, end in extents] for i, layer in enumerate(layers): - G.add_nodes_from(layer, layer=i) + graph.add_nodes_from(layer, layer=i) else: - G.add_nodes_from(range(len(self.cells))) + graph.add_nodes_from(range(len(self.cells))) pre_cell = self.syn_edges["pre_cell_index"].to_numpy() post_cell = self.syn_edges["post_cell_index"].to_numpy() inds = np.stack([pre_cell, post_cell]).T - G.add_edges_from(inds) - return G + graph.add_edges_from(inds) + + return graph class SynapseView(View):