Skip to content

Commit

Permalink
Allow task training by setting the current from outside
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Nov 3, 2023
1 parent ea45ba8 commit 36d41c1
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 45 deletions.
2 changes: 1 addition & 1 deletion neurax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from neurax.integrate import integrate
from neurax.modules import *
from neurax.optimize import ParamTransform
from neurax.stimulus import Stimuli, Stimulus, step_current
from neurax.stimulus import step_current
6 changes: 3 additions & 3 deletions neurax/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import jax.numpy as jnp

from neurax.modules import Module
from neurax.stimulus import Stimuli, Stimulus
from neurax.utils.cell_utils import index_of_loc
from neurax.utils.jax_utils import nested_checkpoint_scan


def integrate(
module: Module,
params: List[Dict[str, jnp.ndarray]] = [],
currents: Optional[jnp.ndarray] = None,
*,
t_max: Optional[float] = None,
delta_t: float = 0.025,
solver: str = "bwd_euler",
Expand Down Expand Up @@ -41,7 +41,7 @@ def integrate(

assert module.initialized, "Module is not initialized, run `.initialize()`."

i_current = module.currents.T
i_current = module.currents.T if currents is None else currents.T
i_inds = module.current_inds.comp_index.to_numpy()
rec_inds = module.recordings.comp_index.to_numpy()

Expand Down
4 changes: 3 additions & 1 deletion neurax/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,9 @@ def _stimulate(self, current, view):
len(view) == 1
), "Can only stimulate compartments, not branches, cells, or networks."
if self.currents is not None:
self.currents = jnp.concatenate([self.currents, jnp.expand_dims(current, axis=0)])
self.currents = jnp.concatenate(
[self.currents, jnp.expand_dims(current, axis=0)]
)
else:
self.currents = jnp.expand_dims(current, axis=0)
self.current_inds = pd.concat([self.current_inds, view])
Expand Down
40 changes: 0 additions & 40 deletions neurax/stimulus.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,46 +7,6 @@
from neurax.utils.cell_utils import index_of_loc


class Stimulus:
"""A single stimulus to the network."""

def __init__(
self, cell_ind, branch_ind, loc, current: Optional[jnp.ndarray] = None
):
"""
Args:
current: Time series of the current.
"""
self.cell_ind = cell_ind
self.branch_ind = branch_ind
self.loc = loc
self.current = current


class Stimuli:
"""Several stimuli to the network.
Here, the properties of all individual stimuli already get vectorized and put
into arrays. This increases speed for big datasets consisting of dozens or hundreds
of stimuli.
"""

def __init__(
self, stims: List[Stimulus], nseg_per_branch: int, cumsum_nbranches: jnp.ndarray
):
self.comp_inds = jnp.asarray(
[index_of_loc(s.branch_ind, s.loc, nseg_per_branch) for s in stims]
)
cell_inds = jnp.asarray([s.cell_ind for s in stims])
self.branch_inds = cumsum_nbranches[cell_inds] * nseg_per_branch
self.currents = jnp.asarray([s.current for s in stims]).T # nA

def set_currents(self, currents: float):
"""Rescale the current of the stimulus with a constant value over time."""
self.currents = currents
return self


def step_current(
i_delay: float, i_dur: float, i_amp: float, time_vec: jnp.asarray, i_offset=0.0
):
Expand Down

0 comments on commit 36d41c1

Please sign in to comment.