diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index 71d078e2..6476365e 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -232,11 +232,13 @@ def update_summed_coupling_conds( def vis( self, - detail: str = "full", + detail: str = "sticks", figsize=(4, 4), dims=(0, 1), cols="k", highlight_branch_inds=[], + max_y_multiplier: float = 5.0, + min_y_multiplier: float = 0.5, ) -> None: """Visualize the network. @@ -248,10 +250,22 @@ def vis( neurons in each layer, e.g., [5, 10, 1] would be a network with 5 input neurons, 10 hidden layer neurons, and 1 output neuron. options: Plotting options passed to `NetworkX.draw()`. + dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of + two of them. + cols: The color for all branches except the highlighted ones. + highlight_branch_inds: Branch indices that will be highlighted. """ if detail == "sticks": - raise NotImplementedError + fig, ax = plot_morph( + cell=self, + figsize=figsize, + cols=cols, + highlight_branch_inds=highlight_branch_inds, + max_y_multiplier=max_y_multiplier, + min_y_multiplier=min_y_multiplier, + ) elif detail == "full": + assert self.xyzr is not None fig, ax = plot_swc( self.xyzr, figsize=figsize, diff --git a/jaxley/utils/plot_utils.py b/jaxley/utils/plot_utils.py index 1adb109d..d63e35df 100644 --- a/jaxley/utils/plot_utils.py +++ b/jaxley/utils/plot_utils.py @@ -6,7 +6,6 @@ _compute_num_children, compute_levels, ) -from jaxley.utils.swc import _build_parents, _split_into_branches_and_sort highlight_cols = [ "#1f78b4",