diff --git a/neurax/integrate.py b/neurax/integrate.py index 29c10e97..001f8e7f 100644 --- a/neurax/integrate.py +++ b/neurax/integrate.py @@ -11,7 +11,6 @@ def integrate( module: Module, - stimuli: Union[List[Stimulus], Stimuli], params: List[Dict[str, jnp.ndarray]] = [], t_max: Optional[float] = None, delta_t: float = 0.025, @@ -42,7 +41,8 @@ def integrate( assert module.initialized, "Module is not initialized, run `.initialize()`." - i_current, i_inds = prepare_stim(module, stimuli) + i_current = module.currents.T + i_inds = module.current_inds.comp_index.to_numpy() rec_inds = module.recordings.comp_index.to_numpy() # Shorten or pad stimulus depending on `t_max`. @@ -101,38 +101,3 @@ def _body_fun(state, i_stim): _body_fun, states, i_current, length=length, nested_lengths=checkpoint_lengths ) return jnp.concatenate([init_recording, recordings[:nsteps_to_return]], axis=0).T - - -def prepare_stim(module, stimuli: Union[List[Stimulus], Stimuli]): - """Prepare stimuli.""" - nseg = module.nseg - cumsum_nbranches = module.cumsum_nbranches - - if isinstance(stimuli, Stimuli): - # Indexing. - i_comp_inds = stimuli.comp_inds - i_branch_inds = stimuli.branch_inds - - # Currents. - i_ext = stimuli.currents # nA - else: - for stim in stimuli: - assert stim.cell_ind < len( - module.nbranches_per_cell - ), "stimulus.cell_ind is larger than the number of cells." - assert ( - stim.branch_ind < module.nbranches_per_cell[stim.cell_ind] - ), "stimulus.branch_ind is larger than the number of branches in the cell." - assert ( - stim.loc <= 1.0 and stim.loc >= 0.0 - ), "stimulus.loc must be in [0, 1]." - # Indexing. - i_comp_inds = [index_of_loc(s.branch_ind, s.loc, nseg) for s in stimuli] - i_comp_inds = jnp.asarray(i_comp_inds) - i_branch_inds = jnp.asarray([s.cell_ind for s in stimuli]) - i_branch_inds = cumsum_nbranches[i_branch_inds] * nseg - - # Currents. - i_ext = jnp.asarray([s.current for s in stimuli]).T # nA - - return i_ext, i_branch_inds + i_comp_inds diff --git a/neurax/modules/base.py b/neurax/modules/base.py index cd03922b..a4707824 100644 --- a/neurax/modules/base.py +++ b/neurax/modules/base.py @@ -29,7 +29,7 @@ def __init__(self): self.syn_edges: Optional[pd.DataFrame] = None self.branch_edges: Optional[pd.DataFrame] = None - self.cumsum_nbranches: jnp.ndarray = None + self.cumsum_nbranches: Optional[jnp.ndarray] = None self.comb_parents: jnp.ndarray = jnp.asarray([-1]) self.comb_branches_in_each_level: List[jnp.ndarray] = [jnp.asarray([0])] @@ -56,6 +56,10 @@ def __init__(self): # For recordings. self.recordings: pd.DataFrame = pd.DataFrame().from_dict({}) + # For stimuli. + self.currents: Optional[jnp.ndarray] = None + self.current_inds: pd.DataFrame = pd.DataFrame().from_dict({}) + def __repr__(self): return f"{type(self).__name__} with {len(self.channel_nodes)} different channels. Use `.show()` for details." @@ -399,7 +403,7 @@ def initialize(self): return self def record(self): - """Insert a recording into the given section.""" + """Insert a recording into the compartment.""" self._record(self.nodes) def _record(self, view): @@ -408,6 +412,20 @@ def _record(self, view): ), "Can only record from compartments, not branches, cells, or networks." self.recordings = pd.concat([self.recordings, view]) + def stimulate(self, current): + """Insert a stimulus into the compartment.""" + self._stimulate(self, current, self.nodes) + + def _stimulate(self, current, view): + assert ( + len(view) == 1 + ), "Can only stimulate compartments, not branches, cells, or networks." + if self.currents is not None: + self.currents = jnp.concatenate([self.currents, jnp.expand_dims(current, axis=0)]) + else: + self.currents = jnp.expand_dims(current, axis=0) + self.current_inds = pd.concat([self.current_inds, view]) + def insert(self, channel): """Insert a channel.""" self._insert(channel, self.nodes) @@ -604,6 +622,10 @@ def record(self): nodes = self.set_global_index_and_index(self.view) self.pointer._record(nodes) + def stimulate(self, current): + nodes = self.set_global_index_and_index(self.view) + self.pointer._stimulate(current, nodes) + def set_params(self, key: str, val: float): """Set parameters of the pointer.""" self.pointer._set_params(key, val, self.view)