Skip to content

Commit

Permalink
wip: save wip, bug hunting in _synapse_current voltages
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Dec 6, 2024
1 parent b5c8a6b commit ead8d3f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
2 changes: 1 addition & 1 deletion jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ def to_jax(self):
inds = data.index[data["type"] == mech._name]
states_params = list(mech.params) + list(mech.states)
params = data[states_params].loc[inds]
jax_arrays.update({mech._name: inds})
# jax_arrays.update({mech._name: inds})
jax_arrays.update(params.to_dict(orient="list"))

morph_params = ["radius", "length", "axial_resistivity", "capacitance"]
Expand Down
17 changes: 8 additions & 9 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def _step_synapse(
) -> Tuple[Dict, Tuple[jnp.ndarray, jnp.ndarray]]:
"""Perform one step of the synapses and obtain their currents."""
states = self._step_synapse_state(states, syn_channels, params, delta_t, edges)
# import jax; jax.debug.print("1 {}", states["TestSynapse_c"])
states, current_terms = self._synapse_currents(
states, syn_channels, params, delta_t, edges
)
Expand Down Expand Up @@ -288,8 +289,7 @@ def _step_synapse_state(
# Rebuild state. This has to be done within the loop over channels to allow
# multiple channels which modify the same state.
for key, val in states_updated.items():
states[key] = states[key].at[group.index.to_numpy()].set(val)

states[key] = states[key].at[:].set(val)
return states

def _synapse_currents(
Expand All @@ -309,7 +309,7 @@ def _synapse_currents(
# offset.
diff = 1e-3

synapse_current_states = {f"{s._name}_current": zeros for s in syn_channels}
synapse_current_states = {f"i_{s._name}": zeros for s in syn_channels}
for i, group in edges.groupby("type_ind"):
synapse = syn_channels[i]
pre_inds = group["pre_global_comp_index"].to_numpy()
Expand All @@ -320,6 +320,7 @@ def _synapse_currents(
synapse_params = query_syn(params, synapse.params)
synapse_states = query_syn(states, synapse.states)

num_comp = len(voltages)
v_pre, v_post = voltages[pre_inds], voltages[post_inds]
pre_v_and_perturbed = jnp.array([v_pre, v_pre + diff])
post_v_and_perturbed = jnp.array([v_post, v_post + diff])
Expand All @@ -344,24 +345,22 @@ def _synapse_currents(
syn_voltages = voltage_term, constant_term

# Gather slope and offset for every postsynaptic compartment.
num_comp = len(voltages)
gathered_syn_currents = gather_synapes(num_comp, post_inds, *syn_voltages)

syn_voltage_terms = syn_voltage_terms.at[:].add(gathered_syn_currents[0])
syn_constant_terms = syn_constant_terms.at[:].add(-gathered_syn_currents[1])

# Save the current (for the unperturbed voltage) as a state that will
# also be passed to the state update.
synapse_current_states[f"{synapse._name}_current"] = (
synapse_current_states[f"{synapse._name}_current"]
.at[edge_inds]
synapse_current_states[f"i_{synapse._name}"] = (
synapse_current_states[f"i_{synapse._name}"]
.at[post_inds]
.add(synapse_currents_dist[0])
)

# Copy the currents into the `state` dictionary such that they can be
# recorded and used by `Channel.update_states()`.
for name in [s._name for s in self.synapses]:
states[f"{name}_current"] = synapse_current_states[f"{name}_current"]
states[f"i_{name}"] = synapse_current_states[f"i_{name}"]
return states, (syn_voltage_terms, syn_constant_terms)

def arrange_in_layers(
Expand Down

0 comments on commit ead8d3f

Please sign in to comment.