diff --git a/neurax/modules/base.py b/neurax/modules/base.py index e35e0109..c778e265 100644 --- a/neurax/modules/base.py +++ b/neurax/modules/base.py @@ -413,6 +413,10 @@ def _record(self, view): ), "Can only record from compartments, not branches, cells, or networks." self.recordings = pd.concat([self.recordings, view]) + def delete_recordings(self): + """Removes all recordings from the module.""" + self.recordings = pd.DataFrame().from_dict({}) + def stimulate(self, current): """Insert a stimulus into the compartment.""" self._stimulate(current, self.nodes) @@ -429,6 +433,11 @@ def _stimulate(self, current, view): self.currents = jnp.expand_dims(current, axis=0) self.current_inds = pd.concat([self.current_inds, view]) + def delete_stimuli(self): + """Removes all stimuli from the module.""" + self.currents = None + self.current_inds = pd.DataFrame().from_dict({}) + def insert(self, channel): """Insert a channel.""" self._insert(channel, self.nodes) diff --git a/tests/test_record_and_stimulate.py b/tests/test_record_and_stimulate.py new file mode 100644 index 00000000..8f61d4c6 --- /dev/null +++ b/tests/test_record_and_stimulate.py @@ -0,0 +1,58 @@ +import jax + +jax.config.update("jax_enable_x64", True) +jax.config.update("jax_platform_name", "cpu") + +import jax.numpy as jnp + +import neurax as nx + + +def test_record_and_stimulate_api(): + """Test the API for recording and stimulating.""" + nseg_per_branch = 2 + depth = 2 + parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] + parents = jnp.asarray(parents) + num_branches = len(parents) + + comp = nx.Compartment().initialize() + branch = nx.Branch([comp for _ in range(nseg_per_branch)]).initialize() + cell = nx.Cell([branch for _ in range(num_branches)], parents=parents).initialize() + + cell.branch(0).comp(0.0).record() + cell.branch(1).comp(1.0).record() + + current = nx.step_current(0.0, 1.0, 1.0, 0.025, 3.0) + cell.branch(1).comp(1.0).stimulate(current) + + cell.delete_recordings() + cell.delete_stimuli() + + +def test_record_shape(): + """Test the API for recording and stimulating.""" + nseg_per_branch = 2 + depth = 2 + parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] + parents = jnp.asarray(parents) + num_branches = len(parents) + + comp = nx.Compartment().initialize() + branch = nx.Branch([comp for _ in range(nseg_per_branch)]).initialize() + cell = nx.Cell([branch for _ in range(num_branches)], parents=parents).initialize() + + current = nx.step_current(0.0, 1.0, 1.0, 0.025, 3.0) + cell.branch(1).comp(1.0).stimulate(current) + + cell.branch(0).comp(0.0).record() + cell.branch(1).comp(1.0).record() + cell.branch(0).comp(1.0).record() + cell.delete_recordings() + cell.branch(2).comp(0.5).record() + cell.branch(1).comp(0.1).record() + + voltages = nx.integrate(cell) + assert ( + voltages.shape[0] == 2 + ), f"Shape of recordings ({voltages.shape}) is not right."