From b5c8a6b25ca24f5ffa47c704de9ca5623f596e23 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Fri, 6 Dec 2024 19:28:54 +0100 Subject: [PATCH] wip: more fixes --- jaxley/modules/base.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 7dafaf1d..56dd842c 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1242,14 +1242,17 @@ def _iter_states_params( for key in global_states_params: yield key, self.jaxnodes[key] + # for key in self.synapse_current_names: + # yield key, self.jaxedges[key] + # Join node and edge states into a single state dictionary. for jax_arrays, mechs in zip( [self.jaxnodes, self.jaxedges], [self.channels, self.synapses], ): for mech in mechs: - mech_params_states = mech.__dict__["params"] if params else {} - mech_params_states.update(mech.__dict__["states"] if states else {}) + mech_params_states = mech.params if params else {} + mech_params_states.update(mech.states if states else {}) for key in mech_params_states: yield key, jax_arrays[key] @@ -1293,8 +1296,8 @@ def _get_all_states_params( states=False, ) -> Dict[str, jnp.ndarray]: states_params = {} - for key, jax_array in self.base._iter_states_params(params, states): - states_params[key] = jax_array + for key, jax_arrays in self.base._iter_states_params(params, states): + states_params[key] = jax_arrays # Override with those parameters set by `.make_trainable()`. for parameter in pstate: @@ -1414,8 +1417,8 @@ def init_states(self, delta_t: float = 0.025): self.base.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`. channel_nodes = self.base.nodes states = {} - for key, jax_array in self.base._iter_states_params(states=True): - states[key] = jax_array + for key, jax_arrays in self.base._iter_states_params(states=True): + states[key] = jax_arrays # We do not use any `pstate` for initializing. In principle, we could change # that by allowing an input `params` and `pstate` to this function. @@ -1784,8 +1787,8 @@ def delete_channel(self, channel: Channel): channel_names = [c._name for c in self.channels] all_channel_names = [c._name for c in self.base.channels] if name in channel_names: - channel_cols = list(channel.channel_params.keys()) - channel_cols += list(channel.channel_states.keys()) + channel_cols = list(channel.params.keys()) + channel_cols += list(channel.states.keys()) self.base.nodes.loc[self._nodes_in_view, channel_cols] = float("nan") self.base.nodes.loc[self._nodes_in_view, name] = False @@ -2569,16 +2572,16 @@ def _jax_arrays_in_view(self, pointer: Union[Module, View]): elif v in mechs + ["v"] + morph_params: self._mech_lookup_table[k] = v - for jax_array, base_jax_array, viewed_inds in zip( + for jax_arrays, base_jax_arrays, viewed_inds in zip( [jaxnodes, jaxedges], [self.base.jaxnodes, self.base.jaxedges], [self._nodes_in_view, self._edges_in_view], ): - if base_jax_array is not None and len(viewed_inds) > 0: - for key, values in base_jax_array.items(): + if base_jax_arrays is not None and len(viewed_inds) > 0: + for key, values in base_jax_arrays.items(): mech, mech_inds = self.base._get_mech_inds_of_param_state(key) if mech is None or mech in mechs: - jax_array[key] = values[ + jax_arrays[key] = values.at[ a_intersects_b_at(mech_inds, viewed_inds) ]