Skip to content

Commit

Permalink
fix: fixed indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Dec 7, 2024
1 parent ead8d3f commit 91710a2
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 23 deletions.
16 changes: 4 additions & 12 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,12 +1435,8 @@ def init_states(self, delta_t: float = 0.025):

channel_param_names = list(channel.params.keys())
channel_state_names = list(channel.states.keys())
channel_states = query_states_and_params(
states, channel_state_names, channel_indices
)
channel_params = query_states_and_params(
params, channel_param_names, channel_indices
)
channel_states = query_states_and_params(states, channel_state_names)
channel_params = query_states_and_params(params, channel_param_names)

init_state = channel.init_state(
channel_states, voltages, channel_params, delta_t
Expand Down Expand Up @@ -1973,9 +1969,7 @@ def _step_channels_state(
is_channel = channel_nodes[channel._name]
channel_inds = channel_nodes.loc[is_channel, "global_comp_index"].to_numpy()

query_channel = lambda d, names: query_states_and_params(
d, names, channel_inds
)
query_channel = lambda d, names: query_states_and_params(d, names)
channel_param_names = list(channel.params) + morph_params
channel_params = query_channel(params, channel_param_names)
channel_state_names = list(channel.states) + self.membrane_current_names
Expand Down Expand Up @@ -2019,9 +2013,7 @@ def _channel_currents(
is_channel = channel_nodes[channel._name]
channel_inds = channel_nodes.loc[is_channel, "global_comp_index"].to_numpy()

query_channel = lambda d, names: query_states_and_params(
d, names, channel_inds
)
query_channel = lambda d, names: query_states_and_params(d, names)
channel_param_names = list(channel.params) + morph_params
channel_params = query_channel(params, channel_param_names)
channel_states = query_channel(states, channel.states)
Expand Down
17 changes: 8 additions & 9 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,18 +309,16 @@ def _synapse_currents(
# offset.
diff = 1e-3

num_comp = len(voltages)
synapse_current_states = {f"i_{s._name}": zeros for s in syn_channels}
for i, group in edges.groupby("type_ind"):
synapse = syn_channels[i]
pre_inds = group["pre_global_comp_index"].to_numpy()
post_inds = group["post_global_comp_index"].to_numpy()
edge_inds = group.index.to_numpy()

query_syn = lambda d, names: query_states_and_params(d, names, edge_inds)
synapse_params = query_syn(params, synapse.params)
synapse_states = query_syn(states, synapse.states)
synapse_params = query_states_and_params(params, synapse.params)
synapse_states = query_states_and_params(states, synapse.states)

num_comp = len(voltages)
v_pre, v_post = voltages[pre_inds], voltages[post_inds]
pre_v_and_perturbed = jnp.array([v_pre, v_pre + diff])
post_v_and_perturbed = jnp.array([v_post, v_post + diff])
Expand All @@ -345,6 +343,7 @@ def _synapse_currents(
syn_voltages = voltage_term, constant_term

# Gather slope and offset for every postsynaptic compartment.
# import jax; jax.debug.print("{}", synapse_params)
gathered_syn_currents = gather_synapes(num_comp, post_inds, *syn_voltages)

syn_voltage_terms = syn_voltage_terms.at[:].add(gathered_syn_currents[0])
Expand All @@ -357,10 +356,10 @@ def _synapse_currents(
.add(synapse_currents_dist[0])
)

# Copy the currents into the `state` dictionary such that they can be
# recorded and used by `Channel.update_states()`.
for name in [s._name for s in self.synapses]:
states[f"i_{name}"] = synapse_current_states[f"i_{name}"]
# Copy the currents into the `state` dictionary such that they can be
# recorded and used by `Channel.update_states()`.
for name in [s._name for s in self.synapses]:
states[f"i_{name}"] = synapse_current_states[f"i_{name}"]
return states, (syn_voltage_terms, syn_constant_terms)

def arrange_in_layers(
Expand Down
4 changes: 2 additions & 2 deletions jaxley/utils/cell_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ def group_and_sum(
return group_sums


def query_states_and_params(d, keys, idcs):
def query_states_and_params(d, keys, idcs=None):
"""Get dict with subset of keys and values from d.
This is used to restrict a dict where every item contains __all__ states to only
Expand All @@ -698,7 +698,7 @@ def query_states_and_params(d, keys, idcs):
```states = {'eCa': Array([ 0., 0.]}```
Only loops over necessary keys, as opposed to looping over `d.items()`."""
return dict(zip(keys, (v[idcs] for v in map(d.get, keys))))
return dict(zip(keys, (v if idcs is None else v[idcs] for v in map(d.get, keys))))


def compute_axial_conductances(
Expand Down

0 comments on commit 91710a2

Please sign in to comment.