Skip to content

Commit

Permalink
wip: more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Dec 6, 2024
1 parent c057a32 commit b5c8a6b
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1242,14 +1242,17 @@ def _iter_states_params(
for key in global_states_params:
yield key, self.jaxnodes[key]

# for key in self.synapse_current_names:
# yield key, self.jaxedges[key]

# Join node and edge states into a single state dictionary.
for jax_arrays, mechs in zip(
[self.jaxnodes, self.jaxedges],
[self.channels, self.synapses],
):
for mech in mechs:
mech_params_states = mech.__dict__["params"] if params else {}
mech_params_states.update(mech.__dict__["states"] if states else {})
mech_params_states = mech.params if params else {}
mech_params_states.update(mech.states if states else {})
for key in mech_params_states:
yield key, jax_arrays[key]

Expand Down Expand Up @@ -1293,8 +1296,8 @@ def _get_all_states_params(
states=False,
) -> Dict[str, jnp.ndarray]:
states_params = {}
for key, jax_array in self.base._iter_states_params(params, states):
states_params[key] = jax_array
for key, jax_arrays in self.base._iter_states_params(params, states):
states_params[key] = jax_arrays

# Override with those parameters set by `.make_trainable()`.
for parameter in pstate:
Expand Down Expand Up @@ -1414,8 +1417,8 @@ def init_states(self, delta_t: float = 0.025):
self.base.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`.
channel_nodes = self.base.nodes
states = {}
for key, jax_array in self.base._iter_states_params(states=True):
states[key] = jax_array
for key, jax_arrays in self.base._iter_states_params(states=True):
states[key] = jax_arrays

# We do not use any `pstate` for initializing. In principle, we could change
# that by allowing an input `params` and `pstate` to this function.
Expand Down Expand Up @@ -1784,8 +1787,8 @@ def delete_channel(self, channel: Channel):
channel_names = [c._name for c in self.channels]
all_channel_names = [c._name for c in self.base.channels]
if name in channel_names:
channel_cols = list(channel.channel_params.keys())
channel_cols += list(channel.channel_states.keys())
channel_cols = list(channel.params.keys())
channel_cols += list(channel.states.keys())
self.base.nodes.loc[self._nodes_in_view, channel_cols] = float("nan")
self.base.nodes.loc[self._nodes_in_view, name] = False

Expand Down Expand Up @@ -2569,16 +2572,16 @@ def _jax_arrays_in_view(self, pointer: Union[Module, View]):
elif v in mechs + ["v"] + morph_params:
self._mech_lookup_table[k] = v

for jax_array, base_jax_array, viewed_inds in zip(
for jax_arrays, base_jax_arrays, viewed_inds in zip(
[jaxnodes, jaxedges],
[self.base.jaxnodes, self.base.jaxedges],
[self._nodes_in_view, self._edges_in_view],
):
if base_jax_array is not None and len(viewed_inds) > 0:
for key, values in base_jax_array.items():
if base_jax_arrays is not None and len(viewed_inds) > 0:
for key, values in base_jax_arrays.items():
mech, mech_inds = self.base._get_mech_inds_of_param_state(key)
if mech is None or mech in mechs:
jax_array[key] = values[
jax_arrays[key] = values.at[
a_intersects_b_at(mech_inds, viewed_inds)
]

Expand Down

0 comments on commit b5c8a6b

Please sign in to comment.