Skip to content

Commit

Permalink
allow plotting networks
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Dec 6, 2023
1 parent add3886 commit f37adf6
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 14 deletions.
6 changes: 6 additions & 0 deletions jaxley/modules/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ def vis(
dims=(0, 1),
cols="k",
highlight_branch_inds=[],
fig=None,
ax=None,
max_y_multiplier: float = 5.0,
min_y_multiplier: float = 0.5,
) -> None:
Expand All @@ -263,6 +265,8 @@ def vis(
highlight_branch_inds=highlight_branch_inds,
max_y_multiplier=max_y_multiplier,
min_y_multiplier=min_y_multiplier,
fig=fig,
ax=ax,
)
elif detail == "full":
assert self.xyzr is not None, "no coordinates, use `vis(detail='sticks')`."
Expand All @@ -272,6 +276,8 @@ def vis(
dims=dims,
cols=cols,
highlight_branch_inds=highlight_branch_inds,
fig=fig,
ax=ax,
)
else:
raise ValueError("`detail must be in {sticks, full}.")
Expand Down
42 changes: 29 additions & 13 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import itertools
from copy import deepcopy
from typing import Callable, Dict, List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import jax.numpy as jnp
import networkx as nx
import numpy as np
Expand Down Expand Up @@ -303,30 +303,46 @@ def vis(
self,
detail: str = "point",
layers: Optional[List] = None,
figsize=(4, 4),
dims=(0, 1),
cols: Union[str, List[str]] = "k",
highlight=[],
fig=None,
ax=None,
**options,
) -> None:
"""Visualize the network.
Args:
detail: Currently, only `point` is supported. In the future, either of
[point, sticks, full] will be supported. `point` visualizes every neuron
as a point. `sticks` visualizes all branches of every neuron, but draws
branches as straight lines. `full` plots the full morphology of every
detail: Either of [point, full]. `point` visualizes every neuron
as a point. `full` plots the full morphology of every
neuron, as read from the SWC file.
layers: Allows to plot the network in layers. Should provide the number of
neurons in each layer, e.g., [5, 10, 1] would be a network with 5 input
neurons, 10 hidden layer neurons, and 1 output neuron.
options: Plotting options passed to `NetworkX.draw()`.
cols: One color in total or one color per cell.
"""
assert detail == "point", "Only `point` is implemented."

graph = self._build_graph(layers, **options)

if layers is not None:
pos = nx.multipartite_layout(graph, subset_key="layer")
nx.draw(graph, pos, with_labels=True)
if detail == "point":
graph = self._build_graph(layers, **options)

if layers is not None:
pos = nx.multipartite_layout(graph, subset_key="layer")
nx.draw(graph, pos, with_labels=True)
else:
nx.draw(graph, with_labels=True)
elif detail == "full":
if fig is None or ax is None:
fig, ax = plt.subplots(1, 1, figsize=figsize)

if isinstance(cols, str):
cols = [cols] * len(self.cells)

for cell, col in zip(self.cells, cols):
fig, ax = cell.vis(detail="full", dims=dims, cols=col, fig=fig, ax=ax)
return fig, ax
else:
nx.draw(graph, with_labels=True)
raise ValueError("detail must be in {point, full}")

def _build_graph(self, layers: Optional[List] = None, **options):
graph = nx.DiGraph()
Expand Down
7 changes: 6 additions & 1 deletion jaxley/utils/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def plot_morph(
highlight_branch_inds=[],
max_y_multiplier: float = 5.0,
min_y_multiplier: float = 0.5,
fig=None,
ax=None,
):
"""Plot the stick representation of a morphology.
Expand Down Expand Up @@ -118,6 +120,8 @@ def plot_swc(
dims=(0, 1),
cols="k",
highlight_branch_inds=[],
fig=None,
ax=None,
):
"""Plot morphology given an SWC file.
Expand All @@ -132,7 +136,8 @@ def plot_swc(
counter_highlight_branches = 0
lines = []

fig, ax = plt.subplots(1, 1, figsize=figsize)
if fig is None or ax is None:
fig, ax = plt.subplots(1, 1, figsize=figsize)
for i, coords_of_branch in enumerate(xyzr):
coords_to_plot = coords_of_branch[:, dims]
col = cols[i]
Expand Down

0 comments on commit f37adf6

Please sign in to comment.