Skip to content

Commit

Permalink
fix: fix circular imports
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Nov 21, 2024
1 parent a1d64a7 commit d55c87b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 279 deletions.
292 changes: 14 additions & 278 deletions jaxley/io/swc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@
import jax.numpy as jnp
import numpy as np

import jaxley as jx
from jaxley.modules import Branch, Cell, Compartment
from jaxley.utils.cell_utils import (
_build_parents,
_compute_pathlengths,
_radius_generating_fns,
_split_into_branches_and_sort,
build_radiuses_from_xyzr,
)


def swc_to_jaxley(
Expand Down Expand Up @@ -85,285 +92,14 @@ def swc_to_jaxley(
return parents, pathlengths, radius_fns, types, all_coords_of_branches


def _split_into_branches_and_sort(
content: np.ndarray,
max_branch_len: float,
is_single_point_soma: bool,
sort: bool = True,
) -> Tuple[np.ndarray, np.ndarray]:
branches, types = _split_into_branches(content, is_single_point_soma)
branches, types = _split_long_branches(
branches,
types,
content,
max_branch_len,
is_single_point_soma=is_single_point_soma,
)

if sort:
first_val = np.asarray([b[0] for b in branches])
sorting = np.argsort(first_val, kind="mergesort")
sorted_branches = [branches[s] for s in sorting]
sorted_types = [types[s] for s in sorting]
else:
sorted_branches = branches
sorted_types = types
return sorted_branches, sorted_types


def _split_long_branches(
branches: np.ndarray,
types: np.ndarray,
content: np.ndarray,
max_branch_len: float,
is_single_point_soma: bool,
) -> Tuple[np.ndarray, np.ndarray]:
pathlengths = _compute_pathlengths(
branches, content[:, 1:6], is_single_point_soma=is_single_point_soma
)
pathlengths = [np.sum(length_traced) for length_traced in pathlengths]
split_branches = []
split_types = []
for branch, type, length in zip(branches, types, pathlengths):
num_subbranches = 1
split_branch = [branch]
while length > max_branch_len:
num_subbranches += 1
split_branch = _split_branch_equally(branch, num_subbranches)
lengths_of_subbranches = _compute_pathlengths(
split_branch,
coords=content[:, 1:6],
is_single_point_soma=is_single_point_soma,
)
lengths_of_subbranches = [
np.sum(length_traced) for length_traced in lengths_of_subbranches
]
length = max(lengths_of_subbranches)
if num_subbranches > 10:
warn(
"""`num_subbranches > 10`, stopping to split. Most likely your
SWC reconstruction is not dense and some neighbouring traced
points are farther than `max_branch_len` apart."""
)
break
split_branches += split_branch
split_types += [type] * num_subbranches

return split_branches, split_types


def _split_branch_equally(branch: np.ndarray, num_subbranches: int) -> List[np.ndarray]:
num_points_each = len(branch) // num_subbranches
branches = [branch[:num_points_each]]
for i in range(1, num_subbranches - 1):
branches.append(branch[i * num_points_each - 1 : (i + 1) * num_points_each])
branches.append(branch[(num_subbranches - 1) * num_points_each - 1 :])
return branches


def _split_into_branches(
content: np.ndarray, is_single_point_soma: bool
) -> Tuple[np.ndarray, np.ndarray]:
prev_ind = None
prev_type = None
n_branches = 0

# Branch inds will contain the row identifier at which a branch point occurs
# (i.e. the row of the parent of two branches).
branch_inds = []
for c in content:
current_ind = c[0]
current_parent = c[-1]
current_type = c[1]
if current_parent != prev_ind or current_type != prev_type:
branch_inds.append(int(current_parent))
n_branches += 1
prev_ind = current_ind
prev_type = current_type

all_branches = []
current_branch = []
all_types = []

# Loop over every line in the SWC file.
for c in content:
current_ind = c[0] # First col is row_identifier
current_parent = c[-1] # Last col is parent in SWC specification.
if current_parent == -1:
all_types.append(c[1])
else:
current_type = c[1]

if current_parent == -1 and is_single_point_soma and current_ind == 1:
all_branches.append([int(current_ind)])
all_types.append(int(current_type))

# Either append the current point to the branch, or add the branch to
# `all_branches`.
if current_parent in branch_inds[1:]:
if len(current_branch) > 1:
all_branches.append(current_branch)
all_types.append(current_type)
current_branch = [int(current_parent), int(current_ind)]
else:
current_branch.append(int(current_ind))

# Append the final branch (intermediate branches are already appended five lines
# above.)
all_branches.append(current_branch)
return all_branches, all_types


def _build_parents(all_branches: List[np.ndarray]) -> List[int]:
parents = [None] * len(all_branches)
all_last_inds = [b[-1] for b in all_branches]
for i, branch in enumerate(all_branches):
parent_ind = branch[0]
ind = np.where(np.asarray(all_last_inds) == parent_ind)[0]
if len(ind) > 0 and ind != i:
parents[i] = ind[0]
else:
assert (
parent_ind == 1
), """Trying to connect a segment to the beginning of
another segment. This is not allowed. Please create an issue on github."""
parents[i] = -1

return parents


def _radius_generating_fns(
all_branches: np.ndarray,
radiuses: np.ndarray,
each_length: np.ndarray,
parents: np.ndarray,
types: np.ndarray,
) -> List[Callable]:
"""For all branches in a cell, returns callable that return radius given loc."""
radius_fns = []
for i, branch in enumerate(all_branches):
rads_in_branch = radiuses[np.asarray(branch) - 1]
if parents[i] > -1 and types[i] != types[parents[i]]:
# We do not want to linearly interpolate between the radius of the previous
# branch if a new type of neurite is found (e.g. switch from soma to
# apical). From looking at the SWC from n140.swc I believe that this is
# also what NEURON does.
rads_in_branch[0] = rads_in_branch[1]
radius_fn = _radius_generating_fn(
radiuses=rads_in_branch, each_length=each_length[i]
)
# Beause SWC starts counting at 1, but numpy counts from 0.
# ind_of_branch_endpoint = np.asarray(b) - 1
radius_fns.append(radius_fn)
return radius_fns


def _radius_generating_fn(radiuses: np.ndarray, each_length: np.ndarray) -> Callable:
# Avoid division by 0 with the `summed_len` below.
each_length[each_length < 1e-8] = 1e-8
summed_len = np.sum(each_length)
cutoffs = np.cumsum(np.concatenate([np.asarray([0]), each_length])) / summed_len
cutoffs[0] -= 1e-8
cutoffs[-1] += 1e-8

# We have to linearly interpolate radiuses, therefore we need at least two radiuses.
# However, jaxley allows somata which consist of a single traced point (i.e.
# just one radius). Therefore, we just `tile` in order to generate an artificial
# endpoint and startpoint radius of the soma.
if len(radiuses) == 1:
radiuses = np.tile(radiuses, 2)

def radius(loc: float) -> float:
"""Function which returns the radius via linear interpolation."""
index = np.digitize(loc, cutoffs, right=False)
left_rad = radiuses[index - 1]
right_rad = radiuses[index]
left_loc = cutoffs[index - 1]
right_loc = cutoffs[index]
loc_within_bin = (loc - left_loc) / (right_loc - left_loc)
return left_rad + (right_rad - left_rad) * loc_within_bin

return radius


def _compute_pathlengths(
all_branches: np.ndarray, coords: np.ndarray, is_single_point_soma: bool
) -> List[np.ndarray]:
"""
Args:
coords: Has shape (num_traced_points, 5), where `5` is (type, x, y, z, radius).
"""
branch_pathlengths = []
for b in all_branches:
coords_in_branch = coords[np.asarray(b) - 1]
if len(coords_in_branch) > 1:
# If the branch starts at a different neurite (e.g. the soma) then NEURON
# ignores the distance from that initial point. To reproduce, use the
# following SWC dummy file and read it in NEURON (and Jaxley):
# 1 1 0.00 0.0 0.0 6.0 -1
# 2 2 9.00 0.0 0.0 0.5 1
# 3 2 10.0 0.0 0.0 0.3 2
types = coords_in_branch[:, 0]
if int(types[0]) == 1 and int(types[1]) != 1 and is_single_point_soma:
coords_in_branch[0] = coords_in_branch[1]

# Compute distances between all traced points in a branch.
point_diffs = np.diff(coords_in_branch, axis=0)
dists = np.sqrt(
point_diffs[:, 1] ** 2 + point_diffs[:, 2] ** 2 + point_diffs[:, 3] ** 2
)
else:
# Jaxley uses length and radius for every compartment and assumes the
# surface area to be 2*pi*r*length. For branches consisting of a single
# traced point we assume for them to have area 4*pi*r*r. Therefore, we have
# to set length = 2*r.
radius = coords_in_branch[0, 4] # txyzr -> 4 is radius.
dists = np.asarray([2 * radius])
branch_pathlengths.append(dists)
return branch_pathlengths


def build_radiuses_from_xyzr(
radius_fns: List[Callable],
branch_indices: List[int],
min_radius: Optional[float],
nseg: int,
) -> jnp.ndarray:
"""Return the radiuses of branches given SWC file xyzr.
Returns an array of shape `(num_branches, nseg)`.
Args:
radius_fns: Functions which, given compartment locations return the radius.
branch_indices: The indices of the branches for which to return the radiuses.
min_radius: If passed, the radiuses are clipped to be at least as large.
nseg: The number of compartments that every branch is discretized into.
"""
# Compartment locations are at the center of the internal nodes.
non_split = 1 / nseg
range_ = np.linspace(non_split / 2, 1 - non_split / 2, nseg)

# Build radiuses.
radiuses = np.asarray([radius_fns[b](range_) for b in branch_indices])
radiuses_each = radiuses.ravel(order="C")
if min_radius is None:
assert np.all(
radiuses_each > 0.0
), "Radius 0.0 in SWC file. Set `read_swc(..., min_radius=...)`."
else:
radiuses_each[radiuses_each < min_radius] = min_radius

return radiuses_each


def read_swc(
fname: str,
nseg: int,
max_branch_len: float = 300.0,
min_radius: Optional[float] = None,
assign_groups: bool = False,
) -> jx.Cell:
"""Reads SWC file into a `jx.Cell`.
) -> Cell:
"""Reads SWC file into a `Cell`.
Jaxley assumes cylindrical compartments and therefore defines length and radius
for every compartment. The surface area is then 2*pi*r*length. For branches
Expand All @@ -382,16 +118,16 @@ def read_swc(
http://www.neuronland.org/NLMorphologyConverter/MorphologyFormats/SWC/Spec.html
Returns:
A `jx.Cell` object.
A `Cell` object.
"""
parents, pathlengths, radius_fns, types, coords_of_branches = swc_to_jaxley(
fname, max_branch_len=max_branch_len, sort=True, num_lines=None
)
nbranches = len(parents)

comp = jx.Compartment()
branch = jx.Branch([comp for _ in range(nseg)])
cell = jx.Cell(
comp = Compartment()
branch = Branch([comp for _ in range(nseg)])
cell = Cell(
[branch for _ in range(nbranches)], parents=parents, xyzr=coords_of_branches
)
# Also save the radius generating functions in case users post-hoc modify the number
Expand Down
1 change: 0 additions & 1 deletion jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from matplotlib.axes import Axes

from jaxley.channels import Channel
from jaxley.io.swc import build_radiuses_from_xyzr
from jaxley.solver_voltage import (
step_voltage_explicit,
step_voltage_implicit_with_jax_spsolve,
Expand Down

0 comments on commit d55c87b

Please sign in to comment.