Skip to content

Commit

Permalink
fix equidistant interpolation in xyz computation (#411)
Browse files Browse the repository at this point in the history
* fix: xp in interpolation was equidistant. changed to pathlengths. fixes #410

* fix: make clearer

* chore: ran black
  • Loading branch information
jnsbck authored Aug 28, 2024
1 parent 9bdeeff commit 19373f7
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions jaxley/utils/cell_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,9 @@ def remap_to_consecutive(arr):
return inverse_indices


v_interp = vmap(jnp.interp, in_axes=(None, None, 1))


def interpolate_xyz(loc: float, coords: np.ndarray):
"""Perform a linear interpolation between xyz-coordinates.
Expand All @@ -291,9 +294,11 @@ def interpolate_xyz(loc: float, coords: np.ndarray):
Return:
Interpolated xyz coordinate at `loc`, shape `(3,).
"""
return vmap(lambda x: jnp.interp(loc, jnp.linspace(0, 1, len(x)), x), in_axes=(1,))(
coords[:, :3]
)
dl = np.sqrt(np.sum(np.diff(coords[:, :3], axis=0) ** 2, axis=1))
pathlens = np.insert(np.cumsum(dl), 0, 0) # cummulative length of sections
norm_pathlens = pathlens / pathlens[-1] # path lengths normalized to [0,1]

return v_interp(loc, norm_pathlens, coords[:, :3])


def params_to_pstate(
Expand Down Expand Up @@ -391,6 +396,3 @@ def group_and_sum(
group_sums = group_sums.at[inds_to_group_by].add(values_to_sum)

return group_sums


v_interp = jit(vmap(jnp.interp, in_axes=(None, None, 1)))

0 comments on commit 19373f7

Please sign in to comment.