From e17f92544f1f4af4f9f83dfd2bb634ad2d450876 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Tue, 17 Dec 2024 12:48:18 +0100 Subject: [PATCH] fix: fix param sharing --- jaxley/modules/base.py | 3 ++- jaxley/utils/cell_utils.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index cf9392dd..f49182e2 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -31,6 +31,7 @@ compute_axial_conductances, compute_current_density, compute_levels, + index_of_a_in_b, interpolate_xyzr, params_to_pstate, v_interp, @@ -1355,7 +1356,7 @@ def _get_all_states_params( # `set_param` is of shape `(num_params,)` # We need to unsqueeze `set_param` to make it `(num_params, 1)` # for the `.set()` to work. This is done with `[:, None]`. - inds = jnp.searchsorted(param_state_inds, inds) + inds = index_of_a_in_b(inds, param_state_inds) states_params[key] = states_params[key].at[inds].set(set_param[:, None]) if params: diff --git a/jaxley/utils/cell_utils.py b/jaxley/utils/cell_utils.py index 27e337ab..ccbe5833 100644 --- a/jaxley/utils/cell_utils.py +++ b/jaxley/utils/cell_utils.py @@ -772,3 +772,32 @@ def dtype_aware_concat(dfs): concat_df[col] = concat_df[col].astype(df[col].dtype) break # first match is sufficient return concat_df + + +def index_of_a_in_b(A: jnp.ndarray, B: jnp.ndarray) -> jnp.ndarray: + """Replace values in A with the indices of the corresponding values in B. + + Mainly used to determine the indices of parameters in jaxnodes based on the global + indices of the parameters in the cell. All values in A that are not in B are + replaced with -1. + + Example: + - indices_of_gNa = [5,6,7,8,9] + - indices_to_change = [6,7] + - index_of_a_in_b(indices_to_change, indices_of_gNa) -> [1,2] + + Args: + A: Array of shape (N, M). + B: Array of shape (N, K). + + Returns: + Array of shape of A with the indices of the values of A in B.""" + matches = A[:, :, None] == B + # Get mask for values that exist in B + exists_in_B = matches.any(axis=-1) + # Get indices where matches occur + indices = jnp.where(matches, jnp.arange(len(B))[None, None, :], 0) + # Sum along last axis to get the indices + result = jnp.sum(indices, axis=-1) + # Replace values not in B with -1 + return jnp.where(exists_in_B, result, -1)