Skip to content

Commit

Permalink
fix: fix remaining indexing issues, tests passing (I think)
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Dec 7, 2024
1 parent 91710a2 commit cdb2983
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 24 deletions.
24 changes: 8 additions & 16 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1428,9 +1428,7 @@ def init_states(self, delta_t: float = 0.025):

for channel in self.base.channels:
name = channel._name
channel_indices = channel_nodes.loc[channel_nodes[name]][
"global_comp_index"
].to_numpy()
channel_indices = self._mech_lookup_table[name + "_index"]
voltages = channel_nodes.loc[channel_indices, "v"].to_numpy()

channel_param_names = list(channel.params.keys())
Expand Down Expand Up @@ -1966,23 +1964,20 @@ def _step_channels_state(
morph_params = ["radius", "length", "axial_resistivity", "capacitance"]

for channel in channels:
is_channel = channel_nodes[channel._name]
channel_inds = channel_nodes.loc[is_channel, "global_comp_index"].to_numpy()

query_channel = lambda d, names: query_states_and_params(d, names)
channel_param_names = list(channel.params) + morph_params
channel_params = query_channel(params, channel_param_names)
channel_params = query_states_and_params(params, channel_param_names)
channel_state_names = list(channel.states) + self.membrane_current_names
channel_states = query_channel(states, channel_state_names)
channel_states = query_states_and_params(states, channel_state_names)

# States updates.
channel_inds = self._mech_lookup_table[channel._name + "_index"]
states_updated = channel.update_states(
channel_states, delta_t, voltages[channel_inds], channel_params
)
# Rebuild state. This has to be done within the loop over channels to allow
# multiple channels which modify the same state.
for key, val in states_updated.items():
states[key] = states[key].at[channel_inds].set(val)
states[key] = states[key].at[:].set(val)

return states

Expand Down Expand Up @@ -2010,14 +2005,11 @@ def _channel_currents(

current_states = {name: zeros for name in self.membrane_current_names}
for channel in channels:
is_channel = channel_nodes[channel._name]
channel_inds = channel_nodes.loc[is_channel, "global_comp_index"].to_numpy()

query_channel = lambda d, names: query_states_and_params(d, names)
channel_param_names = list(channel.params) + morph_params
channel_params = query_channel(params, channel_param_names)
channel_states = query_channel(states, channel.states)
channel_params = query_states_and_params(params, channel_param_names)
channel_states = query_states_and_params(states, channel.states)

channel_inds = self._mech_lookup_table[channel._name + "_index"]
v_channel = voltages[channel_inds]
v_and_perturbed = jnp.array([v_channel, v_channel + diff])

Expand Down
15 changes: 7 additions & 8 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,9 @@ def _step_synapse_state(
synapse = syn_channels[i]
pre_inds = group["pre_global_comp_index"].to_numpy()
post_inds = group["post_global_comp_index"].to_numpy()
edge_inds = group.index.to_numpy()

query_syn = lambda d, names: query_states_and_params(d, names, edge_inds)
synapse_params = query_syn(params, synapse.params)
synapse_states = query_syn(states, synapse.states)
synapse_params = query_states_and_params(params, synapse.params)
synapse_states = query_states_and_params(states, synapse.states)

# State updates.
states_updated = synapse.update_states(
Expand All @@ -290,6 +288,7 @@ def _step_synapse_state(
# multiple channels which modify the same state.
for key, val in states_updated.items():
states[key] = states[key].at[:].set(val)

return states

def _synapse_currents(
Expand Down Expand Up @@ -356,10 +355,10 @@ def _synapse_currents(
.add(synapse_currents_dist[0])
)

# Copy the currents into the `state` dictionary such that they can be
# recorded and used by `Channel.update_states()`.
for name in [s._name for s in self.synapses]:
states[f"i_{name}"] = synapse_current_states[f"i_{name}"]
# Copy the currents into the `state` dictionary such that they can be
# recorded and used by `Channel.update_states()`.
for name in [s._name for s in self.synapses]:
states[f"i_{name}"] = synapse_current_states[f"i_{name}"]
return states, (syn_voltage_terms, syn_constant_terms)

def arrange_in_layers(
Expand Down

0 comments on commit cdb2983

Please sign in to comment.