Skip to content

Commit

Permalink
write plotting function for point neurons
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Dec 6, 2023
1 parent 5de78b4 commit 108c740
Showing 1 changed file with 26 additions and 7 deletions.
33 changes: 26 additions & 7 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools
from copy import deepcopy
from typing import Callable, Dict, List, Optional, Union
from typing import Callable, Dict, List, Tuple, Optional, Union


import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -300,7 +301,13 @@ def _step_synapse(

return new_syn_states, syn_voltage_terms, syn_constant_terms

def vis(self, detail: str = "point", layers: Optional[List] = None, **options):
def vis(
self,
detail: str = "point",
layers: Optional[List] = None,
figsize: Tuple = (5, 5),
**options,
) -> None:
"""Visualize the network.
Args:
Expand All @@ -314,7 +321,18 @@ def vis(self, detail: str = "point", layers: Optional[List] = None, **options):
neurons, 10 hidden layer neurons, and 1 output neuron.
options: Plotting options passed to `NetworkX.draw()`.
"""
G = nx.DiGraph()
assert detail == "point", "Only `point` is implemented."

graph = self._build_graph(layers, **options)

if layers is not None:
pos = nx.multipartite_layout(graph, subset_key="layer")
nx.draw(graph, pos, with_labels=True)
else:
nx.draw(graph, with_labels=True)

def _build_graph(self, layers: Optional[List] = None, **options):
graph = nx.DiGraph()

def build_extents(*subset_sizes):
return nx.utils.pairwise(itertools.accumulate((0,) + subset_sizes))
Expand All @@ -323,16 +341,17 @@ def build_extents(*subset_sizes):
extents = build_extents(*layers)
layers = [range(start, end) for start, end in extents]
for i, layer in enumerate(layers):
G.add_nodes_from(layer, layer=i)
graph.add_nodes_from(layer, layer=i)
else:
G.add_nodes_from(range(len(self.cells)))
graph.add_nodes_from(range(len(self.cells)))

pre_cell = self.syn_edges["pre_cell_index"].to_numpy()
post_cell = self.syn_edges["post_cell_index"].to_numpy()

inds = np.stack([pre_cell, post_cell]).T
G.add_edges_from(inds)
return G
graph.add_edges_from(inds)

return graph


class SynapseView(View):
Expand Down

0 comments on commit 108c740

Please sign in to comment.