Skip to content

Commit

Permalink
_update_nodes_with_xyz for heterogenous nseg (#445)
Browse files Browse the repository at this point in the history
* fix: update_nodes_with_xyz now works with heterogenous nseg

* fix: rm cumsum_nseg, since not present in comp / branch
  • Loading branch information
jnsbck authored Oct 10, 2024
1 parent d169f2c commit ecfe7e8
Showing 1 changed file with 31 additions and 24 deletions.
55 changes: 31 additions & 24 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,35 +123,42 @@ def __init__(self):
def _update_nodes_with_xyz(self):
"""Add xyz coordinates of compartment centers to nodes.
Note: For sake of performance, interpolation is not done for each branch,
but once along a concatenated (and padded) array of all branches.
Centers are the midpoint between the comparment endpoints on the morphology
as defined by xyzr.
Note: For sake of performance, interpolation is not done for each branch
individually, but only once along a concatenated (and padded) array of all branches.
This means for nsegs = [2,4] and normalized cum_branch_lens of [[0,1],[0,1]] we would
interpolate xyz at the locations comp_ends = [[0,0.5,1], [0,0.25,0.5,0.75,1]],
where 0 is the start of the branch and 1 is the end point at the full branch_len.
To avoid do this in one go we set comp_ends = [0,0.5,1,2,2.25,2.5,2.75,3], and
norm_cum_branch_len = [0,1,2,3] incrememting and also padding them by 1 to
avoid overlapping branch_lens i.e. norm_cum_branch_len = [0,1,1,2] for only
incrementing.
"""
num_branches = len(self.xyzr)
comp_ends = (
np.linspace(0, 1, self.nseg + 1).reshape(1, -1).repeat(num_branches, 0)
nsegs = self.nodes.groupby("branch_index")["comp_index"].nunique().to_numpy()

comp_ends = np.hstack(
[np.linspace(0, 1, nseg + 1) + 2 * i for i, nseg in enumerate(nsegs)]
)
comp_ends = comp_ends + 2 * np.arange(num_branches).reshape(
-1, 1
) # inter-branch padding
comp_ends = comp_ends.reshape(-1)
branch_lens = []
cum_branch_lens = []
for i, xyzr in enumerate(self.xyzr):
branch_len = np.sqrt(
np.sum(np.diff(xyzr[:, :3], axis=0) ** 2, axis=1)
).cumsum()
branch_len = np.hstack([np.array([0]), branch_len])
max_len = branch_len.max()
branch_len = (
branch_len / (max_len if max_len > 0 else 1) + 2 * i
) # add padding like above
branch_len[np.isnan(branch_len)] = 0
branch_lens.append(branch_len)
branch_lens = np.hstack(branch_lens)
branch_len = np.sqrt(np.sum(np.diff(xyzr[:, :3], axis=0) ** 2, axis=1))
cum_branch_len = np.cumsum(np.concatenate([np.array([0]), branch_len]))
max_len = cum_branch_len.max()
# add padding like above
cum_branch_len = cum_branch_len / (max_len if max_len > 0 else 1) + 2 * i
cum_branch_len[np.isnan(cum_branch_len)] = 0
cum_branch_lens.append(cum_branch_len)
cum_branch_lens = np.hstack(cum_branch_lens)
xyz = np.vstack(self.xyzr)[:, :3]
xyz = v_interp(comp_ends, branch_lens, xyz).reshape(
3, num_branches, self.nseg + 1
)
centers = ((xyz[:, :, 1:] + xyz[:, :, :-1]) / 2).reshape(3, -1).T
xyz = v_interp(comp_ends, cum_branch_lens, xyz).T
centers = (xyz[:-1] + xyz[1:]) / 2 # unaware of inter vs intra comp centers
cum_nsegs = np.cumsum(nsegs)
# this means centers between comps have to be removed here
between_comp_inds = (cum_nsegs + np.arange(len(cum_nsegs)))[:-1]
centers = np.delete(centers, between_comp_inds, axis=0)
idcs = self.nodes["comp_index"]
self.nodes.loc[idcs, ["x", "y", "z"]] = centers
return centers, xyz
Expand Down

0 comments on commit ecfe7e8

Please sign in to comment.