Skip to content

Commit

Permalink
allow reading coords from SWC
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Dec 6, 2023
1 parent 68dba2e commit 5e0dc18
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
13 changes: 9 additions & 4 deletions jaxley/modules/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@ def __init__(
self,
branches: Union[Branch, List[Branch]],
parents: List,
xyzr: Optional[Union[np.ndarray, jnp.ndarray]],
xyzr: Optional[List[np.ndarray]] = None,
):
"""Initialize a cell.
Args:
branches:
parents:
xyzr: The x, y, and z coordinates and the radius at these coordinates.
xyzr: For every branch, the x, y, and z coordinates and the radius at the
traced coordinates. Note that this is the full tracing (from SWC), not
the stick representation coordinates.
"""
super().__init__()
assert isinstance(branches, Branch) or len(parents) == len(
Expand All @@ -45,6 +47,7 @@ def __init__(
branch_list = [branches for _ in range(len(parents))]
else:
branch_list = branches
self.xyzr = xyzr

self._append_to_params_and_state(branch_list)
for branch in branch_list:
Expand Down Expand Up @@ -251,7 +254,7 @@ def read_swc(
min_radius: Optional[float] = None,
):
"""Reads SWC file into a `jx.Cell`."""
parents, pathlengths, radius_fns, _ = swc_to_jaxley(
parents, pathlengths, radius_fns, _, coords_of_branches = swc_to_jaxley(
fname, max_branch_len=max_branch_len, sort=True, num_lines=None
)
nbranches = len(parents)
Expand All @@ -261,7 +264,9 @@ def read_swc(

comp = Compartment().initialize()
branch = Branch([comp for _ in range(nseg)]).initialize()
cell = Cell([branch for _ in range(nbranches)], parents=parents, xyzr=None)
cell = Cell(
[branch for _ in range(nbranches)], parents=parents, xyzr=coords_of_branches
)

radiuses = np.flip(
np.asarray([radius_fns[b](range_) for b in range(len(parents))]), axis=1
Expand Down
8 changes: 7 additions & 1 deletion jaxley/utils/swc.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,14 @@ def swc_to_jaxley(
parents = parents.tolist()
pathlengths = [0.1] + pathlengths
radius_fns = [lambda x: content[0, 5] * np.ones_like(x)] + radius_fns
sorted_branches = [[0]] + sorted_branches

return parents, pathlengths, radius_fns, types
all_coords_of_branches = []
for i, branch in enumerate(sorted_branches):
coords_of_branch = content[np.asarray(branch) - 1, 2:5]
all_coords_of_branches.append(coords_of_branch)

return parents, pathlengths, radius_fns, types, all_coords_of_branches


def _split_into_branches_and_sort(
Expand Down

0 comments on commit 5e0dc18

Please sign in to comment.