diff --git a/jaxley/utils/swc.py b/jaxley/utils/swc.py index fbd0a160a..294ce0f87 100644 --- a/jaxley/utils/swc.py +++ b/jaxley/utils/swc.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Callable, List, Optional, Tuple from warnings import warn import numpy as np @@ -31,7 +31,7 @@ def swc_to_jaxley( if pathlen == 0.0: warn("Found a segment with length 0. Clipping it to 1.0") pathlengths[i] = 1.0 - radius_fns = _extract_endpoint_radiuses( + radius_fns = _radius_generating_fns( sorted_branches, content[:, 5], each_length, parents, types ) @@ -45,7 +45,9 @@ def swc_to_jaxley( return parents, pathlengths, radius_fns, types -def _split_into_branches_and_sort(content, max_branch_len, sort=True): +def _split_into_branches_and_sort( + content: np.ndarray, max_branch_len: float, sort: bool = True +) -> Tuple[np.ndarray, np.ndarray]: branches, types = _split_into_branches(content) branches, types = _split_long_branches(branches, types, content, max_branch_len) @@ -60,7 +62,9 @@ def _split_into_branches_and_sort(content, max_branch_len, sort=True): return sorted_branches, sorted_types -def _split_long_branches(branches, types, content, max_branch_len): +def _split_long_branches( + branches, types, content, max_branch_len +) -> Tuple[np.ndarray, np.ndarray]: pathlengths = _compute_pathlengths(branches, content[:, 2:5]) pathlengths = [np.sum(length_traced) for length_traced in pathlengths] split_branches = [] @@ -100,7 +104,7 @@ def _split_branch_equally(branch, num_subbranches): return branches -def _split_into_branches(content): +def _split_into_branches(content: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: prev_ind = None prev_type = None n_branches = 0 @@ -156,7 +160,10 @@ def _build_parents(all_branches): return parents -def _extract_endpoint_radiuses(all_branches, radiuses, each_length, parents, types): +def _radius_generating_fns( + all_branches, radiuses, each_length, parents, types +) -> List[Callable]: + """For all branches in a cell, returns callable that return radius given loc.""" radius_fns = [] for i, branch in enumerate(all_branches): rads_in_branch = radiuses[np.asarray(branch) - 1] @@ -175,13 +182,15 @@ def _extract_endpoint_radiuses(all_branches, radiuses, each_length, parents, typ return radius_fns -def _radius_generating_fn(radiuses, each_length): +def _radius_generating_fn(radiuses: np.ndarray, each_length: np.ndarray) -> Callable: + each_length[each_length < 1e-8] = 1e-8 summed_len = np.sum(each_length) cutoffs = np.cumsum(np.concatenate([np.asarray([0]), each_length])) / summed_len cutoffs[0] -= 1e-8 cutoffs[-1] += 1e-8 def radius(loc): + """Function which returns the radius via linear interpolation.""" index = np.digitize(loc, cutoffs, right=False) left_rad = radiuses[index - 1] right_rad = radiuses[index]