diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index e0a3787c..18681bd1 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -237,6 +237,8 @@ def vis( dims=(0, 1), cols="k", highlight_branch_inds=[], + fig=None, + ax=None, max_y_multiplier: float = 5.0, min_y_multiplier: float = 0.5, ) -> None: @@ -263,6 +265,8 @@ def vis( highlight_branch_inds=highlight_branch_inds, max_y_multiplier=max_y_multiplier, min_y_multiplier=min_y_multiplier, + fig=fig, + ax=ax, ) elif detail == "full": assert self.xyzr is not None, "no coordinates, use `vis(detail='sticks')`." @@ -272,6 +276,8 @@ def vis( dims=dims, cols=cols, highlight_branch_inds=highlight_branch_inds, + fig=fig, + ax=ax, ) else: raise ValueError("`detail must be in {sticks, full}.") diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 0a917e3e..7d5190b8 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -1,7 +1,7 @@ import itertools from copy import deepcopy from typing import Callable, Dict, List, Optional, Tuple, Union - +import matplotlib.pyplot as plt import jax.numpy as jnp import networkx as nx import numpy as np @@ -303,30 +303,46 @@ def vis( self, detail: str = "point", layers: Optional[List] = None, + figsize=(4, 4), + dims=(0, 1), + cols: Union[str, List[str]] = "k", + highlight=[], + fig=None, + ax=None, **options, ) -> None: """Visualize the network. Args: - detail: Currently, only `point` is supported. In the future, either of - [point, sticks, full] will be supported. `point` visualizes every neuron - as a point. `sticks` visualizes all branches of every neuron, but draws - branches as straight lines. `full` plots the full morphology of every + detail: Either of [point, full]. `point` visualizes every neuron + as a point. `full` plots the full morphology of every neuron, as read from the SWC file. layers: Allows to plot the network in layers. Should provide the number of neurons in each layer, e.g., [5, 10, 1] would be a network with 5 input neurons, 10 hidden layer neurons, and 1 output neuron. options: Plotting options passed to `NetworkX.draw()`. + cols: One color in total or one color per cell. """ - assert detail == "point", "Only `point` is implemented." - - graph = self._build_graph(layers, **options) - - if layers is not None: - pos = nx.multipartite_layout(graph, subset_key="layer") - nx.draw(graph, pos, with_labels=True) + if detail == "point": + graph = self._build_graph(layers, **options) + + if layers is not None: + pos = nx.multipartite_layout(graph, subset_key="layer") + nx.draw(graph, pos, with_labels=True) + else: + nx.draw(graph, with_labels=True) + elif detail == "full": + if fig is None or ax is None: + fig, ax = plt.subplots(1, 1, figsize=figsize) + + if isinstance(cols, str): + cols = [cols] * len(self.cells) + + for cell, col in zip(self.cells, cols): + fig, ax = cell.vis(detail="full", dims=dims, cols=col, fig=fig, ax=ax) + return fig, ax else: - nx.draw(graph, with_labels=True) + raise ValueError("detail must be in {point, full}") def _build_graph(self, layers: Optional[List] = None, **options): graph = nx.DiGraph() diff --git a/jaxley/utils/plot_utils.py b/jaxley/utils/plot_utils.py index d63e35df..80ca052c 100644 --- a/jaxley/utils/plot_utils.py +++ b/jaxley/utils/plot_utils.py @@ -30,6 +30,8 @@ def plot_morph( highlight_branch_inds=[], max_y_multiplier: float = 5.0, min_y_multiplier: float = 0.5, + fig=None, + ax=None, ): """Plot the stick representation of a morphology. @@ -118,6 +120,8 @@ def plot_swc( dims=(0, 1), cols="k", highlight_branch_inds=[], + fig=None, + ax=None, ): """Plot morphology given an SWC file. @@ -132,7 +136,8 @@ def plot_swc( counter_highlight_branches = 0 lines = [] - fig, ax = plt.subplots(1, 1, figsize=figsize) + if fig is None or ax is None: + fig, ax = plt.subplots(1, 1, figsize=figsize) for i, coords_of_branch in enumerate(xyzr): coords_to_plot = coords_of_branch[:, dims] col = cols[i]