diff --git a/jaxley/__init__.py b/jaxley/__init__.py index ce2c261a..abc461f0 100644 --- a/jaxley/__init__.py +++ b/jaxley/__init__.py @@ -8,6 +8,7 @@ sparse_connect, ) from jaxley.integrate import integrate +from jaxley.io.swc import read_swc from jaxley.modules import * from jaxley.optimize import ParamTransform from jaxley.stimulus import datapoint_to_step_currents, step_current diff --git a/jaxley/io/swc.py b/jaxley/io/swc.py new file mode 100644 index 00000000..ff131e62 --- /dev/null +++ b/jaxley/io/swc.py @@ -0,0 +1,168 @@ +# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is +# licensed under the Apache License Version 2.0, see + +from copy import copy +from typing import Callable, List, Optional, Tuple +from warnings import warn + +import jax.numpy as jnp +import numpy as np + +from jaxley.modules import Branch, Cell, Compartment +from jaxley.utils.cell_utils import ( + _build_parents, + _compute_pathlengths, + _radius_generating_fns, + _split_into_branches_and_sort, + build_radiuses_from_xyzr, +) + + +def swc_to_jaxley( + fname: str, + max_branch_len: float = 100.0, + sort: bool = True, + num_lines: Optional[int] = None, +) -> Tuple[List[int], List[float], List[Callable], List[float], List[np.ndarray]]: + """Read an SWC file and bring morphology into `jaxley` compatible formats. + + Args: + fname: Path to swc file. + max_branch_len: Maximal length of one branch. If a branch exceeds this length, + it is split into equal parts such that each subbranch is below + `max_branch_len`. + num_lines: Number of lines of the SWC file to read. + """ + content = np.loadtxt(fname)[:num_lines] + types = content[:, 1] + is_single_point_soma = types[0] == 1 and types[1] != 1 + + if is_single_point_soma: + # Warn here, but the conversion of the length happens in `_compute_pathlengths`. + warn( + "Found a soma which consists of a single traced point. `Jaxley` " + "interprets this soma as a spherical compartment with radius " + "specified in the SWC file, i.e. with surface area 4*pi*r*r." + ) + sorted_branches, types = _split_into_branches_and_sort( + content, + max_branch_len=max_branch_len, + is_single_point_soma=is_single_point_soma, + sort=sort, + ) + + parents = _build_parents(sorted_branches) + each_length = _compute_pathlengths( + sorted_branches, content[:, 1:6], is_single_point_soma=is_single_point_soma + ) + pathlengths = [np.sum(length_traced) for length_traced in each_length] + for i, pathlen in enumerate(pathlengths): + if pathlen == 0.0: + warn("Found a segment with length 0. Clipping it to 1.0") + pathlengths[i] = 1.0 + radius_fns = _radius_generating_fns( + sorted_branches, content[:, 5], each_length, parents, types + ) + + if np.sum(np.asarray(parents) == -1) > 1.0: + parents = np.asarray([-1] + parents) + parents[1:] += 1 + parents = parents.tolist() + pathlengths = [0.1] + pathlengths + radius_fns = [lambda x: content[0, 5] * np.ones_like(x)] + radius_fns + sorted_branches = [[0]] + sorted_branches + + # Type of padded section is assumed to be of `custom` type: + # http://www.neuronland.org/NLMorphologyConverter/MorphologyFormats/SWC/Spec.html + types = [5.0] + types + + all_coords_of_branches = [] + for i, branch in enumerate(sorted_branches): + # Remove 1 because `content` is an array that is indexed from 0. + branch = np.asarray(branch) - 1 + + # Deal with additional branch that might have been added above in the lines + # `if np.sum(np.asarray(parents) == -1) > 1.0:` + branch[branch < 0] = 0 + + # Get traced coordinates of the branch. + coords_of_branch = content[branch, 2:6] + all_coords_of_branches.append(coords_of_branch) + + return parents, pathlengths, radius_fns, types, all_coords_of_branches + + +def read_swc( + fname: str, + nseg: int, + max_branch_len: float = 300.0, + min_radius: Optional[float] = None, + assign_groups: bool = False, +) -> Cell: + """Reads SWC file into a `Cell`. + + Jaxley assumes cylindrical compartments and therefore defines length and radius + for every compartment. The surface area is then 2*pi*r*length. For branches + consisting of a single traced point we assume for them to have area 4*pi*r*r. + Therefore, in these cases, we set lenght=2*r. + + Args: + fname: Path to the swc file. + nseg: The number of compartments per branch. + max_branch_len: If a branch is longer than this value it is split into two + branches. + min_radius: If the radius of a reconstruction is below this value it is clipped. + assign_groups: If True, then the identity of reconstructed points in the SWC + file will be used to generate groups `undefined`, `soma`, `axon`, `basal`, + `apical`, `custom`. See here: + http://www.neuronland.org/NLMorphologyConverter/MorphologyFormats/SWC/Spec.html + + Returns: + A `Cell` object. + """ + parents, pathlengths, radius_fns, types, coords_of_branches = swc_to_jaxley( + fname, max_branch_len=max_branch_len, sort=True, num_lines=None + ) + nbranches = len(parents) + + comp = Compartment() + branch = Branch([comp for _ in range(nseg)]) + cell = Cell( + [branch for _ in range(nbranches)], parents=parents, xyzr=coords_of_branches + ) + # Also save the radius generating functions in case users post-hoc modify the number + # of compartments with `.set_ncomp()`. + cell._radius_generating_fns = radius_fns + + lengths_each = np.repeat(pathlengths, nseg) / nseg + cell.set("length", lengths_each) + + radiuses_each = build_radiuses_from_xyzr( + radius_fns, + range(len(parents)), + min_radius, + nseg, + ) + cell.set("radius", radiuses_each) + + # Description of SWC file format: + # http://www.neuronland.org/NLMorphologyConverter/MorphologyFormats/SWC/Spec.html + ind_name_lookup = { + 0: "undefined", + 1: "soma", + 2: "axon", + 3: "basal", + 4: "apical", + 5: "custom", + } + types = np.asarray(types).astype(int) + if assign_groups: + for type_ind in np.unique(types): + if type_ind < 5.5: + name = ind_name_lookup[type_ind] + else: + name = f"custom{type_ind}" + indices = np.where(types == type_ind)[0].tolist() + if len(indices) > 0: + cell.branch(indices).add_to_group(name) + return cell diff --git a/jaxley/modules/__init__.py b/jaxley/modules/__init__.py index 584ca3e5..1bcca2b6 100644 --- a/jaxley/modules/__init__.py +++ b/jaxley/modules/__init__.py @@ -3,6 +3,6 @@ from jaxley.modules.base import Module from jaxley.modules.branch import Branch -from jaxley.modules.cell import Cell, read_swc +from jaxley.modules.cell import Cell from jaxley.modules.compartment import Compartment from jaxley.modules.network import Network diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 1409a6eb..8441c6e6 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -26,6 +26,7 @@ from jaxley.utils.cell_utils import ( _compute_index_of_child, _compute_num_children, + build_radiuses_from_xyzr, compute_axial_conductances, compute_levels, convert_point_process_to_distributed, @@ -39,7 +40,6 @@ from jaxley.utils.misc_utils import cumsum_leading_zero, is_str_all from jaxley.utils.plot_utils import plot_comps, plot_graph, plot_morph from jaxley.utils.solver_utils import convert_to_csc -from jaxley.utils.swc import build_radiuses_from_xyzr def only_allow_module(func): diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index cc307f22..7db12cdd 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -8,8 +8,7 @@ import pandas as pd from jaxley.modules.base import Module -from jaxley.modules.branch import Branch, Compartment -from jaxley.synapses import Synapse +from jaxley.modules.branch import Branch from jaxley.utils.cell_utils import ( build_branchpoint_group_inds, compute_children_and_parents, @@ -25,7 +24,6 @@ comp_edges_to_indices, remap_index_to_masked, ) -from jaxley.utils.swc import build_radiuses_from_xyzr, swc_to_jaxley class Cell(Module): @@ -271,79 +269,3 @@ def _init_morph_jax_spsolve(self): self._data_inds = data_inds self._indices_jax_spsolve = indices self._indptr_jax_spsolve = indptr - - -def read_swc( - fname: str, - nseg: int, - max_branch_len: float = 300.0, - min_radius: Optional[float] = None, - assign_groups: bool = False, -) -> Cell: - """Reads SWC file into a `jx.Cell`. - - Jaxley assumes cylindrical compartments and therefore defines length and radius - for every compartment. The surface area is then 2*pi*r*length. For branches - consisting of a single traced point we assume for them to have area 4*pi*r*r. - Therefore, in these cases, we set lenght=2*r. - - Args: - fname: Path to the swc file. - nseg: The number of compartments per branch. - max_branch_len: If a branch is longer than this value it is split into two - branches. - min_radius: If the radius of a reconstruction is below this value it is clipped. - assign_groups: If True, then the identity of reconstructed points in the SWC - file will be used to generate groups `undefined`, `soma`, `axon`, `basal`, - `apical`, `custom`. See here: - http://www.neuronland.org/NLMorphologyConverter/MorphologyFormats/SWC/Spec.html - - Returns: - A `jx.Cell` object. - """ - parents, pathlengths, radius_fns, types, coords_of_branches = swc_to_jaxley( - fname, max_branch_len=max_branch_len, sort=True, num_lines=None - ) - nbranches = len(parents) - - comp = Compartment() - branch = Branch([comp for _ in range(nseg)]) - cell = Cell( - [branch for _ in range(nbranches)], parents=parents, xyzr=coords_of_branches - ) - # Also save the radius generating functions in case users post-hoc modify the number - # of compartments with `.set_ncomp()`. - cell._radius_generating_fns = radius_fns - - lengths_each = np.repeat(pathlengths, nseg) / nseg - cell.set("length", lengths_each) - - radiuses_each = build_radiuses_from_xyzr( - radius_fns, - range(len(parents)), - min_radius, - nseg, - ) - cell.set("radius", radiuses_each) - - # Description of SWC file format: - # http://www.neuronland.org/NLMorphologyConverter/MorphologyFormats/SWC/Spec.html - ind_name_lookup = { - 0: "undefined", - 1: "soma", - 2: "axon", - 3: "basal", - 4: "apical", - 5: "custom", - } - types = np.asarray(types).astype(int) - if assign_groups: - for type_ind in np.unique(types): - if type_ind < 5.5: - name = ind_name_lookup[type_ind] - else: - name = f"custom{type_ind}" - indices = np.where(types == type_ind)[0].tolist() - if len(indices) > 0: - cell.branch(indices).add_to_group(name) - return cell diff --git a/jaxley/utils/cell_utils.py b/jaxley/utils/cell_utils.py index 961b6533..d47ab6ea 100644 --- a/jaxley/utils/cell_utils.py +++ b/jaxley/utils/cell_utils.py @@ -2,7 +2,8 @@ # licensed under the Apache License Version 2.0, see from math import pi -from typing import Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union +from warnings import warn import jax.numpy as jnp import numpy as np @@ -12,6 +13,277 @@ from jaxley.utils.misc_utils import cumsum_leading_zero +def _split_into_branches_and_sort( + content: np.ndarray, + max_branch_len: float, + is_single_point_soma: bool, + sort: bool = True, +) -> Tuple[np.ndarray, np.ndarray]: + branches, types = _split_into_branches(content, is_single_point_soma) + branches, types = _split_long_branches( + branches, + types, + content, + max_branch_len, + is_single_point_soma=is_single_point_soma, + ) + + if sort: + first_val = np.asarray([b[0] for b in branches]) + sorting = np.argsort(first_val, kind="mergesort") + sorted_branches = [branches[s] for s in sorting] + sorted_types = [types[s] for s in sorting] + else: + sorted_branches = branches + sorted_types = types + return sorted_branches, sorted_types + + +def _split_long_branches( + branches: np.ndarray, + types: np.ndarray, + content: np.ndarray, + max_branch_len: float, + is_single_point_soma: bool, +) -> Tuple[np.ndarray, np.ndarray]: + pathlengths = _compute_pathlengths( + branches, content[:, 1:6], is_single_point_soma=is_single_point_soma + ) + pathlengths = [np.sum(length_traced) for length_traced in pathlengths] + split_branches = [] + split_types = [] + for branch, type, length in zip(branches, types, pathlengths): + num_subbranches = 1 + split_branch = [branch] + while length > max_branch_len: + num_subbranches += 1 + split_branch = _split_branch_equally(branch, num_subbranches) + lengths_of_subbranches = _compute_pathlengths( + split_branch, + coords=content[:, 1:6], + is_single_point_soma=is_single_point_soma, + ) + lengths_of_subbranches = [ + np.sum(length_traced) for length_traced in lengths_of_subbranches + ] + length = max(lengths_of_subbranches) + if num_subbranches > 10: + warn( + """`num_subbranches > 10`, stopping to split. Most likely your + SWC reconstruction is not dense and some neighbouring traced + points are farther than `max_branch_len` apart.""" + ) + break + split_branches += split_branch + split_types += [type] * num_subbranches + + return split_branches, split_types + + +def _split_branch_equally(branch: np.ndarray, num_subbranches: int) -> List[np.ndarray]: + num_points_each = len(branch) // num_subbranches + branches = [branch[:num_points_each]] + for i in range(1, num_subbranches - 1): + branches.append(branch[i * num_points_each - 1 : (i + 1) * num_points_each]) + branches.append(branch[(num_subbranches - 1) * num_points_each - 1 :]) + return branches + + +def _split_into_branches( + content: np.ndarray, is_single_point_soma: bool +) -> Tuple[np.ndarray, np.ndarray]: + prev_ind = None + prev_type = None + n_branches = 0 + + # Branch inds will contain the row identifier at which a branch point occurs + # (i.e. the row of the parent of two branches). + branch_inds = [] + for c in content: + current_ind = c[0] + current_parent = c[-1] + current_type = c[1] + if current_parent != prev_ind or current_type != prev_type: + branch_inds.append(int(current_parent)) + n_branches += 1 + prev_ind = current_ind + prev_type = current_type + + all_branches = [] + current_branch = [] + all_types = [] + + # Loop over every line in the SWC file. + for c in content: + current_ind = c[0] # First col is row_identifier + current_parent = c[-1] # Last col is parent in SWC specification. + if current_parent == -1: + all_types.append(c[1]) + else: + current_type = c[1] + + if current_parent == -1 and is_single_point_soma and current_ind == 1: + all_branches.append([int(current_ind)]) + all_types.append(int(current_type)) + + # Either append the current point to the branch, or add the branch to + # `all_branches`. + if current_parent in branch_inds[1:]: + if len(current_branch) > 1: + all_branches.append(current_branch) + all_types.append(current_type) + current_branch = [int(current_parent), int(current_ind)] + else: + current_branch.append(int(current_ind)) + + # Append the final branch (intermediate branches are already appended five lines + # above.) + all_branches.append(current_branch) + return all_branches, all_types + + +def _build_parents(all_branches: List[np.ndarray]) -> List[int]: + parents = [None] * len(all_branches) + all_last_inds = [b[-1] for b in all_branches] + for i, branch in enumerate(all_branches): + parent_ind = branch[0] + ind = np.where(np.asarray(all_last_inds) == parent_ind)[0] + if len(ind) > 0 and ind != i: + parents[i] = ind[0] + else: + assert ( + parent_ind == 1 + ), """Trying to connect a segment to the beginning of + another segment. This is not allowed. Please create an issue on github.""" + parents[i] = -1 + + return parents + + +def _radius_generating_fns( + all_branches: np.ndarray, + radiuses: np.ndarray, + each_length: np.ndarray, + parents: np.ndarray, + types: np.ndarray, +) -> List[Callable]: + """For all branches in a cell, returns callable that return radius given loc.""" + radius_fns = [] + for i, branch in enumerate(all_branches): + rads_in_branch = radiuses[np.asarray(branch) - 1] + if parents[i] > -1 and types[i] != types[parents[i]]: + # We do not want to linearly interpolate between the radius of the previous + # branch if a new type of neurite is found (e.g. switch from soma to + # apical). From looking at the SWC from n140.swc I believe that this is + # also what NEURON does. + rads_in_branch[0] = rads_in_branch[1] + radius_fn = _radius_generating_fn( + radiuses=rads_in_branch, each_length=each_length[i] + ) + # Beause SWC starts counting at 1, but numpy counts from 0. + # ind_of_branch_endpoint = np.asarray(b) - 1 + radius_fns.append(radius_fn) + return radius_fns + + +def _radius_generating_fn(radiuses: np.ndarray, each_length: np.ndarray) -> Callable: + # Avoid division by 0 with the `summed_len` below. + each_length[each_length < 1e-8] = 1e-8 + summed_len = np.sum(each_length) + cutoffs = np.cumsum(np.concatenate([np.asarray([0]), each_length])) / summed_len + cutoffs[0] -= 1e-8 + cutoffs[-1] += 1e-8 + + # We have to linearly interpolate radiuses, therefore we need at least two radiuses. + # However, jaxley allows somata which consist of a single traced point (i.e. + # just one radius). Therefore, we just `tile` in order to generate an artificial + # endpoint and startpoint radius of the soma. + if len(radiuses) == 1: + radiuses = np.tile(radiuses, 2) + + def radius(loc: float) -> float: + """Function which returns the radius via linear interpolation.""" + index = np.digitize(loc, cutoffs, right=False) + left_rad = radiuses[index - 1] + right_rad = radiuses[index] + left_loc = cutoffs[index - 1] + right_loc = cutoffs[index] + loc_within_bin = (loc - left_loc) / (right_loc - left_loc) + return left_rad + (right_rad - left_rad) * loc_within_bin + + return radius + + +def _compute_pathlengths( + all_branches: np.ndarray, coords: np.ndarray, is_single_point_soma: bool +) -> List[np.ndarray]: + """ + Args: + coords: Has shape (num_traced_points, 5), where `5` is (type, x, y, z, radius). + """ + branch_pathlengths = [] + for b in all_branches: + coords_in_branch = coords[np.asarray(b) - 1] + if len(coords_in_branch) > 1: + # If the branch starts at a different neurite (e.g. the soma) then NEURON + # ignores the distance from that initial point. To reproduce, use the + # following SWC dummy file and read it in NEURON (and Jaxley): + # 1 1 0.00 0.0 0.0 6.0 -1 + # 2 2 9.00 0.0 0.0 0.5 1 + # 3 2 10.0 0.0 0.0 0.3 2 + types = coords_in_branch[:, 0] + if int(types[0]) == 1 and int(types[1]) != 1 and is_single_point_soma: + coords_in_branch[0] = coords_in_branch[1] + + # Compute distances between all traced points in a branch. + point_diffs = np.diff(coords_in_branch, axis=0) + dists = np.sqrt( + point_diffs[:, 1] ** 2 + point_diffs[:, 2] ** 2 + point_diffs[:, 3] ** 2 + ) + else: + # Jaxley uses length and radius for every compartment and assumes the + # surface area to be 2*pi*r*length. For branches consisting of a single + # traced point we assume for them to have area 4*pi*r*r. Therefore, we have + # to set length = 2*r. + radius = coords_in_branch[0, 4] # txyzr -> 4 is radius. + dists = np.asarray([2 * radius]) + branch_pathlengths.append(dists) + return branch_pathlengths + + +def build_radiuses_from_xyzr( + radius_fns: List[Callable], + branch_indices: List[int], + min_radius: Optional[float], + nseg: int, +) -> jnp.ndarray: + """Return the radiuses of branches given SWC file xyzr. + + Returns an array of shape `(num_branches, nseg)`. + + Args: + radius_fns: Functions which, given compartment locations return the radius. + branch_indices: The indices of the branches for which to return the radiuses. + min_radius: If passed, the radiuses are clipped to be at least as large. + nseg: The number of compartments that every branch is discretized into. + """ + # Compartment locations are at the center of the internal nodes. + non_split = 1 / nseg + range_ = np.linspace(non_split / 2, 1 - non_split / 2, nseg) + + # Build radiuses. + radiuses = np.asarray([radius_fns[b](range_) for b in branch_indices]) + radiuses_each = radiuses.ravel(order="C") + if min_radius is None: + assert np.all( + radiuses_each > 0.0 + ), "Radius 0.0 in SWC file. Set `read_swc(..., min_radius=...)`." + else: + radiuses_each[radiuses_each < min_radius] = min_radius + + return radiuses_each + + def equal_segments(branch_property: list, nseg_per_branch: int): """Generates segments where some property is the same in each segment. diff --git a/jaxley/utils/swc.py b/jaxley/utils/swc.py deleted file mode 100644 index 8659d418..00000000 --- a/jaxley/utils/swc.py +++ /dev/null @@ -1,354 +0,0 @@ -# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is -# licensed under the Apache License Version 2.0, see - -from copy import copy -from typing import Callable, List, Optional, Tuple -from warnings import warn - -import jax.numpy as jnp -import numpy as np - - -def swc_to_jaxley( - fname: str, - max_branch_len: float = 100.0, - sort: bool = True, - num_lines: Optional[int] = None, -) -> Tuple[List[int], List[float], List[Callable], List[float], List[np.ndarray]]: - """Read an SWC file and bring morphology into `jaxley` compatible formats. - - Args: - fname: Path to swc file. - max_branch_len: Maximal length of one branch. If a branch exceeds this length, - it is split into equal parts such that each subbranch is below - `max_branch_len`. - num_lines: Number of lines of the SWC file to read. - """ - content = np.loadtxt(fname)[:num_lines] - types = content[:, 1] - is_single_point_soma = types[0] == 1 and types[1] != 1 - - if is_single_point_soma: - # Warn here, but the conversion of the length happens in `_compute_pathlengths`. - warn( - "Found a soma which consists of a single traced point. `Jaxley` " - "interprets this soma as a spherical compartment with radius " - "specified in the SWC file, i.e. with surface area 4*pi*r*r." - ) - sorted_branches, types = _split_into_branches_and_sort( - content, - max_branch_len=max_branch_len, - is_single_point_soma=is_single_point_soma, - sort=sort, - ) - - parents = _build_parents(sorted_branches) - each_length = _compute_pathlengths( - sorted_branches, content[:, 1:6], is_single_point_soma=is_single_point_soma - ) - pathlengths = [np.sum(length_traced) for length_traced in each_length] - for i, pathlen in enumerate(pathlengths): - if pathlen == 0.0: - warn("Found a segment with length 0. Clipping it to 1.0") - pathlengths[i] = 1.0 - radius_fns = _radius_generating_fns( - sorted_branches, content[:, 5], each_length, parents, types - ) - - if np.sum(np.asarray(parents) == -1) > 1.0: - parents = np.asarray([-1] + parents) - parents[1:] += 1 - parents = parents.tolist() - pathlengths = [0.1] + pathlengths - radius_fns = [lambda x: content[0, 5] * np.ones_like(x)] + radius_fns - sorted_branches = [[0]] + sorted_branches - - # Type of padded section is assumed to be of `custom` type: - # http://www.neuronland.org/NLMorphologyConverter/MorphologyFormats/SWC/Spec.html - types = [5.0] + types - - all_coords_of_branches = [] - for i, branch in enumerate(sorted_branches): - # Remove 1 because `content` is an array that is indexed from 0. - branch = np.asarray(branch) - 1 - - # Deal with additional branch that might have been added above in the lines - # `if np.sum(np.asarray(parents) == -1) > 1.0:` - branch[branch < 0] = 0 - - # Get traced coordinates of the branch. - coords_of_branch = content[branch, 2:6] - all_coords_of_branches.append(coords_of_branch) - - return parents, pathlengths, radius_fns, types, all_coords_of_branches - - -def _split_into_branches_and_sort( - content: np.ndarray, - max_branch_len: float, - is_single_point_soma: bool, - sort: bool = True, -) -> Tuple[np.ndarray, np.ndarray]: - branches, types = _split_into_branches(content, is_single_point_soma) - branches, types = _split_long_branches( - branches, - types, - content, - max_branch_len, - is_single_point_soma=is_single_point_soma, - ) - - if sort: - first_val = np.asarray([b[0] for b in branches]) - sorting = np.argsort(first_val, kind="mergesort") - sorted_branches = [branches[s] for s in sorting] - sorted_types = [types[s] for s in sorting] - else: - sorted_branches = branches - sorted_types = types - return sorted_branches, sorted_types - - -def _split_long_branches( - branches: np.ndarray, - types: np.ndarray, - content: np.ndarray, - max_branch_len: float, - is_single_point_soma: bool, -) -> Tuple[np.ndarray, np.ndarray]: - pathlengths = _compute_pathlengths( - branches, content[:, 1:6], is_single_point_soma=is_single_point_soma - ) - pathlengths = [np.sum(length_traced) for length_traced in pathlengths] - split_branches = [] - split_types = [] - for branch, type, length in zip(branches, types, pathlengths): - num_subbranches = 1 - split_branch = [branch] - while length > max_branch_len: - num_subbranches += 1 - split_branch = _split_branch_equally(branch, num_subbranches) - lengths_of_subbranches = _compute_pathlengths( - split_branch, - coords=content[:, 1:6], - is_single_point_soma=is_single_point_soma, - ) - lengths_of_subbranches = [ - np.sum(length_traced) for length_traced in lengths_of_subbranches - ] - length = max(lengths_of_subbranches) - if num_subbranches > 10: - warn( - """`num_subbranches > 10`, stopping to split. Most likely your - SWC reconstruction is not dense and some neighbouring traced - points are farther than `max_branch_len` apart.""" - ) - break - split_branches += split_branch - split_types += [type] * num_subbranches - - return split_branches, split_types - - -def _split_branch_equally(branch: np.ndarray, num_subbranches: int) -> List[np.ndarray]: - num_points_each = len(branch) // num_subbranches - branches = [branch[:num_points_each]] - for i in range(1, num_subbranches - 1): - branches.append(branch[i * num_points_each - 1 : (i + 1) * num_points_each]) - branches.append(branch[(num_subbranches - 1) * num_points_each - 1 :]) - return branches - - -def _split_into_branches( - content: np.ndarray, is_single_point_soma: bool -) -> Tuple[np.ndarray, np.ndarray]: - prev_ind = None - prev_type = None - n_branches = 0 - - # Branch inds will contain the row identifier at which a branch point occurs - # (i.e. the row of the parent of two branches). - branch_inds = [] - for c in content: - current_ind = c[0] - current_parent = c[-1] - current_type = c[1] - if current_parent != prev_ind or current_type != prev_type: - branch_inds.append(int(current_parent)) - n_branches += 1 - prev_ind = current_ind - prev_type = current_type - - all_branches = [] - current_branch = [] - all_types = [] - - # Loop over every line in the SWC file. - for c in content: - current_ind = c[0] # First col is row_identifier - current_parent = c[-1] # Last col is parent in SWC specification. - if current_parent == -1: - all_types.append(c[1]) - else: - current_type = c[1] - - if current_parent == -1 and is_single_point_soma and current_ind == 1: - all_branches.append([int(current_ind)]) - all_types.append(int(current_type)) - - # Either append the current point to the branch, or add the branch to - # `all_branches`. - if current_parent in branch_inds[1:]: - if len(current_branch) > 1: - all_branches.append(current_branch) - all_types.append(current_type) - current_branch = [int(current_parent), int(current_ind)] - else: - current_branch.append(int(current_ind)) - - # Append the final branch (intermediate branches are already appended five lines - # above.) - all_branches.append(current_branch) - return all_branches, all_types - - -def _build_parents(all_branches: List[np.ndarray]) -> List[int]: - parents = [None] * len(all_branches) - all_last_inds = [b[-1] for b in all_branches] - for i, branch in enumerate(all_branches): - parent_ind = branch[0] - ind = np.where(np.asarray(all_last_inds) == parent_ind)[0] - if len(ind) > 0 and ind != i: - parents[i] = ind[0] - else: - assert ( - parent_ind == 1 - ), """Trying to connect a segment to the beginning of - another segment. This is not allowed. Please create an issue on github.""" - parents[i] = -1 - - return parents - - -def _radius_generating_fns( - all_branches: np.ndarray, - radiuses: np.ndarray, - each_length: np.ndarray, - parents: np.ndarray, - types: np.ndarray, -) -> List[Callable]: - """For all branches in a cell, returns callable that return radius given loc.""" - radius_fns = [] - for i, branch in enumerate(all_branches): - rads_in_branch = radiuses[np.asarray(branch) - 1] - if parents[i] > -1 and types[i] != types[parents[i]]: - # We do not want to linearly interpolate between the radius of the previous - # branch if a new type of neurite is found (e.g. switch from soma to - # apical). From looking at the SWC from n140.swc I believe that this is - # also what NEURON does. - rads_in_branch[0] = rads_in_branch[1] - radius_fn = _radius_generating_fn( - radiuses=rads_in_branch, each_length=each_length[i] - ) - # Beause SWC starts counting at 1, but numpy counts from 0. - # ind_of_branch_endpoint = np.asarray(b) - 1 - radius_fns.append(radius_fn) - return radius_fns - - -def _radius_generating_fn(radiuses: np.ndarray, each_length: np.ndarray) -> Callable: - # Avoid division by 0 with the `summed_len` below. - each_length[each_length < 1e-8] = 1e-8 - summed_len = np.sum(each_length) - cutoffs = np.cumsum(np.concatenate([np.asarray([0]), each_length])) / summed_len - cutoffs[0] -= 1e-8 - cutoffs[-1] += 1e-8 - - # We have to linearly interpolate radiuses, therefore we need at least two radiuses. - # However, jaxley allows somata which consist of a single traced point (i.e. - # just one radius). Therefore, we just `tile` in order to generate an artificial - # endpoint and startpoint radius of the soma. - if len(radiuses) == 1: - radiuses = np.tile(radiuses, 2) - - def radius(loc: float) -> float: - """Function which returns the radius via linear interpolation.""" - index = np.digitize(loc, cutoffs, right=False) - left_rad = radiuses[index - 1] - right_rad = radiuses[index] - left_loc = cutoffs[index - 1] - right_loc = cutoffs[index] - loc_within_bin = (loc - left_loc) / (right_loc - left_loc) - return left_rad + (right_rad - left_rad) * loc_within_bin - - return radius - - -def _compute_pathlengths( - all_branches: np.ndarray, coords: np.ndarray, is_single_point_soma: bool -) -> List[np.ndarray]: - """ - Args: - coords: Has shape (num_traced_points, 5), where `5` is (type, x, y, z, radius). - """ - branch_pathlengths = [] - for b in all_branches: - coords_in_branch = coords[np.asarray(b) - 1] - if len(coords_in_branch) > 1: - # If the branch starts at a different neurite (e.g. the soma) then NEURON - # ignores the distance from that initial point. To reproduce, use the - # following SWC dummy file and read it in NEURON (and Jaxley): - # 1 1 0.00 0.0 0.0 6.0 -1 - # 2 2 9.00 0.0 0.0 0.5 1 - # 3 2 10.0 0.0 0.0 0.3 2 - types = coords_in_branch[:, 0] - if int(types[0]) == 1 and int(types[1]) != 1 and is_single_point_soma: - coords_in_branch[0] = coords_in_branch[1] - - # Compute distances between all traced points in a branch. - point_diffs = np.diff(coords_in_branch, axis=0) - dists = np.sqrt( - point_diffs[:, 1] ** 2 + point_diffs[:, 2] ** 2 + point_diffs[:, 3] ** 2 - ) - else: - # Jaxley uses length and radius for every compartment and assumes the - # surface area to be 2*pi*r*length. For branches consisting of a single - # traced point we assume for them to have area 4*pi*r*r. Therefore, we have - # to set length = 2*r. - radius = coords_in_branch[0, 4] # txyzr -> 4 is radius. - dists = np.asarray([2 * radius]) - branch_pathlengths.append(dists) - return branch_pathlengths - - -def build_radiuses_from_xyzr( - radius_fns: List[Callable], - branch_indices: List[int], - min_radius: Optional[float], - nseg: int, -) -> jnp.ndarray: - """Return the radiuses of branches given SWC file xyzr. - - Returns an array of shape `(num_branches, nseg)`. - - Args: - radius_fns: Functions which, given compartment locations return the radius. - branch_indices: The indices of the branches for which to return the radiuses. - min_radius: If passed, the radiuses are clipped to be at least as large. - nseg: The number of compartments that every branch is discretized into. - """ - # Compartment locations are at the center of the internal nodes. - non_split = 1 / nseg - range_ = np.linspace(non_split / 2, 1 - non_split / 2, nseg) - - # Build radiuses. - radiuses = np.asarray([radius_fns[b](range_) for b in branch_indices]) - radiuses_each = radiuses.ravel(order="C") - if min_radius is None: - assert np.all( - radiuses_each > 0.0 - ), "Radius 0.0 in SWC file. Set `read_swc(..., min_radius=...)`." - else: - radiuses_each[radiuses_each < min_radius] = min_radius - - return radiuses_each diff --git a/tests/conftest.py b/tests/conftest.py index 89621245..01a97976 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -194,7 +194,7 @@ def get_or_compute_swc2jaxley_params( default_fname = os.path.join(dirname, "swc_files", "morph.swc") fname = default_fname if fname is None else fname if key := (fname, max_branch_len, sort) not in params or force_init: - params[key] = jx.utils.swc.swc_to_jaxley(fname, max_branch_len, sort) + params[key] = jx.io.swc.swc_to_jaxley(fname, max_branch_len, sort) return params[key] yield get_or_compute_swc2jaxley_params