Skip to content

Commit

Permalink
Add functionality to remove recordings
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Nov 13, 2023
1 parent 54a8928 commit 1c06731
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
9 changes: 9 additions & 0 deletions neurax/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,10 @@ def _record(self, view):
), "Can only record from compartments, not branches, cells, or networks."
self.recordings = pd.concat([self.recordings, view])

def delete_recordings(self):
"""Removes all recordings from the module."""
self.recordings = pd.DataFrame().from_dict({})

def stimulate(self, current):
"""Insert a stimulus into the compartment."""
self._stimulate(current, self.nodes)
Expand All @@ -429,6 +433,11 @@ def _stimulate(self, current, view):
self.currents = jnp.expand_dims(current, axis=0)
self.current_inds = pd.concat([self.current_inds, view])

def delete_stimuli(self):
"""Removes all stimuli from the module."""
self.currents = None
self.current_inds = pd.DataFrame().from_dict({})

def insert(self, channel):
"""Insert a channel."""
self._insert(channel, self.nodes)
Expand Down
58 changes: 58 additions & 0 deletions tests/test_record_and_stimulate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import jax

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

import jax.numpy as jnp

import neurax as nx


def test_record_and_stimulate_api():
"""Test the API for recording and stimulating."""
nseg_per_branch = 2
depth = 2
parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]
parents = jnp.asarray(parents)
num_branches = len(parents)

comp = nx.Compartment().initialize()
branch = nx.Branch([comp for _ in range(nseg_per_branch)]).initialize()
cell = nx.Cell([branch for _ in range(num_branches)], parents=parents).initialize()

cell.branch(0).comp(0.0).record()
cell.branch(1).comp(1.0).record()

current = nx.step_current(0.0, 1.0, 1.0, 0.025, 3.0)
cell.branch(1).comp(1.0).stimulate(current)

cell.delete_recordings()
cell.delete_stimuli()


def test_record_shape():
"""Test the API for recording and stimulating."""
nseg_per_branch = 2
depth = 2
parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)]
parents = jnp.asarray(parents)
num_branches = len(parents)

comp = nx.Compartment().initialize()
branch = nx.Branch([comp for _ in range(nseg_per_branch)]).initialize()
cell = nx.Cell([branch for _ in range(num_branches)], parents=parents).initialize()

current = nx.step_current(0.0, 1.0, 1.0, 0.025, 3.0)
cell.branch(1).comp(1.0).stimulate(current)

cell.branch(0).comp(0.0).record()
cell.branch(1).comp(1.0).record()
cell.branch(0).comp(1.0).record()
cell.delete_recordings()
cell.branch(2).comp(0.5).record()
cell.branch(1).comp(0.1).record()

voltages = nx.integrate(cell)
assert (
voltages.shape[0] == 2
), f"Shape of recordings ({voltages.shape}) is not right."

0 comments on commit 1c06731

Please sign in to comment.