From 68dba2e021dcc6d2f9ae473deb809a0574336e25 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Wed, 6 Dec 2023 16:38:55 +0100 Subject: [PATCH 1/6] introduce option to pass xyzr --- jaxley/modules/cell.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index c8832760..e10374a6 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -23,7 +23,19 @@ class Cell(Module): cell_params: Dict = {} cell_states: Dict = {} - def __init__(self, branches: Union[Branch, List[Branch]], parents: List): + def __init__( + self, + branches: Union[Branch, List[Branch]], + parents: List, + xyzr: Optional[Union[np.ndarray, jnp.ndarray]], + ): + """Initialize a cell. + + Args: + branches: + parents: + xyzr: The x, y, and z coordinates and the radius at these coordinates. + """ super().__init__() assert isinstance(branches, Branch) or len(parents) == len( branches @@ -249,7 +261,7 @@ def read_swc( comp = Compartment().initialize() branch = Branch([comp for _ in range(nseg)]).initialize() - cell = Cell([branch for _ in range(nbranches)], parents=parents) + cell = Cell([branch for _ in range(nbranches)], parents=parents, xyzr=None) radiuses = np.flip( np.asarray([radius_fns[b](range_) for b in range(len(parents))]), axis=1 From 5e0dc180ea24a178d60508100cc84d2681a6b059 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Wed, 6 Dec 2023 16:57:56 +0100 Subject: [PATCH 2/6] allow reading coords from SWC --- jaxley/modules/cell.py | 13 +++++++++---- jaxley/utils/swc.py | 8 +++++++- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index e10374a6..385e8fe4 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -27,14 +27,16 @@ def __init__( self, branches: Union[Branch, List[Branch]], parents: List, - xyzr: Optional[Union[np.ndarray, jnp.ndarray]], + xyzr: Optional[List[np.ndarray]] = None, ): """Initialize a cell. Args: branches: parents: - xyzr: The x, y, and z coordinates and the radius at these coordinates. + xyzr: For every branch, the x, y, and z coordinates and the radius at the + traced coordinates. Note that this is the full tracing (from SWC), not + the stick representation coordinates. """ super().__init__() assert isinstance(branches, Branch) or len(parents) == len( @@ -45,6 +47,7 @@ def __init__( branch_list = [branches for _ in range(len(parents))] else: branch_list = branches + self.xyzr = xyzr self._append_to_params_and_state(branch_list) for branch in branch_list: @@ -251,7 +254,7 @@ def read_swc( min_radius: Optional[float] = None, ): """Reads SWC file into a `jx.Cell`.""" - parents, pathlengths, radius_fns, _ = swc_to_jaxley( + parents, pathlengths, radius_fns, _, coords_of_branches = swc_to_jaxley( fname, max_branch_len=max_branch_len, sort=True, num_lines=None ) nbranches = len(parents) @@ -261,7 +264,9 @@ def read_swc( comp = Compartment().initialize() branch = Branch([comp for _ in range(nseg)]).initialize() - cell = Cell([branch for _ in range(nbranches)], parents=parents, xyzr=None) + cell = Cell( + [branch for _ in range(nbranches)], parents=parents, xyzr=coords_of_branches + ) radiuses = np.flip( np.asarray([radius_fns[b](range_) for b in range(len(parents))]), axis=1 diff --git a/jaxley/utils/swc.py b/jaxley/utils/swc.py index 2c73592a..98a05965 100644 --- a/jaxley/utils/swc.py +++ b/jaxley/utils/swc.py @@ -41,8 +41,14 @@ def swc_to_jaxley( parents = parents.tolist() pathlengths = [0.1] + pathlengths radius_fns = [lambda x: content[0, 5] * np.ones_like(x)] + radius_fns + sorted_branches = [[0]] + sorted_branches - return parents, pathlengths, radius_fns, types + all_coords_of_branches = [] + for i, branch in enumerate(sorted_branches): + coords_of_branch = content[np.asarray(branch) - 1, 2:5] + all_coords_of_branches.append(coords_of_branch) + + return parents, pathlengths, radius_fns, types, all_coords_of_branches def _split_into_branches_and_sort( From 5c32c7f5850f82b18167d9f6de83cfcd5d16bab9 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Wed, 6 Dec 2023 17:15:59 +0100 Subject: [PATCH 3/6] cell is visualized via .vis() method --- jaxley/modules/cell.py | 35 +++++++++++++++++++++++++++++++++++ jaxley/utils/plot_utils.py | 20 +++++--------------- 2 files changed, 40 insertions(+), 15 deletions(-) diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index 385e8fe4..71d078e2 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -15,6 +15,7 @@ compute_levels, ) from jaxley.utils.swc import swc_to_jaxley +from jaxley.utils.plot_utils import plot_morph, plot_swc class Cell(Module): @@ -229,6 +230,40 @@ def update_summed_coupling_conds( ) return summed_conds + def vis( + self, + detail: str = "full", + figsize=(4, 4), + dims=(0, 1), + cols="k", + highlight_branch_inds=[], + ) -> None: + """Visualize the network. + + Args: + detail: Either of [sticks, full]. `sticks` visualizes all branches of every + neuron, but draws branches as straight lines. `full` plots the full + morphology of every neuron, as read from the SWC file. + layers: Allows to plot the network in layers. Should provide the number of + neurons in each layer, e.g., [5, 10, 1] would be a network with 5 input + neurons, 10 hidden layer neurons, and 1 output neuron. + options: Plotting options passed to `NetworkX.draw()`. + """ + if detail == "sticks": + raise NotImplementedError + elif detail == "full": + fig, ax = plot_swc( + self.xyzr, + figsize=figsize, + dims=dims, + cols=cols, + highlight_branch_inds=highlight_branch_inds, + ) + else: + raise ValueError("`detail must be in {sticks, full}.") + + return fig, ax + class CellView(View): """CellView.""" diff --git a/jaxley/utils/plot_utils.py b/jaxley/utils/plot_utils.py index cf2a430e..1adb109d 100644 --- a/jaxley/utils/plot_utils.py +++ b/jaxley/utils/plot_utils.py @@ -114,8 +114,7 @@ def plot_morph( def plot_swc( - fname, - max_branch_len: float = 100.0, + xyzr, figsize=(4, 4), dims=(0, 1), cols="k", @@ -129,30 +128,21 @@ def plot_swc( cols: The color for all branches except the highlighted ones. highlight_branch_inds: Branch indices that will be highlighted. """ - content = np.loadtxt(fname) - sorted_branches, _ = _split_into_branches_and_sort( - content, max_branch_len=max_branch_len, sort=True - ) - parents = _build_parents(sorted_branches) - if np.sum(np.asarray(parents) == -1) > 1.0: - sorted_branches = [[0]] + sorted_branches - cols = [cols] * len(sorted_branches) + cols = [cols] * len(xyzr) counter_highlight_branches = 0 lines = [] fig, ax = plt.subplots(1, 1, figsize=figsize) - for i, branch in enumerate(sorted_branches): - coords_of_branch = content[np.asarray(branch) - 1, 2:5] - coords_of_branch = coords_of_branch[:, dims] - + for i, coords_of_branch in enumerate(xyzr): + coords_to_plot = coords_of_branch[:, dims] col = cols[i] if i in highlight_branch_inds: col = highlight_cols[counter_highlight_branches % len(highlight_cols)] counter_highlight_branches += 1 (line,) = ax.plot( - coords_of_branch[:, 0], coords_of_branch[:, 1], c=col, label=f"ind {i}" + coords_to_plot[:, 0], coords_to_plot[:, 1], c=col, label=f"ind {i}" ) if i in highlight_branch_inds: lines.append(line) From bda9c411d2476bdfe41f2300e22224758c42b5f8 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Wed, 6 Dec 2023 17:34:09 +0100 Subject: [PATCH 4/6] both cell visualization methods are working --- jaxley/modules/cell.py | 18 ++++++++++++++++-- jaxley/utils/plot_utils.py | 1 - 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index 71d078e2..6476365e 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -232,11 +232,13 @@ def update_summed_coupling_conds( def vis( self, - detail: str = "full", + detail: str = "sticks", figsize=(4, 4), dims=(0, 1), cols="k", highlight_branch_inds=[], + max_y_multiplier: float = 5.0, + min_y_multiplier: float = 0.5, ) -> None: """Visualize the network. @@ -248,10 +250,22 @@ def vis( neurons in each layer, e.g., [5, 10, 1] would be a network with 5 input neurons, 10 hidden layer neurons, and 1 output neuron. options: Plotting options passed to `NetworkX.draw()`. + dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of + two of them. + cols: The color for all branches except the highlighted ones. + highlight_branch_inds: Branch indices that will be highlighted. """ if detail == "sticks": - raise NotImplementedError + fig, ax = plot_morph( + cell=self, + figsize=figsize, + cols=cols, + highlight_branch_inds=highlight_branch_inds, + max_y_multiplier=max_y_multiplier, + min_y_multiplier=min_y_multiplier, + ) elif detail == "full": + assert self.xyzr is not None fig, ax = plot_swc( self.xyzr, figsize=figsize, diff --git a/jaxley/utils/plot_utils.py b/jaxley/utils/plot_utils.py index 1adb109d..d63e35df 100644 --- a/jaxley/utils/plot_utils.py +++ b/jaxley/utils/plot_utils.py @@ -6,7 +6,6 @@ _compute_num_children, compute_levels, ) -from jaxley.utils.swc import _build_parents, _split_into_branches_and_sort highlight_cols = [ "#1f78b4", From 2bb858a280149e1e378953c744f63c16ae38ba3a Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Wed, 6 Dec 2023 17:41:28 +0100 Subject: [PATCH 5/6] isort --- jaxley/modules/cell.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index 6476365e..8469d1e2 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -14,8 +14,8 @@ compute_coupling_cond, compute_levels, ) -from jaxley.utils.swc import swc_to_jaxley from jaxley.utils.plot_utils import plot_morph, plot_swc +from jaxley.utils.swc import swc_to_jaxley class Cell(Module): From c5deac96cc20eb6abe48df38af97817f43dd3324 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Wed, 6 Dec 2023 17:43:20 +0100 Subject: [PATCH 6/6] default to full --- jaxley/modules/cell.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index 8469d1e2..e0a3787c 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -232,7 +232,7 @@ def update_summed_coupling_conds( def vis( self, - detail: str = "sticks", + detail: str = "full", figsize=(4, 4), dims=(0, 1), cols="k", @@ -265,7 +265,7 @@ def vis( min_y_multiplier=min_y_multiplier, ) elif detail == "full": - assert self.xyzr is not None + assert self.xyzr is not None, "no coordinates, use `vis(detail='sticks')`." fig, ax = plot_swc( self.xyzr, figsize=figsize,