From aa4ae5fe93183ac0efece826b4f53760cab76451 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Thu, 5 Dec 2024 14:20:03 +0100 Subject: [PATCH] wip: more tests passing, small refactor --- jaxley/modules/base.py | 112 +++++++++++++++++++++-------------------- 1 file changed, 58 insertions(+), 54 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 55c43a48..768f0e91 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -30,8 +30,7 @@ compute_axial_conductances, compute_current_density, compute_levels, - interpolate_xyz, - loc_of_index, + interpolate_xyzr, params_to_pstate, query_states_and_params, v_interp, @@ -228,6 +227,7 @@ def __getattr__(self, key): view._set_controlled_by_param(key) # overwrites param set by edge # Ensure synapse param sharing works with `edge` # `edge` will be removed as part of #463 + view.edges["local_edge_index"] = np.arange(len(view.edges)) return view def _childviews(self) -> List[str]: @@ -1198,9 +1198,9 @@ def _get_state_names(self) -> Tuple[List, List]: """Collect all recordable / clampable states in the membrane and synapses. Returns states seperated by comps and edges.""" - channel_states = [name for c in self.channels for name in c.channel_states] + channel_states = [name for c in self.channels for name in c.states] synapse_states = [ - name for s in self.synapses if s is not None for name in s.synapse_states + name for s in self.synapses if s is not None for name in s.states ] membrane_states = ["v", "i"] + self.membrane_current_names return ( @@ -1219,6 +1219,26 @@ def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: """ return self.trainable_params + @only_allow_module + def _iter_states_or_params(self, type="states") -> Dict[str, jnp.ndarray]: + # TODO FROM #447: MAKE THIS WORK FOR VIEW? + """Return states as they are set in the `.nodes` and `.edges` tables.""" + morph_params = ["radius", "length", "axial_resistivity", "capacitance"] + global_states = ["v"] + global_states_or_params = morph_params if type == "params" else global_states + for key in global_states_or_params: + yield key, self.base.jaxnodes["index"], self.base.jaxnodes[key] + + # Join node and edge states into a single state dictionary. + for jax_arrays, mechs in zip( + [self.base.jaxnodes, self.base.jaxedges], + [self.base.channels, self.base.synapses], + ): + for mech in mechs: + mech_inds = jax_arrays[mech._name] + for key in mech.__dict__[type]: + yield key, mech_inds, jax_arrays[key] + @only_allow_module def get_all_parameters( self, pstate: List[Dict], voltage_solver: str @@ -1255,34 +1275,24 @@ def get_all_parameters( Returns: A dictionary of all module parameters. """ - params = {} - morph_params = ["radius", "length", "axial_resistivity", "capacitance"] - for key in ["v"] + morph_params: - params[key] = self.base.jaxnodes[key] + pstate_inds = {d["key"]: i for i, d in enumerate(pstate)} - for jax_arrays, data, mechs in zip( - [self.base.jaxnodes, self.base.jaxedges], - [self.base.nodes, self.base.edges], - [self.base.channels, self.base.synapses], - ): - for mech in mechs: - inds = jax_arrays[mech._name] - for mech_param in mech.params: - params[mech_param] = data[mech_param].to_numpy() - params[mech_param][inds] = jax_arrays[mech_param] - params[mech_param] = jnp.asarray(params[mech_param]) + params = {} + for key, mech_inds, jax_array in self._iter_states_or_params("params"): + params[key] = jax_array - # Override with those parameters set by `.make_trainable()`. - for parameter in pstate: - key = parameter["key"] - inds = parameter["indices"] - set_param = parameter["val"] + # Override with those parameters set by `.make_trainable()`. + if key in pstate_inds: + idx = pstate_inds[key] + key = pstate[idx]["key"] + inds = pstate[idx]["indices"] + set_param = pstate[idx]["val"] - if key in params: # Only parameters, not initial states. # `inds` is of shape `(num_params, num_comps_per_param)`. # `set_param` is of shape `(num_params,)` - # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the - # `.set()` to work. This is done with `[:, None]`. + # We need to unsqueeze `set_param` to make it `(num_params, 1)` + # for the `.set()` to work. This is done with `[:, None]`. + inds = np.searchsorted(mech_inds, inds) params[key] = params[key].at[inds].set(set_param[:, None]) # Compute conductance params and add them to the params dictionary. @@ -1291,20 +1301,6 @@ def get_all_parameters( ) return params - @only_allow_module - def _get_states_from_nodes_and_edges(self) -> Dict[str, jnp.ndarray]: - # TODO FROM #447: MAKE THIS WORK FOR VIEW? - """Return states as they are set in the `.nodes` and `.edges` tables.""" - self.base.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`. - states = {"v": self.base.jaxnodes["v"]} - # Join node and edge states into a single state dictionary. - for channel in self.base.channels: - for channel_states in channel.states: - states[channel_states] = self.base.jaxnodes[channel_states] - for synapse_states in self.base.synapse_state_names: - states[synapse_states] = self.base.jaxedges[synapse_states] - return states - @only_allow_module def get_all_states( self, pstate: List[Dict], all_params, delta_t: float @@ -1320,18 +1316,23 @@ def get_all_states( Returns: A dictionary of all states of the module. """ - states = self.base._get_states_from_nodes_and_edges() - - # Override with the initial states set by `.make_trainable()`. - for parameter in pstate: - key = parameter["key"] - inds = parameter["indices"] - set_param = parameter["val"] - if key in states: # Only initial states, not parameters. - # `inds` is of shape `(num_params, num_comps_per_param)`. - # `set_param` is of shape `(num_params,)` - # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the - # `.set()` to work. This is done with `[:, None]`. + pstate_inds = {d["key"]: i for i, d in enumerate(pstate)} + states = {} + for key, mech_inds, jax_array in self._iter_states_or_params("states"): + states[key] = jax_array + + # Override with those parameters set by `.make_trainable()`. + if key in pstate_inds: + idx = pstate_inds[key] + key = pstate[idx]["key"] + inds = pstate[idx]["indices"] + set_param = pstate[idx]["val"] + + # `inds` is of shape `(num_states, num_comps_per_param)`. + # `set_param` is of shape `(num_states,)` + # We need to unsqueeze `set_param` to make it `(num_states, 1)` + # for the `.set()` to work. This is done with `[:, None]`. + inds = np.searchsorted(mech_inds, inds) states[key] = states[key].at[inds].set(set_param[:, None]) # Add to the states the initial current through every channel. @@ -1366,8 +1367,11 @@ def init_states(self, delta_t: float = 0.025): delta_t: Passed on to `channel.init_state()`. """ # Update states of the channels. + self.base.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`. channel_nodes = self.base.nodes - states = self.base._get_states_from_nodes_and_edges() + states = {} + for key, _, jax_array in self._iter_states_or_params("states"): + states[key] = jax_array # We do not use any `pstate` for initializing. In principle, we could change # that by allowing an input `params` and `pstate` to this function.