Skip to content

Commit

Permalink
Bugfix if branchlen==0 in SWC
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Dec 1, 2023
1 parent c71b29b commit 224dd45
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions jaxley/utils/swc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Callable, List, Optional, Tuple
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -31,7 +31,7 @@ def swc_to_jaxley(
if pathlen == 0.0:
warn("Found a segment with length 0. Clipping it to 1.0")
pathlengths[i] = 1.0
radius_fns = _extract_endpoint_radiuses(
radius_fns = _radius_generating_fns(
sorted_branches, content[:, 5], each_length, parents, types
)

Expand All @@ -45,7 +45,9 @@ def swc_to_jaxley(
return parents, pathlengths, radius_fns, types


def _split_into_branches_and_sort(content, max_branch_len, sort=True):
def _split_into_branches_and_sort(
content: np.ndarray, max_branch_len: float, sort: bool = True
) -> Tuple[np.ndarray, np.ndarray]:
branches, types = _split_into_branches(content)
branches, types = _split_long_branches(branches, types, content, max_branch_len)

Expand All @@ -60,7 +62,9 @@ def _split_into_branches_and_sort(content, max_branch_len, sort=True):
return sorted_branches, sorted_types


def _split_long_branches(branches, types, content, max_branch_len):
def _split_long_branches(
branches, types, content, max_branch_len
) -> Tuple[np.ndarray, np.ndarray]:
pathlengths = _compute_pathlengths(branches, content[:, 2:5])
pathlengths = [np.sum(length_traced) for length_traced in pathlengths]
split_branches = []
Expand Down Expand Up @@ -100,7 +104,7 @@ def _split_branch_equally(branch, num_subbranches):
return branches


def _split_into_branches(content):
def _split_into_branches(content: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
prev_ind = None
prev_type = None
n_branches = 0
Expand Down Expand Up @@ -156,7 +160,10 @@ def _build_parents(all_branches):
return parents


def _extract_endpoint_radiuses(all_branches, radiuses, each_length, parents, types):
def _radius_generating_fns(
all_branches, radiuses, each_length, parents, types
) -> List[Callable]:
"""For all branches in a cell, returns callable that return radius given loc."""
radius_fns = []
for i, branch in enumerate(all_branches):
rads_in_branch = radiuses[np.asarray(branch) - 1]
Expand All @@ -175,14 +182,21 @@ def _extract_endpoint_radiuses(all_branches, radiuses, each_length, parents, typ
return radius_fns


def _radius_generating_fn(radiuses, each_length):
def _radius_generating_fn(radiuses: np.ndarray, each_length: np.ndarray) -> Callable:
each_length[each_length < 1e-8] = 1e-8
summed_len = np.sum(each_length)
cutoffs = np.cumsum(np.concatenate([np.asarray([0]), each_length])) / summed_len
print("cutoffs", cutoffs)
cutoffs[0] -= 1e-8
cutoffs[-1] += 1e-8

def radius(loc):
"""Function which returns the radius via linear interpolation."""
index = np.digitize(loc, cutoffs, right=False)
print("radiuses", radiuses.shape)
print("cutoffs", cutoffs)
print("index", index)
print("loc", loc)
left_rad = radiuses[index - 1]
right_rad = radiuses[index]
left_loc = cutoffs[index - 1]
Expand Down

0 comments on commit 224dd45

Please sign in to comment.