Skip to content

Commit

Permalink
[MRG] Add ability to record synaptic currents (#523)
Browse files Browse the repository at this point in the history
* Add test for recording currents

* Add synapse current attribute for recordings

* black formatting

* Fix channel current recording and add tests

* possible fix to edge dataframe

* Add tests for advanced synapse indexing, only store edges of post synaptic cell

* update changelog, revert edge view filtering
  • Loading branch information
ntolley authored Nov 28, 2024
1 parent 6474ac2 commit d89fe7b
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 6 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
# 0.6.0

### New Features

- Add ability to record synaptic currents (#523, @ntolley). Recordings can be turned on with
```python
net.record("i_IonotropicSynapse")
```


# 0.5.0

### API changes
Expand Down
15 changes: 11 additions & 4 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def __init__(self):
self.synapse_param_names = []
self.synapse_state_names = []
self.synapse_names = []
self.synapse_current_names: List[str] = []

# List of types of all `jx.Channel`s.
self.channels: List[Channel] = []
Expand Down Expand Up @@ -1207,9 +1208,14 @@ def _get_state_names(self) -> Tuple[List, List]:
Returns states seperated by comps and edges."""
channel_states = [name for c in self.channels for name in c.channel_states]
synapse_states = [name for s in self.synapses for name in s.synapse_states]
synapse_states = [
name for s in self.synapses if s is not None for name in s.synapse_states
]
membrane_states = ["v", "i"] + self.membrane_current_names
return channel_states + membrane_states, synapse_states
return (
channel_states + membrane_states,
synapse_states + self.synapse_current_names,
)

def get_parameters(self) -> List[Dict[str, jnp.ndarray]]:
"""Get all trainable parameters.
Expand Down Expand Up @@ -2444,7 +2450,8 @@ def __init__(
)

self.channels = self._channels_in_view(pointer)
self.membrane_current_names = [c._name for c in self.channels]
self.membrane_current_names = [c.current_name for c in self.channels]
self.synapse_current_names = pointer.synapse_current_names
self._set_trainables_in_view() # run after synapses and channels
self.num_trainable_params = (
np.sum([len(inds) for inds in self.indices_set_by_trainables])
Expand Down Expand Up @@ -2635,7 +2642,7 @@ def _set_synapses_in_view(self, pointer: Union[Module, View]):
viewed_synapses = []
viewed_params = []
viewed_states = []
if not pointer.synapses is None:
if pointer.synapses is not None:
for syn in pointer.synapses:
if syn is not None: # needed for recurive viewing
in_view = syn._name in self.synapse_names
Expand Down
4 changes: 3 additions & 1 deletion jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def _synapse_currents(
# `post_syn_currents` is a `jnp.ndarray` of as many elements as there are
# compartments in the network.
# `[0]` because we only use the non-perturbed voltage.
states[f"{synapse_type._name}_current"] = synapse_currents[0]
states[f"i_{synapse_type._name}"] = synapse_currents[0]

return states, (syn_voltage_terms, syn_constant_terms)

Expand Down Expand Up @@ -565,9 +565,11 @@ def _update_synapse_state_names(self, synapse_type):
def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type):
# Add synapse types to the module and infer their unique identifier.
synapse_name = synapse_type._name
synapse_current_name = f"i_{synapse_name}"
type_ind, is_new = self._infer_synapse_type_ind(synapse_name)
if is_new: # synapse is not known
self._update_synapse_state_names(synapse_type)
self.base.synapse_current_names.append(synapse_current_name)

index = len(self.base.edges)
indices = [idx for idx in range(index, index + len(pre_nodes))]
Expand Down
16 changes: 15 additions & 1 deletion tests/test_record_and_stimulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,30 @@ def test_record_synaptic_and_membrane_states(SimpleNet):
)
net.cell(0).branch(0).loc(0.0).stimulate(current)

# Invoke recording of voltage and synaptic states.
net.cell(2).branch(0).loc(0.0).record("v")
net.IonotropicSynapse.edge(1).record("IonotropicSynapse_s")
net.cell(2).branch(0).loc(0.0).record("HH_m")
net.cell(1).branch(0).loc(0.0).record("v")
net.TestSynapse.edge(0).record("TestSynapse_c")
net.cell(1).branch(0).loc(0.0).record("HH_m")
net.cell(1).branch(0).loc(0.0).record("i_HH")
net.IonotropicSynapse.edge(1).record("i_IonotropicSynapse")

# Advanced synapse indexing for recording.
net.copy_node_property_to_edges("global_cell_index")
# Record currents from specific post synaptic cells.
df = net.edges
df = df.query("pre_global_cell_index in [0, 1]")
net.select(edges=df.index).record("i_IonotropicSynapse")
# Record currents from specific synapse types
df = net.edges
df = df.query("type == 'TestSynapse'")
net.select(edges=df.index).record("i_TestSynapse")

recs = jx.integrate(net)

# Loop over first two recorings and then the second two recordings.
# Loop over first two recordings and then the second two recordings.
for index in [0, 3]:
# Local maxima of voltage trace.
y = recs[index]
Expand Down

0 comments on commit d89fe7b

Please sign in to comment.