From ecfe7e8f47287989dc108081ca06837a6ec7334c Mon Sep 17 00:00:00 2001 From: jnsbck <65561470+jnsbck@users.noreply.github.com> Date: Thu, 10 Oct 2024 21:15:32 +0200 Subject: [PATCH] `_update_nodes_with_xyz` for heterogenous nseg (#445) * fix: update_nodes_with_xyz now works with heterogenous nseg * fix: rm cumsum_nseg, since not present in comp / branch --- jaxley/modules/base.py | 55 ++++++++++++++++++++++++------------------ 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 91901fc4..a359d30c 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -123,35 +123,42 @@ def __init__(self): def _update_nodes_with_xyz(self): """Add xyz coordinates of compartment centers to nodes. - Note: For sake of performance, interpolation is not done for each branch, - but once along a concatenated (and padded) array of all branches. + Centers are the midpoint between the comparment endpoints on the morphology + as defined by xyzr. + + Note: For sake of performance, interpolation is not done for each branch + individually, but only once along a concatenated (and padded) array of all branches. + This means for nsegs = [2,4] and normalized cum_branch_lens of [[0,1],[0,1]] we would + interpolate xyz at the locations comp_ends = [[0,0.5,1], [0,0.25,0.5,0.75,1]], + where 0 is the start of the branch and 1 is the end point at the full branch_len. + To avoid do this in one go we set comp_ends = [0,0.5,1,2,2.25,2.5,2.75,3], and + norm_cum_branch_len = [0,1,2,3] incrememting and also padding them by 1 to + avoid overlapping branch_lens i.e. norm_cum_branch_len = [0,1,1,2] for only + incrementing. """ - num_branches = len(self.xyzr) - comp_ends = ( - np.linspace(0, 1, self.nseg + 1).reshape(1, -1).repeat(num_branches, 0) + nsegs = self.nodes.groupby("branch_index")["comp_index"].nunique().to_numpy() + + comp_ends = np.hstack( + [np.linspace(0, 1, nseg + 1) + 2 * i for i, nseg in enumerate(nsegs)] ) - comp_ends = comp_ends + 2 * np.arange(num_branches).reshape( - -1, 1 - ) # inter-branch padding comp_ends = comp_ends.reshape(-1) - branch_lens = [] + cum_branch_lens = [] for i, xyzr in enumerate(self.xyzr): - branch_len = np.sqrt( - np.sum(np.diff(xyzr[:, :3], axis=0) ** 2, axis=1) - ).cumsum() - branch_len = np.hstack([np.array([0]), branch_len]) - max_len = branch_len.max() - branch_len = ( - branch_len / (max_len if max_len > 0 else 1) + 2 * i - ) # add padding like above - branch_len[np.isnan(branch_len)] = 0 - branch_lens.append(branch_len) - branch_lens = np.hstack(branch_lens) + branch_len = np.sqrt(np.sum(np.diff(xyzr[:, :3], axis=0) ** 2, axis=1)) + cum_branch_len = np.cumsum(np.concatenate([np.array([0]), branch_len])) + max_len = cum_branch_len.max() + # add padding like above + cum_branch_len = cum_branch_len / (max_len if max_len > 0 else 1) + 2 * i + cum_branch_len[np.isnan(cum_branch_len)] = 0 + cum_branch_lens.append(cum_branch_len) + cum_branch_lens = np.hstack(cum_branch_lens) xyz = np.vstack(self.xyzr)[:, :3] - xyz = v_interp(comp_ends, branch_lens, xyz).reshape( - 3, num_branches, self.nseg + 1 - ) - centers = ((xyz[:, :, 1:] + xyz[:, :, :-1]) / 2).reshape(3, -1).T + xyz = v_interp(comp_ends, cum_branch_lens, xyz).T + centers = (xyz[:-1] + xyz[1:]) / 2 # unaware of inter vs intra comp centers + cum_nsegs = np.cumsum(nsegs) + # this means centers between comps have to be removed here + between_comp_inds = (cum_nsegs + np.arange(len(cum_nsegs)))[:-1] + centers = np.delete(centers, between_comp_inds, axis=0) idcs = self.nodes["comp_index"] self.nodes.loc[idcs, ["x", "y", "z"]] = centers return centers, xyz