From 5e0dc180ea24a178d60508100cc84d2681a6b059 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Wed, 6 Dec 2023 16:57:56 +0100 Subject: [PATCH] allow reading coords from SWC --- jaxley/modules/cell.py | 13 +++++++++---- jaxley/utils/swc.py | 8 +++++++- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index e10374a6..385e8fe4 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -27,14 +27,16 @@ def __init__( self, branches: Union[Branch, List[Branch]], parents: List, - xyzr: Optional[Union[np.ndarray, jnp.ndarray]], + xyzr: Optional[List[np.ndarray]] = None, ): """Initialize a cell. Args: branches: parents: - xyzr: The x, y, and z coordinates and the radius at these coordinates. + xyzr: For every branch, the x, y, and z coordinates and the radius at the + traced coordinates. Note that this is the full tracing (from SWC), not + the stick representation coordinates. """ super().__init__() assert isinstance(branches, Branch) or len(parents) == len( @@ -45,6 +47,7 @@ def __init__( branch_list = [branches for _ in range(len(parents))] else: branch_list = branches + self.xyzr = xyzr self._append_to_params_and_state(branch_list) for branch in branch_list: @@ -251,7 +254,7 @@ def read_swc( min_radius: Optional[float] = None, ): """Reads SWC file into a `jx.Cell`.""" - parents, pathlengths, radius_fns, _ = swc_to_jaxley( + parents, pathlengths, radius_fns, _, coords_of_branches = swc_to_jaxley( fname, max_branch_len=max_branch_len, sort=True, num_lines=None ) nbranches = len(parents) @@ -261,7 +264,9 @@ def read_swc( comp = Compartment().initialize() branch = Branch([comp for _ in range(nseg)]).initialize() - cell = Cell([branch for _ in range(nbranches)], parents=parents, xyzr=None) + cell = Cell( + [branch for _ in range(nbranches)], parents=parents, xyzr=coords_of_branches + ) radiuses = np.flip( np.asarray([radius_fns[b](range_) for b in range(len(parents))]), axis=1 diff --git a/jaxley/utils/swc.py b/jaxley/utils/swc.py index 2c73592a..98a05965 100644 --- a/jaxley/utils/swc.py +++ b/jaxley/utils/swc.py @@ -41,8 +41,14 @@ def swc_to_jaxley( parents = parents.tolist() pathlengths = [0.1] + pathlengths radius_fns = [lambda x: content[0, 5] * np.ones_like(x)] + radius_fns + sorted_branches = [[0]] + sorted_branches - return parents, pathlengths, radius_fns, types + all_coords_of_branches = [] + for i, branch in enumerate(sorted_branches): + coords_of_branch = content[np.asarray(branch) - 1, 2:5] + all_coords_of_branches.append(coords_of_branch) + + return parents, pathlengths, radius_fns, types, all_coords_of_branches def _split_into_branches_and_sort(