diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e421d4a..21020bed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,13 @@ +# 0.6.0 + +### New Features + +- Add ability to record synaptic currents (#523, @ntolley). Recordings can be turned on with +```python +net.record("i_IonotropicSynapse") +``` + + # 0.5.0 ### API changes diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 90d48f5d..5fa50cc0 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -148,6 +148,7 @@ def __init__(self): self.synapse_param_names = [] self.synapse_state_names = [] self.synapse_names = [] + self.synapse_current_names: List[str] = [] # List of types of all `jx.Channel`s. self.channels: List[Channel] = [] @@ -1207,9 +1208,14 @@ def _get_state_names(self) -> Tuple[List, List]: Returns states seperated by comps and edges.""" channel_states = [name for c in self.channels for name in c.channel_states] - synapse_states = [name for s in self.synapses for name in s.synapse_states] + synapse_states = [ + name for s in self.synapses if s is not None for name in s.synapse_states + ] membrane_states = ["v", "i"] + self.membrane_current_names - return channel_states + membrane_states, synapse_states + return ( + channel_states + membrane_states, + synapse_states + self.synapse_current_names, + ) def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: """Get all trainable parameters. @@ -2444,7 +2450,8 @@ def __init__( ) self.channels = self._channels_in_view(pointer) - self.membrane_current_names = [c._name for c in self.channels] + self.membrane_current_names = [c.current_name for c in self.channels] + self.synapse_current_names = pointer.synapse_current_names self._set_trainables_in_view() # run after synapses and channels self.num_trainable_params = ( np.sum([len(inds) for inds in self.indices_set_by_trainables]) @@ -2635,7 +2642,7 @@ def _set_synapses_in_view(self, pointer: Union[Module, View]): viewed_synapses = [] viewed_params = [] viewed_states = [] - if not pointer.synapses is None: + if pointer.synapses is not None: for syn in pointer.synapses: if syn is not None: # needed for recurive viewing in_view = syn._name in self.synapse_names diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 62d74045..0a65c58c 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -379,7 +379,7 @@ def _synapse_currents( # `post_syn_currents` is a `jnp.ndarray` of as many elements as there are # compartments in the network. # `[0]` because we only use the non-perturbed voltage. - states[f"{synapse_type._name}_current"] = synapse_currents[0] + states[f"i_{synapse_type._name}"] = synapse_currents[0] return states, (syn_voltage_terms, syn_constant_terms) @@ -565,9 +565,11 @@ def _update_synapse_state_names(self, synapse_type): def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type): # Add synapse types to the module and infer their unique identifier. synapse_name = synapse_type._name + synapse_current_name = f"i_{synapse_name}" type_ind, is_new = self._infer_synapse_type_ind(synapse_name) if is_new: # synapse is not known self._update_synapse_state_names(synapse_type) + self.base.synapse_current_names.append(synapse_current_name) index = len(self.base.edges) indices = [idx for idx in range(index, index + len(pre_nodes))] diff --git a/tests/test_record_and_stimulate.py b/tests/test_record_and_stimulate.py index 151c3474..8e7b5a2d 100644 --- a/tests/test_record_and_stimulate.py +++ b/tests/test_record_and_stimulate.py @@ -75,16 +75,30 @@ def test_record_synaptic_and_membrane_states(SimpleNet): ) net.cell(0).branch(0).loc(0.0).stimulate(current) + # Invoke recording of voltage and synaptic states. net.cell(2).branch(0).loc(0.0).record("v") net.IonotropicSynapse.edge(1).record("IonotropicSynapse_s") net.cell(2).branch(0).loc(0.0).record("HH_m") net.cell(1).branch(0).loc(0.0).record("v") net.TestSynapse.edge(0).record("TestSynapse_c") net.cell(1).branch(0).loc(0.0).record("HH_m") + net.cell(1).branch(0).loc(0.0).record("i_HH") + net.IonotropicSynapse.edge(1).record("i_IonotropicSynapse") + + # Advanced synapse indexing for recording. + net.copy_node_property_to_edges("global_cell_index") + # Record currents from specific post synaptic cells. + df = net.edges + df = df.query("pre_global_cell_index in [0, 1]") + net.select(edges=df.index).record("i_IonotropicSynapse") + # Record currents from specific synapse types + df = net.edges + df = df.query("type == 'TestSynapse'") + net.select(edges=df.index).record("i_TestSynapse") recs = jx.integrate(net) - # Loop over first two recorings and then the second two recordings. + # Loop over first two recordings and then the second two recordings. for index in [0, 3]: # Local maxima of voltage trace. y = recs[index]