Skip to content

Commit

Permalink
Make stimulus a mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Nov 3, 2023
1 parent 9bee226 commit ea45ba8
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 39 deletions.
39 changes: 2 additions & 37 deletions neurax/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

def integrate(
module: Module,
stimuli: Union[List[Stimulus], Stimuli],
params: List[Dict[str, jnp.ndarray]] = [],
t_max: Optional[float] = None,
delta_t: float = 0.025,
Expand Down Expand Up @@ -42,7 +41,8 @@ def integrate(

assert module.initialized, "Module is not initialized, run `.initialize()`."

i_current, i_inds = prepare_stim(module, stimuli)
i_current = module.currents.T
i_inds = module.current_inds.comp_index.to_numpy()
rec_inds = module.recordings.comp_index.to_numpy()

# Shorten or pad stimulus depending on `t_max`.
Expand Down Expand Up @@ -101,38 +101,3 @@ def _body_fun(state, i_stim):
_body_fun, states, i_current, length=length, nested_lengths=checkpoint_lengths
)
return jnp.concatenate([init_recording, recordings[:nsteps_to_return]], axis=0).T


def prepare_stim(module, stimuli: Union[List[Stimulus], Stimuli]):
"""Prepare stimuli."""
nseg = module.nseg
cumsum_nbranches = module.cumsum_nbranches

if isinstance(stimuli, Stimuli):
# Indexing.
i_comp_inds = stimuli.comp_inds
i_branch_inds = stimuli.branch_inds

# Currents.
i_ext = stimuli.currents # nA
else:
for stim in stimuli:
assert stim.cell_ind < len(
module.nbranches_per_cell
), "stimulus.cell_ind is larger than the number of cells."
assert (
stim.branch_ind < module.nbranches_per_cell[stim.cell_ind]
), "stimulus.branch_ind is larger than the number of branches in the cell."
assert (
stim.loc <= 1.0 and stim.loc >= 0.0
), "stimulus.loc must be in [0, 1]."
# Indexing.
i_comp_inds = [index_of_loc(s.branch_ind, s.loc, nseg) for s in stimuli]
i_comp_inds = jnp.asarray(i_comp_inds)
i_branch_inds = jnp.asarray([s.cell_ind for s in stimuli])
i_branch_inds = cumsum_nbranches[i_branch_inds] * nseg

# Currents.
i_ext = jnp.asarray([s.current for s in stimuli]).T # nA

return i_ext, i_branch_inds + i_comp_inds
26 changes: 24 additions & 2 deletions neurax/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self):
self.syn_edges: Optional[pd.DataFrame] = None
self.branch_edges: Optional[pd.DataFrame] = None

self.cumsum_nbranches: jnp.ndarray = None
self.cumsum_nbranches: Optional[jnp.ndarray] = None

self.comb_parents: jnp.ndarray = jnp.asarray([-1])
self.comb_branches_in_each_level: List[jnp.ndarray] = [jnp.asarray([0])]
Expand All @@ -56,6 +56,10 @@ def __init__(self):
# For recordings.
self.recordings: pd.DataFrame = pd.DataFrame().from_dict({})

# For stimuli.
self.currents: Optional[jnp.ndarray] = None
self.current_inds: pd.DataFrame = pd.DataFrame().from_dict({})

def __repr__(self):
return f"{type(self).__name__} with {len(self.channel_nodes)} different channels. Use `.show()` for details."

Expand Down Expand Up @@ -399,7 +403,7 @@ def initialize(self):
return self

def record(self):
"""Insert a recording into the given section."""
"""Insert a recording into the compartment."""
self._record(self.nodes)

def _record(self, view):
Expand All @@ -408,6 +412,20 @@ def _record(self, view):
), "Can only record from compartments, not branches, cells, or networks."
self.recordings = pd.concat([self.recordings, view])

def stimulate(self, current):
"""Insert a stimulus into the compartment."""
self._stimulate(self, current, self.nodes)

def _stimulate(self, current, view):
assert (
len(view) == 1
), "Can only stimulate compartments, not branches, cells, or networks."
if self.currents is not None:
self.currents = jnp.concatenate([self.currents, jnp.expand_dims(current, axis=0)])
else:
self.currents = jnp.expand_dims(current, axis=0)
self.current_inds = pd.concat([self.current_inds, view])

def insert(self, channel):
"""Insert a channel."""
self._insert(channel, self.nodes)
Expand Down Expand Up @@ -604,6 +622,10 @@ def record(self):
nodes = self.set_global_index_and_index(self.view)
self.pointer._record(nodes)

def stimulate(self, current):
nodes = self.set_global_index_and_index(self.view)
self.pointer._stimulate(current, nodes)

def set_params(self, key: str, val: float):
"""Set parameters of the pointer."""
self.pointer._set_params(key, val, self.view)
Expand Down

0 comments on commit ea45ba8

Please sign in to comment.