From d55c87bd63bd2856c364bdaf8591375096b7ef59 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Thu, 21 Nov 2024 14:03:01 +0100 Subject: [PATCH] fix: fix circular imports --- jaxley/io/swc.py | 292 ++--------------------------------------- jaxley/modules/base.py | 1 - 2 files changed, 14 insertions(+), 279 deletions(-) diff --git a/jaxley/io/swc.py b/jaxley/io/swc.py index 321ea659..ff131e62 100644 --- a/jaxley/io/swc.py +++ b/jaxley/io/swc.py @@ -8,7 +8,14 @@ import jax.numpy as jnp import numpy as np -import jaxley as jx +from jaxley.modules import Branch, Cell, Compartment +from jaxley.utils.cell_utils import ( + _build_parents, + _compute_pathlengths, + _radius_generating_fns, + _split_into_branches_and_sort, + build_radiuses_from_xyzr, +) def swc_to_jaxley( @@ -85,285 +92,14 @@ def swc_to_jaxley( return parents, pathlengths, radius_fns, types, all_coords_of_branches -def _split_into_branches_and_sort( - content: np.ndarray, - max_branch_len: float, - is_single_point_soma: bool, - sort: bool = True, -) -> Tuple[np.ndarray, np.ndarray]: - branches, types = _split_into_branches(content, is_single_point_soma) - branches, types = _split_long_branches( - branches, - types, - content, - max_branch_len, - is_single_point_soma=is_single_point_soma, - ) - - if sort: - first_val = np.asarray([b[0] for b in branches]) - sorting = np.argsort(first_val, kind="mergesort") - sorted_branches = [branches[s] for s in sorting] - sorted_types = [types[s] for s in sorting] - else: - sorted_branches = branches - sorted_types = types - return sorted_branches, sorted_types - - -def _split_long_branches( - branches: np.ndarray, - types: np.ndarray, - content: np.ndarray, - max_branch_len: float, - is_single_point_soma: bool, -) -> Tuple[np.ndarray, np.ndarray]: - pathlengths = _compute_pathlengths( - branches, content[:, 1:6], is_single_point_soma=is_single_point_soma - ) - pathlengths = [np.sum(length_traced) for length_traced in pathlengths] - split_branches = [] - split_types = [] - for branch, type, length in zip(branches, types, pathlengths): - num_subbranches = 1 - split_branch = [branch] - while length > max_branch_len: - num_subbranches += 1 - split_branch = _split_branch_equally(branch, num_subbranches) - lengths_of_subbranches = _compute_pathlengths( - split_branch, - coords=content[:, 1:6], - is_single_point_soma=is_single_point_soma, - ) - lengths_of_subbranches = [ - np.sum(length_traced) for length_traced in lengths_of_subbranches - ] - length = max(lengths_of_subbranches) - if num_subbranches > 10: - warn( - """`num_subbranches > 10`, stopping to split. Most likely your - SWC reconstruction is not dense and some neighbouring traced - points are farther than `max_branch_len` apart.""" - ) - break - split_branches += split_branch - split_types += [type] * num_subbranches - - return split_branches, split_types - - -def _split_branch_equally(branch: np.ndarray, num_subbranches: int) -> List[np.ndarray]: - num_points_each = len(branch) // num_subbranches - branches = [branch[:num_points_each]] - for i in range(1, num_subbranches - 1): - branches.append(branch[i * num_points_each - 1 : (i + 1) * num_points_each]) - branches.append(branch[(num_subbranches - 1) * num_points_each - 1 :]) - return branches - - -def _split_into_branches( - content: np.ndarray, is_single_point_soma: bool -) -> Tuple[np.ndarray, np.ndarray]: - prev_ind = None - prev_type = None - n_branches = 0 - - # Branch inds will contain the row identifier at which a branch point occurs - # (i.e. the row of the parent of two branches). - branch_inds = [] - for c in content: - current_ind = c[0] - current_parent = c[-1] - current_type = c[1] - if current_parent != prev_ind or current_type != prev_type: - branch_inds.append(int(current_parent)) - n_branches += 1 - prev_ind = current_ind - prev_type = current_type - - all_branches = [] - current_branch = [] - all_types = [] - - # Loop over every line in the SWC file. - for c in content: - current_ind = c[0] # First col is row_identifier - current_parent = c[-1] # Last col is parent in SWC specification. - if current_parent == -1: - all_types.append(c[1]) - else: - current_type = c[1] - - if current_parent == -1 and is_single_point_soma and current_ind == 1: - all_branches.append([int(current_ind)]) - all_types.append(int(current_type)) - - # Either append the current point to the branch, or add the branch to - # `all_branches`. - if current_parent in branch_inds[1:]: - if len(current_branch) > 1: - all_branches.append(current_branch) - all_types.append(current_type) - current_branch = [int(current_parent), int(current_ind)] - else: - current_branch.append(int(current_ind)) - - # Append the final branch (intermediate branches are already appended five lines - # above.) - all_branches.append(current_branch) - return all_branches, all_types - - -def _build_parents(all_branches: List[np.ndarray]) -> List[int]: - parents = [None] * len(all_branches) - all_last_inds = [b[-1] for b in all_branches] - for i, branch in enumerate(all_branches): - parent_ind = branch[0] - ind = np.where(np.asarray(all_last_inds) == parent_ind)[0] - if len(ind) > 0 and ind != i: - parents[i] = ind[0] - else: - assert ( - parent_ind == 1 - ), """Trying to connect a segment to the beginning of - another segment. This is not allowed. Please create an issue on github.""" - parents[i] = -1 - - return parents - - -def _radius_generating_fns( - all_branches: np.ndarray, - radiuses: np.ndarray, - each_length: np.ndarray, - parents: np.ndarray, - types: np.ndarray, -) -> List[Callable]: - """For all branches in a cell, returns callable that return radius given loc.""" - radius_fns = [] - for i, branch in enumerate(all_branches): - rads_in_branch = radiuses[np.asarray(branch) - 1] - if parents[i] > -1 and types[i] != types[parents[i]]: - # We do not want to linearly interpolate between the radius of the previous - # branch if a new type of neurite is found (e.g. switch from soma to - # apical). From looking at the SWC from n140.swc I believe that this is - # also what NEURON does. - rads_in_branch[0] = rads_in_branch[1] - radius_fn = _radius_generating_fn( - radiuses=rads_in_branch, each_length=each_length[i] - ) - # Beause SWC starts counting at 1, but numpy counts from 0. - # ind_of_branch_endpoint = np.asarray(b) - 1 - radius_fns.append(radius_fn) - return radius_fns - - -def _radius_generating_fn(radiuses: np.ndarray, each_length: np.ndarray) -> Callable: - # Avoid division by 0 with the `summed_len` below. - each_length[each_length < 1e-8] = 1e-8 - summed_len = np.sum(each_length) - cutoffs = np.cumsum(np.concatenate([np.asarray([0]), each_length])) / summed_len - cutoffs[0] -= 1e-8 - cutoffs[-1] += 1e-8 - - # We have to linearly interpolate radiuses, therefore we need at least two radiuses. - # However, jaxley allows somata which consist of a single traced point (i.e. - # just one radius). Therefore, we just `tile` in order to generate an artificial - # endpoint and startpoint radius of the soma. - if len(radiuses) == 1: - radiuses = np.tile(radiuses, 2) - - def radius(loc: float) -> float: - """Function which returns the radius via linear interpolation.""" - index = np.digitize(loc, cutoffs, right=False) - left_rad = radiuses[index - 1] - right_rad = radiuses[index] - left_loc = cutoffs[index - 1] - right_loc = cutoffs[index] - loc_within_bin = (loc - left_loc) / (right_loc - left_loc) - return left_rad + (right_rad - left_rad) * loc_within_bin - - return radius - - -def _compute_pathlengths( - all_branches: np.ndarray, coords: np.ndarray, is_single_point_soma: bool -) -> List[np.ndarray]: - """ - Args: - coords: Has shape (num_traced_points, 5), where `5` is (type, x, y, z, radius). - """ - branch_pathlengths = [] - for b in all_branches: - coords_in_branch = coords[np.asarray(b) - 1] - if len(coords_in_branch) > 1: - # If the branch starts at a different neurite (e.g. the soma) then NEURON - # ignores the distance from that initial point. To reproduce, use the - # following SWC dummy file and read it in NEURON (and Jaxley): - # 1 1 0.00 0.0 0.0 6.0 -1 - # 2 2 9.00 0.0 0.0 0.5 1 - # 3 2 10.0 0.0 0.0 0.3 2 - types = coords_in_branch[:, 0] - if int(types[0]) == 1 and int(types[1]) != 1 and is_single_point_soma: - coords_in_branch[0] = coords_in_branch[1] - - # Compute distances between all traced points in a branch. - point_diffs = np.diff(coords_in_branch, axis=0) - dists = np.sqrt( - point_diffs[:, 1] ** 2 + point_diffs[:, 2] ** 2 + point_diffs[:, 3] ** 2 - ) - else: - # Jaxley uses length and radius for every compartment and assumes the - # surface area to be 2*pi*r*length. For branches consisting of a single - # traced point we assume for them to have area 4*pi*r*r. Therefore, we have - # to set length = 2*r. - radius = coords_in_branch[0, 4] # txyzr -> 4 is radius. - dists = np.asarray([2 * radius]) - branch_pathlengths.append(dists) - return branch_pathlengths - - -def build_radiuses_from_xyzr( - radius_fns: List[Callable], - branch_indices: List[int], - min_radius: Optional[float], - nseg: int, -) -> jnp.ndarray: - """Return the radiuses of branches given SWC file xyzr. - - Returns an array of shape `(num_branches, nseg)`. - - Args: - radius_fns: Functions which, given compartment locations return the radius. - branch_indices: The indices of the branches for which to return the radiuses. - min_radius: If passed, the radiuses are clipped to be at least as large. - nseg: The number of compartments that every branch is discretized into. - """ - # Compartment locations are at the center of the internal nodes. - non_split = 1 / nseg - range_ = np.linspace(non_split / 2, 1 - non_split / 2, nseg) - - # Build radiuses. - radiuses = np.asarray([radius_fns[b](range_) for b in branch_indices]) - radiuses_each = radiuses.ravel(order="C") - if min_radius is None: - assert np.all( - radiuses_each > 0.0 - ), "Radius 0.0 in SWC file. Set `read_swc(..., min_radius=...)`." - else: - radiuses_each[radiuses_each < min_radius] = min_radius - - return radiuses_each - - def read_swc( fname: str, nseg: int, max_branch_len: float = 300.0, min_radius: Optional[float] = None, assign_groups: bool = False, -) -> jx.Cell: - """Reads SWC file into a `jx.Cell`. +) -> Cell: + """Reads SWC file into a `Cell`. Jaxley assumes cylindrical compartments and therefore defines length and radius for every compartment. The surface area is then 2*pi*r*length. For branches @@ -382,16 +118,16 @@ def read_swc( http://www.neuronland.org/NLMorphologyConverter/MorphologyFormats/SWC/Spec.html Returns: - A `jx.Cell` object. + A `Cell` object. """ parents, pathlengths, radius_fns, types, coords_of_branches = swc_to_jaxley( fname, max_branch_len=max_branch_len, sort=True, num_lines=None ) nbranches = len(parents) - comp = jx.Compartment() - branch = jx.Branch([comp for _ in range(nseg)]) - cell = jx.Cell( + comp = Compartment() + branch = Branch([comp for _ in range(nseg)]) + cell = Cell( [branch for _ in range(nbranches)], parents=parents, xyzr=coords_of_branches ) # Also save the radius generating functions in case users post-hoc modify the number diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index f50a0d1a..b40f3951 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -17,7 +17,6 @@ from matplotlib.axes import Axes from jaxley.channels import Channel -from jaxley.io.swc import build_radiuses_from_xyzr from jaxley.solver_voltage import ( step_voltage_explicit, step_voltage_implicit_with_jax_spsolve,