Skip to content

Commit

Permalink
fix: fix param sharing
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Dec 23, 2024
1 parent efee504 commit b760e90
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
3 changes: 2 additions & 1 deletion jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
compute_axial_conductances,
compute_current_density,
compute_levels,
index_of_a_in_b,
interpolate_xyzr,
params_to_pstate,
v_interp,
Expand Down Expand Up @@ -1355,7 +1356,7 @@ def _get_all_states_params(
# `set_param` is of shape `(num_params,)`
# We need to unsqueeze `set_param` to make it `(num_params, 1)`
# for the `.set()` to work. This is done with `[:, None]`.
inds = jnp.searchsorted(param_state_inds, inds)
inds = index_of_a_in_b(inds, param_state_inds)
states_params[key] = states_params[key].at[inds].set(set_param[:, None])

if params:
Expand Down
29 changes: 29 additions & 0 deletions jaxley/utils/cell_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,3 +772,32 @@ def dtype_aware_concat(dfs):
concat_df[col] = concat_df[col].astype(df[col].dtype)
break # first match is sufficient
return concat_df


def index_of_a_in_b(A: jnp.ndarray, B: jnp.ndarray) -> jnp.ndarray:
"""Replace values in A with the indices of the corresponding values in B.
Mainly used to determine the indices of parameters in jaxnodes based on the global
indices of the parameters in the cell. All values in A that are not in B are
replaced with -1.
Example:
- indices_of_gNa = [5,6,7,8,9]
- indices_to_change = [6,7]
- index_of_a_in_b(indices_to_change, indices_of_gNa) -> [1,2]
Args:
A: Array of shape (N, M).
B: Array of shape (N, K).
Returns:
Array of shape of A with the indices of the values of A in B."""
matches = A[:, :, None] == B
# Get mask for values that exist in B
exists_in_B = matches.any(axis=-1)
# Get indices where matches occur
indices = jnp.where(matches, jnp.arange(len(B))[None, None, :], 0)
# Sum along last axis to get the indices
result = jnp.sum(indices, axis=-1)
# Replace values not in B with -1
return jnp.where(exists_in_B, result, -1)

0 comments on commit b760e90

Please sign in to comment.