diff --git a/jaxley/__init__.py b/jaxley/__init__.py
index ce2c261a..abc461f0 100644
--- a/jaxley/__init__.py
+++ b/jaxley/__init__.py
@@ -8,6 +8,7 @@
sparse_connect,
)
from jaxley.integrate import integrate
+from jaxley.io.swc import read_swc
from jaxley.modules import *
from jaxley.optimize import ParamTransform
from jaxley.stimulus import datapoint_to_step_currents, step_current
diff --git a/jaxley/io/swc.py b/jaxley/io/swc.py
new file mode 100644
index 00000000..ff131e62
--- /dev/null
+++ b/jaxley/io/swc.py
@@ -0,0 +1,168 @@
+# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
+# licensed under the Apache License Version 2.0, see
+
+from copy import copy
+from typing import Callable, List, Optional, Tuple
+from warnings import warn
+
+import jax.numpy as jnp
+import numpy as np
+
+from jaxley.modules import Branch, Cell, Compartment
+from jaxley.utils.cell_utils import (
+ _build_parents,
+ _compute_pathlengths,
+ _radius_generating_fns,
+ _split_into_branches_and_sort,
+ build_radiuses_from_xyzr,
+)
+
+
+def swc_to_jaxley(
+ fname: str,
+ max_branch_len: float = 100.0,
+ sort: bool = True,
+ num_lines: Optional[int] = None,
+) -> Tuple[List[int], List[float], List[Callable], List[float], List[np.ndarray]]:
+ """Read an SWC file and bring morphology into `jaxley` compatible formats.
+
+ Args:
+ fname: Path to swc file.
+ max_branch_len: Maximal length of one branch. If a branch exceeds this length,
+ it is split into equal parts such that each subbranch is below
+ `max_branch_len`.
+ num_lines: Number of lines of the SWC file to read.
+ """
+ content = np.loadtxt(fname)[:num_lines]
+ types = content[:, 1]
+ is_single_point_soma = types[0] == 1 and types[1] != 1
+
+ if is_single_point_soma:
+ # Warn here, but the conversion of the length happens in `_compute_pathlengths`.
+ warn(
+ "Found a soma which consists of a single traced point. `Jaxley` "
+ "interprets this soma as a spherical compartment with radius "
+ "specified in the SWC file, i.e. with surface area 4*pi*r*r."
+ )
+ sorted_branches, types = _split_into_branches_and_sort(
+ content,
+ max_branch_len=max_branch_len,
+ is_single_point_soma=is_single_point_soma,
+ sort=sort,
+ )
+
+ parents = _build_parents(sorted_branches)
+ each_length = _compute_pathlengths(
+ sorted_branches, content[:, 1:6], is_single_point_soma=is_single_point_soma
+ )
+ pathlengths = [np.sum(length_traced) for length_traced in each_length]
+ for i, pathlen in enumerate(pathlengths):
+ if pathlen == 0.0:
+ warn("Found a segment with length 0. Clipping it to 1.0")
+ pathlengths[i] = 1.0
+ radius_fns = _radius_generating_fns(
+ sorted_branches, content[:, 5], each_length, parents, types
+ )
+
+ if np.sum(np.asarray(parents) == -1) > 1.0:
+ parents = np.asarray([-1] + parents)
+ parents[1:] += 1
+ parents = parents.tolist()
+ pathlengths = [0.1] + pathlengths
+ radius_fns = [lambda x: content[0, 5] * np.ones_like(x)] + radius_fns
+ sorted_branches = [[0]] + sorted_branches
+
+ # Type of padded section is assumed to be of `custom` type:
+ # http://www.neuronland.org/NLMorphologyConverter/MorphologyFormats/SWC/Spec.html
+ types = [5.0] + types
+
+ all_coords_of_branches = []
+ for i, branch in enumerate(sorted_branches):
+ # Remove 1 because `content` is an array that is indexed from 0.
+ branch = np.asarray(branch) - 1
+
+ # Deal with additional branch that might have been added above in the lines
+ # `if np.sum(np.asarray(parents) == -1) > 1.0:`
+ branch[branch < 0] = 0
+
+ # Get traced coordinates of the branch.
+ coords_of_branch = content[branch, 2:6]
+ all_coords_of_branches.append(coords_of_branch)
+
+ return parents, pathlengths, radius_fns, types, all_coords_of_branches
+
+
+def read_swc(
+ fname: str,
+ nseg: int,
+ max_branch_len: float = 300.0,
+ min_radius: Optional[float] = None,
+ assign_groups: bool = False,
+) -> Cell:
+ """Reads SWC file into a `Cell`.
+
+ Jaxley assumes cylindrical compartments and therefore defines length and radius
+ for every compartment. The surface area is then 2*pi*r*length. For branches
+ consisting of a single traced point we assume for them to have area 4*pi*r*r.
+ Therefore, in these cases, we set lenght=2*r.
+
+ Args:
+ fname: Path to the swc file.
+ nseg: The number of compartments per branch.
+ max_branch_len: If a branch is longer than this value it is split into two
+ branches.
+ min_radius: If the radius of a reconstruction is below this value it is clipped.
+ assign_groups: If True, then the identity of reconstructed points in the SWC
+ file will be used to generate groups `undefined`, `soma`, `axon`, `basal`,
+ `apical`, `custom`. See here:
+ http://www.neuronland.org/NLMorphologyConverter/MorphologyFormats/SWC/Spec.html
+
+ Returns:
+ A `Cell` object.
+ """
+ parents, pathlengths, radius_fns, types, coords_of_branches = swc_to_jaxley(
+ fname, max_branch_len=max_branch_len, sort=True, num_lines=None
+ )
+ nbranches = len(parents)
+
+ comp = Compartment()
+ branch = Branch([comp for _ in range(nseg)])
+ cell = Cell(
+ [branch for _ in range(nbranches)], parents=parents, xyzr=coords_of_branches
+ )
+ # Also save the radius generating functions in case users post-hoc modify the number
+ # of compartments with `.set_ncomp()`.
+ cell._radius_generating_fns = radius_fns
+
+ lengths_each = np.repeat(pathlengths, nseg) / nseg
+ cell.set("length", lengths_each)
+
+ radiuses_each = build_radiuses_from_xyzr(
+ radius_fns,
+ range(len(parents)),
+ min_radius,
+ nseg,
+ )
+ cell.set("radius", radiuses_each)
+
+ # Description of SWC file format:
+ # http://www.neuronland.org/NLMorphologyConverter/MorphologyFormats/SWC/Spec.html
+ ind_name_lookup = {
+ 0: "undefined",
+ 1: "soma",
+ 2: "axon",
+ 3: "basal",
+ 4: "apical",
+ 5: "custom",
+ }
+ types = np.asarray(types).astype(int)
+ if assign_groups:
+ for type_ind in np.unique(types):
+ if type_ind < 5.5:
+ name = ind_name_lookup[type_ind]
+ else:
+ name = f"custom{type_ind}"
+ indices = np.where(types == type_ind)[0].tolist()
+ if len(indices) > 0:
+ cell.branch(indices).add_to_group(name)
+ return cell
diff --git a/jaxley/modules/__init__.py b/jaxley/modules/__init__.py
index 584ca3e5..1bcca2b6 100644
--- a/jaxley/modules/__init__.py
+++ b/jaxley/modules/__init__.py
@@ -3,6 +3,6 @@
from jaxley.modules.base import Module
from jaxley.modules.branch import Branch
-from jaxley.modules.cell import Cell, read_swc
+from jaxley.modules.cell import Cell
from jaxley.modules.compartment import Compartment
from jaxley.modules.network import Network
diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py
index 1409a6eb..8441c6e6 100644
--- a/jaxley/modules/base.py
+++ b/jaxley/modules/base.py
@@ -26,6 +26,7 @@
from jaxley.utils.cell_utils import (
_compute_index_of_child,
_compute_num_children,
+ build_radiuses_from_xyzr,
compute_axial_conductances,
compute_levels,
convert_point_process_to_distributed,
@@ -39,7 +40,6 @@
from jaxley.utils.misc_utils import cumsum_leading_zero, is_str_all
from jaxley.utils.plot_utils import plot_comps, plot_graph, plot_morph
from jaxley.utils.solver_utils import convert_to_csc
-from jaxley.utils.swc import build_radiuses_from_xyzr
def only_allow_module(func):
diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py
index cc307f22..7db12cdd 100644
--- a/jaxley/modules/cell.py
+++ b/jaxley/modules/cell.py
@@ -8,8 +8,7 @@
import pandas as pd
from jaxley.modules.base import Module
-from jaxley.modules.branch import Branch, Compartment
-from jaxley.synapses import Synapse
+from jaxley.modules.branch import Branch
from jaxley.utils.cell_utils import (
build_branchpoint_group_inds,
compute_children_and_parents,
@@ -25,7 +24,6 @@
comp_edges_to_indices,
remap_index_to_masked,
)
-from jaxley.utils.swc import build_radiuses_from_xyzr, swc_to_jaxley
class Cell(Module):
@@ -271,79 +269,3 @@ def _init_morph_jax_spsolve(self):
self._data_inds = data_inds
self._indices_jax_spsolve = indices
self._indptr_jax_spsolve = indptr
-
-
-def read_swc(
- fname: str,
- nseg: int,
- max_branch_len: float = 300.0,
- min_radius: Optional[float] = None,
- assign_groups: bool = False,
-) -> Cell:
- """Reads SWC file into a `jx.Cell`.
-
- Jaxley assumes cylindrical compartments and therefore defines length and radius
- for every compartment. The surface area is then 2*pi*r*length. For branches
- consisting of a single traced point we assume for them to have area 4*pi*r*r.
- Therefore, in these cases, we set lenght=2*r.
-
- Args:
- fname: Path to the swc file.
- nseg: The number of compartments per branch.
- max_branch_len: If a branch is longer than this value it is split into two
- branches.
- min_radius: If the radius of a reconstruction is below this value it is clipped.
- assign_groups: If True, then the identity of reconstructed points in the SWC
- file will be used to generate groups `undefined`, `soma`, `axon`, `basal`,
- `apical`, `custom`. See here:
- http://www.neuronland.org/NLMorphologyConverter/MorphologyFormats/SWC/Spec.html
-
- Returns:
- A `jx.Cell` object.
- """
- parents, pathlengths, radius_fns, types, coords_of_branches = swc_to_jaxley(
- fname, max_branch_len=max_branch_len, sort=True, num_lines=None
- )
- nbranches = len(parents)
-
- comp = Compartment()
- branch = Branch([comp for _ in range(nseg)])
- cell = Cell(
- [branch for _ in range(nbranches)], parents=parents, xyzr=coords_of_branches
- )
- # Also save the radius generating functions in case users post-hoc modify the number
- # of compartments with `.set_ncomp()`.
- cell._radius_generating_fns = radius_fns
-
- lengths_each = np.repeat(pathlengths, nseg) / nseg
- cell.set("length", lengths_each)
-
- radiuses_each = build_radiuses_from_xyzr(
- radius_fns,
- range(len(parents)),
- min_radius,
- nseg,
- )
- cell.set("radius", radiuses_each)
-
- # Description of SWC file format:
- # http://www.neuronland.org/NLMorphologyConverter/MorphologyFormats/SWC/Spec.html
- ind_name_lookup = {
- 0: "undefined",
- 1: "soma",
- 2: "axon",
- 3: "basal",
- 4: "apical",
- 5: "custom",
- }
- types = np.asarray(types).astype(int)
- if assign_groups:
- for type_ind in np.unique(types):
- if type_ind < 5.5:
- name = ind_name_lookup[type_ind]
- else:
- name = f"custom{type_ind}"
- indices = np.where(types == type_ind)[0].tolist()
- if len(indices) > 0:
- cell.branch(indices).add_to_group(name)
- return cell
diff --git a/jaxley/utils/cell_utils.py b/jaxley/utils/cell_utils.py
index 961b6533..d47ab6ea 100644
--- a/jaxley/utils/cell_utils.py
+++ b/jaxley/utils/cell_utils.py
@@ -2,7 +2,8 @@
# licensed under the Apache License Version 2.0, see
from math import pi
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Callable, Dict, List, Optional, Tuple, Union
+from warnings import warn
import jax.numpy as jnp
import numpy as np
@@ -12,6 +13,277 @@
from jaxley.utils.misc_utils import cumsum_leading_zero
+def _split_into_branches_and_sort(
+ content: np.ndarray,
+ max_branch_len: float,
+ is_single_point_soma: bool,
+ sort: bool = True,
+) -> Tuple[np.ndarray, np.ndarray]:
+ branches, types = _split_into_branches(content, is_single_point_soma)
+ branches, types = _split_long_branches(
+ branches,
+ types,
+ content,
+ max_branch_len,
+ is_single_point_soma=is_single_point_soma,
+ )
+
+ if sort:
+ first_val = np.asarray([b[0] for b in branches])
+ sorting = np.argsort(first_val, kind="mergesort")
+ sorted_branches = [branches[s] for s in sorting]
+ sorted_types = [types[s] for s in sorting]
+ else:
+ sorted_branches = branches
+ sorted_types = types
+ return sorted_branches, sorted_types
+
+
+def _split_long_branches(
+ branches: np.ndarray,
+ types: np.ndarray,
+ content: np.ndarray,
+ max_branch_len: float,
+ is_single_point_soma: bool,
+) -> Tuple[np.ndarray, np.ndarray]:
+ pathlengths = _compute_pathlengths(
+ branches, content[:, 1:6], is_single_point_soma=is_single_point_soma
+ )
+ pathlengths = [np.sum(length_traced) for length_traced in pathlengths]
+ split_branches = []
+ split_types = []
+ for branch, type, length in zip(branches, types, pathlengths):
+ num_subbranches = 1
+ split_branch = [branch]
+ while length > max_branch_len:
+ num_subbranches += 1
+ split_branch = _split_branch_equally(branch, num_subbranches)
+ lengths_of_subbranches = _compute_pathlengths(
+ split_branch,
+ coords=content[:, 1:6],
+ is_single_point_soma=is_single_point_soma,
+ )
+ lengths_of_subbranches = [
+ np.sum(length_traced) for length_traced in lengths_of_subbranches
+ ]
+ length = max(lengths_of_subbranches)
+ if num_subbranches > 10:
+ warn(
+ """`num_subbranches > 10`, stopping to split. Most likely your
+ SWC reconstruction is not dense and some neighbouring traced
+ points are farther than `max_branch_len` apart."""
+ )
+ break
+ split_branches += split_branch
+ split_types += [type] * num_subbranches
+
+ return split_branches, split_types
+
+
+def _split_branch_equally(branch: np.ndarray, num_subbranches: int) -> List[np.ndarray]:
+ num_points_each = len(branch) // num_subbranches
+ branches = [branch[:num_points_each]]
+ for i in range(1, num_subbranches - 1):
+ branches.append(branch[i * num_points_each - 1 : (i + 1) * num_points_each])
+ branches.append(branch[(num_subbranches - 1) * num_points_each - 1 :])
+ return branches
+
+
+def _split_into_branches(
+ content: np.ndarray, is_single_point_soma: bool
+) -> Tuple[np.ndarray, np.ndarray]:
+ prev_ind = None
+ prev_type = None
+ n_branches = 0
+
+ # Branch inds will contain the row identifier at which a branch point occurs
+ # (i.e. the row of the parent of two branches).
+ branch_inds = []
+ for c in content:
+ current_ind = c[0]
+ current_parent = c[-1]
+ current_type = c[1]
+ if current_parent != prev_ind or current_type != prev_type:
+ branch_inds.append(int(current_parent))
+ n_branches += 1
+ prev_ind = current_ind
+ prev_type = current_type
+
+ all_branches = []
+ current_branch = []
+ all_types = []
+
+ # Loop over every line in the SWC file.
+ for c in content:
+ current_ind = c[0] # First col is row_identifier
+ current_parent = c[-1] # Last col is parent in SWC specification.
+ if current_parent == -1:
+ all_types.append(c[1])
+ else:
+ current_type = c[1]
+
+ if current_parent == -1 and is_single_point_soma and current_ind == 1:
+ all_branches.append([int(current_ind)])
+ all_types.append(int(current_type))
+
+ # Either append the current point to the branch, or add the branch to
+ # `all_branches`.
+ if current_parent in branch_inds[1:]:
+ if len(current_branch) > 1:
+ all_branches.append(current_branch)
+ all_types.append(current_type)
+ current_branch = [int(current_parent), int(current_ind)]
+ else:
+ current_branch.append(int(current_ind))
+
+ # Append the final branch (intermediate branches are already appended five lines
+ # above.)
+ all_branches.append(current_branch)
+ return all_branches, all_types
+
+
+def _build_parents(all_branches: List[np.ndarray]) -> List[int]:
+ parents = [None] * len(all_branches)
+ all_last_inds = [b[-1] for b in all_branches]
+ for i, branch in enumerate(all_branches):
+ parent_ind = branch[0]
+ ind = np.where(np.asarray(all_last_inds) == parent_ind)[0]
+ if len(ind) > 0 and ind != i:
+ parents[i] = ind[0]
+ else:
+ assert (
+ parent_ind == 1
+ ), """Trying to connect a segment to the beginning of
+ another segment. This is not allowed. Please create an issue on github."""
+ parents[i] = -1
+
+ return parents
+
+
+def _radius_generating_fns(
+ all_branches: np.ndarray,
+ radiuses: np.ndarray,
+ each_length: np.ndarray,
+ parents: np.ndarray,
+ types: np.ndarray,
+) -> List[Callable]:
+ """For all branches in a cell, returns callable that return radius given loc."""
+ radius_fns = []
+ for i, branch in enumerate(all_branches):
+ rads_in_branch = radiuses[np.asarray(branch) - 1]
+ if parents[i] > -1 and types[i] != types[parents[i]]:
+ # We do not want to linearly interpolate between the radius of the previous
+ # branch if a new type of neurite is found (e.g. switch from soma to
+ # apical). From looking at the SWC from n140.swc I believe that this is
+ # also what NEURON does.
+ rads_in_branch[0] = rads_in_branch[1]
+ radius_fn = _radius_generating_fn(
+ radiuses=rads_in_branch, each_length=each_length[i]
+ )
+ # Beause SWC starts counting at 1, but numpy counts from 0.
+ # ind_of_branch_endpoint = np.asarray(b) - 1
+ radius_fns.append(radius_fn)
+ return radius_fns
+
+
+def _radius_generating_fn(radiuses: np.ndarray, each_length: np.ndarray) -> Callable:
+ # Avoid division by 0 with the `summed_len` below.
+ each_length[each_length < 1e-8] = 1e-8
+ summed_len = np.sum(each_length)
+ cutoffs = np.cumsum(np.concatenate([np.asarray([0]), each_length])) / summed_len
+ cutoffs[0] -= 1e-8
+ cutoffs[-1] += 1e-8
+
+ # We have to linearly interpolate radiuses, therefore we need at least two radiuses.
+ # However, jaxley allows somata which consist of a single traced point (i.e.
+ # just one radius). Therefore, we just `tile` in order to generate an artificial
+ # endpoint and startpoint radius of the soma.
+ if len(radiuses) == 1:
+ radiuses = np.tile(radiuses, 2)
+
+ def radius(loc: float) -> float:
+ """Function which returns the radius via linear interpolation."""
+ index = np.digitize(loc, cutoffs, right=False)
+ left_rad = radiuses[index - 1]
+ right_rad = radiuses[index]
+ left_loc = cutoffs[index - 1]
+ right_loc = cutoffs[index]
+ loc_within_bin = (loc - left_loc) / (right_loc - left_loc)
+ return left_rad + (right_rad - left_rad) * loc_within_bin
+
+ return radius
+
+
+def _compute_pathlengths(
+ all_branches: np.ndarray, coords: np.ndarray, is_single_point_soma: bool
+) -> List[np.ndarray]:
+ """
+ Args:
+ coords: Has shape (num_traced_points, 5), where `5` is (type, x, y, z, radius).
+ """
+ branch_pathlengths = []
+ for b in all_branches:
+ coords_in_branch = coords[np.asarray(b) - 1]
+ if len(coords_in_branch) > 1:
+ # If the branch starts at a different neurite (e.g. the soma) then NEURON
+ # ignores the distance from that initial point. To reproduce, use the
+ # following SWC dummy file and read it in NEURON (and Jaxley):
+ # 1 1 0.00 0.0 0.0 6.0 -1
+ # 2 2 9.00 0.0 0.0 0.5 1
+ # 3 2 10.0 0.0 0.0 0.3 2
+ types = coords_in_branch[:, 0]
+ if int(types[0]) == 1 and int(types[1]) != 1 and is_single_point_soma:
+ coords_in_branch[0] = coords_in_branch[1]
+
+ # Compute distances between all traced points in a branch.
+ point_diffs = np.diff(coords_in_branch, axis=0)
+ dists = np.sqrt(
+ point_diffs[:, 1] ** 2 + point_diffs[:, 2] ** 2 + point_diffs[:, 3] ** 2
+ )
+ else:
+ # Jaxley uses length and radius for every compartment and assumes the
+ # surface area to be 2*pi*r*length. For branches consisting of a single
+ # traced point we assume for them to have area 4*pi*r*r. Therefore, we have
+ # to set length = 2*r.
+ radius = coords_in_branch[0, 4] # txyzr -> 4 is radius.
+ dists = np.asarray([2 * radius])
+ branch_pathlengths.append(dists)
+ return branch_pathlengths
+
+
+def build_radiuses_from_xyzr(
+ radius_fns: List[Callable],
+ branch_indices: List[int],
+ min_radius: Optional[float],
+ nseg: int,
+) -> jnp.ndarray:
+ """Return the radiuses of branches given SWC file xyzr.
+
+ Returns an array of shape `(num_branches, nseg)`.
+
+ Args:
+ radius_fns: Functions which, given compartment locations return the radius.
+ branch_indices: The indices of the branches for which to return the radiuses.
+ min_radius: If passed, the radiuses are clipped to be at least as large.
+ nseg: The number of compartments that every branch is discretized into.
+ """
+ # Compartment locations are at the center of the internal nodes.
+ non_split = 1 / nseg
+ range_ = np.linspace(non_split / 2, 1 - non_split / 2, nseg)
+
+ # Build radiuses.
+ radiuses = np.asarray([radius_fns[b](range_) for b in branch_indices])
+ radiuses_each = radiuses.ravel(order="C")
+ if min_radius is None:
+ assert np.all(
+ radiuses_each > 0.0
+ ), "Radius 0.0 in SWC file. Set `read_swc(..., min_radius=...)`."
+ else:
+ radiuses_each[radiuses_each < min_radius] = min_radius
+
+ return radiuses_each
+
+
def equal_segments(branch_property: list, nseg_per_branch: int):
"""Generates segments where some property is the same in each segment.
diff --git a/jaxley/utils/swc.py b/jaxley/utils/swc.py
deleted file mode 100644
index 8659d418..00000000
--- a/jaxley/utils/swc.py
+++ /dev/null
@@ -1,354 +0,0 @@
-# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
-# licensed under the Apache License Version 2.0, see
-
-from copy import copy
-from typing import Callable, List, Optional, Tuple
-from warnings import warn
-
-import jax.numpy as jnp
-import numpy as np
-
-
-def swc_to_jaxley(
- fname: str,
- max_branch_len: float = 100.0,
- sort: bool = True,
- num_lines: Optional[int] = None,
-) -> Tuple[List[int], List[float], List[Callable], List[float], List[np.ndarray]]:
- """Read an SWC file and bring morphology into `jaxley` compatible formats.
-
- Args:
- fname: Path to swc file.
- max_branch_len: Maximal length of one branch. If a branch exceeds this length,
- it is split into equal parts such that each subbranch is below
- `max_branch_len`.
- num_lines: Number of lines of the SWC file to read.
- """
- content = np.loadtxt(fname)[:num_lines]
- types = content[:, 1]
- is_single_point_soma = types[0] == 1 and types[1] != 1
-
- if is_single_point_soma:
- # Warn here, but the conversion of the length happens in `_compute_pathlengths`.
- warn(
- "Found a soma which consists of a single traced point. `Jaxley` "
- "interprets this soma as a spherical compartment with radius "
- "specified in the SWC file, i.e. with surface area 4*pi*r*r."
- )
- sorted_branches, types = _split_into_branches_and_sort(
- content,
- max_branch_len=max_branch_len,
- is_single_point_soma=is_single_point_soma,
- sort=sort,
- )
-
- parents = _build_parents(sorted_branches)
- each_length = _compute_pathlengths(
- sorted_branches, content[:, 1:6], is_single_point_soma=is_single_point_soma
- )
- pathlengths = [np.sum(length_traced) for length_traced in each_length]
- for i, pathlen in enumerate(pathlengths):
- if pathlen == 0.0:
- warn("Found a segment with length 0. Clipping it to 1.0")
- pathlengths[i] = 1.0
- radius_fns = _radius_generating_fns(
- sorted_branches, content[:, 5], each_length, parents, types
- )
-
- if np.sum(np.asarray(parents) == -1) > 1.0:
- parents = np.asarray([-1] + parents)
- parents[1:] += 1
- parents = parents.tolist()
- pathlengths = [0.1] + pathlengths
- radius_fns = [lambda x: content[0, 5] * np.ones_like(x)] + radius_fns
- sorted_branches = [[0]] + sorted_branches
-
- # Type of padded section is assumed to be of `custom` type:
- # http://www.neuronland.org/NLMorphologyConverter/MorphologyFormats/SWC/Spec.html
- types = [5.0] + types
-
- all_coords_of_branches = []
- for i, branch in enumerate(sorted_branches):
- # Remove 1 because `content` is an array that is indexed from 0.
- branch = np.asarray(branch) - 1
-
- # Deal with additional branch that might have been added above in the lines
- # `if np.sum(np.asarray(parents) == -1) > 1.0:`
- branch[branch < 0] = 0
-
- # Get traced coordinates of the branch.
- coords_of_branch = content[branch, 2:6]
- all_coords_of_branches.append(coords_of_branch)
-
- return parents, pathlengths, radius_fns, types, all_coords_of_branches
-
-
-def _split_into_branches_and_sort(
- content: np.ndarray,
- max_branch_len: float,
- is_single_point_soma: bool,
- sort: bool = True,
-) -> Tuple[np.ndarray, np.ndarray]:
- branches, types = _split_into_branches(content, is_single_point_soma)
- branches, types = _split_long_branches(
- branches,
- types,
- content,
- max_branch_len,
- is_single_point_soma=is_single_point_soma,
- )
-
- if sort:
- first_val = np.asarray([b[0] for b in branches])
- sorting = np.argsort(first_val, kind="mergesort")
- sorted_branches = [branches[s] for s in sorting]
- sorted_types = [types[s] for s in sorting]
- else:
- sorted_branches = branches
- sorted_types = types
- return sorted_branches, sorted_types
-
-
-def _split_long_branches(
- branches: np.ndarray,
- types: np.ndarray,
- content: np.ndarray,
- max_branch_len: float,
- is_single_point_soma: bool,
-) -> Tuple[np.ndarray, np.ndarray]:
- pathlengths = _compute_pathlengths(
- branches, content[:, 1:6], is_single_point_soma=is_single_point_soma
- )
- pathlengths = [np.sum(length_traced) for length_traced in pathlengths]
- split_branches = []
- split_types = []
- for branch, type, length in zip(branches, types, pathlengths):
- num_subbranches = 1
- split_branch = [branch]
- while length > max_branch_len:
- num_subbranches += 1
- split_branch = _split_branch_equally(branch, num_subbranches)
- lengths_of_subbranches = _compute_pathlengths(
- split_branch,
- coords=content[:, 1:6],
- is_single_point_soma=is_single_point_soma,
- )
- lengths_of_subbranches = [
- np.sum(length_traced) for length_traced in lengths_of_subbranches
- ]
- length = max(lengths_of_subbranches)
- if num_subbranches > 10:
- warn(
- """`num_subbranches > 10`, stopping to split. Most likely your
- SWC reconstruction is not dense and some neighbouring traced
- points are farther than `max_branch_len` apart."""
- )
- break
- split_branches += split_branch
- split_types += [type] * num_subbranches
-
- return split_branches, split_types
-
-
-def _split_branch_equally(branch: np.ndarray, num_subbranches: int) -> List[np.ndarray]:
- num_points_each = len(branch) // num_subbranches
- branches = [branch[:num_points_each]]
- for i in range(1, num_subbranches - 1):
- branches.append(branch[i * num_points_each - 1 : (i + 1) * num_points_each])
- branches.append(branch[(num_subbranches - 1) * num_points_each - 1 :])
- return branches
-
-
-def _split_into_branches(
- content: np.ndarray, is_single_point_soma: bool
-) -> Tuple[np.ndarray, np.ndarray]:
- prev_ind = None
- prev_type = None
- n_branches = 0
-
- # Branch inds will contain the row identifier at which a branch point occurs
- # (i.e. the row of the parent of two branches).
- branch_inds = []
- for c in content:
- current_ind = c[0]
- current_parent = c[-1]
- current_type = c[1]
- if current_parent != prev_ind or current_type != prev_type:
- branch_inds.append(int(current_parent))
- n_branches += 1
- prev_ind = current_ind
- prev_type = current_type
-
- all_branches = []
- current_branch = []
- all_types = []
-
- # Loop over every line in the SWC file.
- for c in content:
- current_ind = c[0] # First col is row_identifier
- current_parent = c[-1] # Last col is parent in SWC specification.
- if current_parent == -1:
- all_types.append(c[1])
- else:
- current_type = c[1]
-
- if current_parent == -1 and is_single_point_soma and current_ind == 1:
- all_branches.append([int(current_ind)])
- all_types.append(int(current_type))
-
- # Either append the current point to the branch, or add the branch to
- # `all_branches`.
- if current_parent in branch_inds[1:]:
- if len(current_branch) > 1:
- all_branches.append(current_branch)
- all_types.append(current_type)
- current_branch = [int(current_parent), int(current_ind)]
- else:
- current_branch.append(int(current_ind))
-
- # Append the final branch (intermediate branches are already appended five lines
- # above.)
- all_branches.append(current_branch)
- return all_branches, all_types
-
-
-def _build_parents(all_branches: List[np.ndarray]) -> List[int]:
- parents = [None] * len(all_branches)
- all_last_inds = [b[-1] for b in all_branches]
- for i, branch in enumerate(all_branches):
- parent_ind = branch[0]
- ind = np.where(np.asarray(all_last_inds) == parent_ind)[0]
- if len(ind) > 0 and ind != i:
- parents[i] = ind[0]
- else:
- assert (
- parent_ind == 1
- ), """Trying to connect a segment to the beginning of
- another segment. This is not allowed. Please create an issue on github."""
- parents[i] = -1
-
- return parents
-
-
-def _radius_generating_fns(
- all_branches: np.ndarray,
- radiuses: np.ndarray,
- each_length: np.ndarray,
- parents: np.ndarray,
- types: np.ndarray,
-) -> List[Callable]:
- """For all branches in a cell, returns callable that return radius given loc."""
- radius_fns = []
- for i, branch in enumerate(all_branches):
- rads_in_branch = radiuses[np.asarray(branch) - 1]
- if parents[i] > -1 and types[i] != types[parents[i]]:
- # We do not want to linearly interpolate between the radius of the previous
- # branch if a new type of neurite is found (e.g. switch from soma to
- # apical). From looking at the SWC from n140.swc I believe that this is
- # also what NEURON does.
- rads_in_branch[0] = rads_in_branch[1]
- radius_fn = _radius_generating_fn(
- radiuses=rads_in_branch, each_length=each_length[i]
- )
- # Beause SWC starts counting at 1, but numpy counts from 0.
- # ind_of_branch_endpoint = np.asarray(b) - 1
- radius_fns.append(radius_fn)
- return radius_fns
-
-
-def _radius_generating_fn(radiuses: np.ndarray, each_length: np.ndarray) -> Callable:
- # Avoid division by 0 with the `summed_len` below.
- each_length[each_length < 1e-8] = 1e-8
- summed_len = np.sum(each_length)
- cutoffs = np.cumsum(np.concatenate([np.asarray([0]), each_length])) / summed_len
- cutoffs[0] -= 1e-8
- cutoffs[-1] += 1e-8
-
- # We have to linearly interpolate radiuses, therefore we need at least two radiuses.
- # However, jaxley allows somata which consist of a single traced point (i.e.
- # just one radius). Therefore, we just `tile` in order to generate an artificial
- # endpoint and startpoint radius of the soma.
- if len(radiuses) == 1:
- radiuses = np.tile(radiuses, 2)
-
- def radius(loc: float) -> float:
- """Function which returns the radius via linear interpolation."""
- index = np.digitize(loc, cutoffs, right=False)
- left_rad = radiuses[index - 1]
- right_rad = radiuses[index]
- left_loc = cutoffs[index - 1]
- right_loc = cutoffs[index]
- loc_within_bin = (loc - left_loc) / (right_loc - left_loc)
- return left_rad + (right_rad - left_rad) * loc_within_bin
-
- return radius
-
-
-def _compute_pathlengths(
- all_branches: np.ndarray, coords: np.ndarray, is_single_point_soma: bool
-) -> List[np.ndarray]:
- """
- Args:
- coords: Has shape (num_traced_points, 5), where `5` is (type, x, y, z, radius).
- """
- branch_pathlengths = []
- for b in all_branches:
- coords_in_branch = coords[np.asarray(b) - 1]
- if len(coords_in_branch) > 1:
- # If the branch starts at a different neurite (e.g. the soma) then NEURON
- # ignores the distance from that initial point. To reproduce, use the
- # following SWC dummy file and read it in NEURON (and Jaxley):
- # 1 1 0.00 0.0 0.0 6.0 -1
- # 2 2 9.00 0.0 0.0 0.5 1
- # 3 2 10.0 0.0 0.0 0.3 2
- types = coords_in_branch[:, 0]
- if int(types[0]) == 1 and int(types[1]) != 1 and is_single_point_soma:
- coords_in_branch[0] = coords_in_branch[1]
-
- # Compute distances between all traced points in a branch.
- point_diffs = np.diff(coords_in_branch, axis=0)
- dists = np.sqrt(
- point_diffs[:, 1] ** 2 + point_diffs[:, 2] ** 2 + point_diffs[:, 3] ** 2
- )
- else:
- # Jaxley uses length and radius for every compartment and assumes the
- # surface area to be 2*pi*r*length. For branches consisting of a single
- # traced point we assume for them to have area 4*pi*r*r. Therefore, we have
- # to set length = 2*r.
- radius = coords_in_branch[0, 4] # txyzr -> 4 is radius.
- dists = np.asarray([2 * radius])
- branch_pathlengths.append(dists)
- return branch_pathlengths
-
-
-def build_radiuses_from_xyzr(
- radius_fns: List[Callable],
- branch_indices: List[int],
- min_radius: Optional[float],
- nseg: int,
-) -> jnp.ndarray:
- """Return the radiuses of branches given SWC file xyzr.
-
- Returns an array of shape `(num_branches, nseg)`.
-
- Args:
- radius_fns: Functions which, given compartment locations return the radius.
- branch_indices: The indices of the branches for which to return the radiuses.
- min_radius: If passed, the radiuses are clipped to be at least as large.
- nseg: The number of compartments that every branch is discretized into.
- """
- # Compartment locations are at the center of the internal nodes.
- non_split = 1 / nseg
- range_ = np.linspace(non_split / 2, 1 - non_split / 2, nseg)
-
- # Build radiuses.
- radiuses = np.asarray([radius_fns[b](range_) for b in branch_indices])
- radiuses_each = radiuses.ravel(order="C")
- if min_radius is None:
- assert np.all(
- radiuses_each > 0.0
- ), "Radius 0.0 in SWC file. Set `read_swc(..., min_radius=...)`."
- else:
- radiuses_each[radiuses_each < min_radius] = min_radius
-
- return radiuses_each
diff --git a/tests/conftest.py b/tests/conftest.py
index 89621245..01a97976 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -194,7 +194,7 @@ def get_or_compute_swc2jaxley_params(
default_fname = os.path.join(dirname, "swc_files", "morph.swc")
fname = default_fname if fname is None else fname
if key := (fname, max_branch_len, sort) not in params or force_init:
- params[key] = jx.utils.swc.swc_to_jaxley(fname, max_branch_len, sort)
+ params[key] = jx.io.swc.swc_to_jaxley(fname, max_branch_len, sort)
return params[key]
yield get_or_compute_swc2jaxley_params