Skip to content

Commit

Permalink
Initial states are trainable
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Dec 18, 2023
1 parent b0edeb1 commit 73e9c0e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
9 changes: 9 additions & 0 deletions jaxley/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,18 @@ def _body_fun(state, i_stim):
for channel in module.channels:
for channel_states in list(channel.channel_states.keys()):
states[channel_states] = module.jaxnodes[channel_states]

# Override with the initial states set by `.make_trainable()`.
for inds, set_param in zip(module.indices_set_by_trainables, params):
for key in set_param.keys():
if key in list(states.keys()): # Only initial states, not parameters.
states[key] = states[key].at[inds].set(set_param[key])

# Write synaptic states. TODO move above when new interface for synapses.
for key in module.syn_states:
states[key] = module.syn_states[key]

# Run simulation.
_, recordings = nested_checkpoint_scan(
_body_fun, states, i_current, length=length, nested_lengths=checkpoint_lengths
)
Expand Down
3 changes: 2 additions & 1 deletion jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,8 @@ def get_all_parameters(self, trainable_params):
# Override with those parameters set by `.make_trainable()`.
for inds, set_param in zip(self.indices_set_by_trainables, trainable_params):
for key in set_param.keys():
params[key] = params[key].at[inds].set(set_param[key])
if key in list(params.keys()): # Only parameters, not initial states.
params[key] = params[key].at[inds].set(set_param[key])

# Compute conductance params and append them.
cond_params = self.init_conds(params)
Expand Down

0 comments on commit 73e9c0e

Please sign in to comment.