Skip to content

Commit

Permalink
Make recording a mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Nov 3, 2023
1 parent 364df29 commit 0395800
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 76 deletions.
28 changes: 2 additions & 26 deletions neurax/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
def integrate(
module: Module,
stimuli: Union[List[Stimulus], Stimuli],
recordings: List[Recording],
params: List[Dict[str, jnp.ndarray]] = [],
t_max: Optional[float] = None,
delta_t: float = 0.025,
Expand All @@ -25,9 +24,7 @@ def integrate(
Solves ODE and simulates neuron model.
Args:
t_max: Duration of the simulation in milliseconds. If `None`, the duration is
inferred from the duration of the stimulus. If it is larger than the
duration of the stimulus, the stimulus is padded with zeros at the end.
t_max: Duration of the simulation in milliseconds.
delta_t: Time step of the solver in milliseconds.
solver: Which ODE solver to use. Either of ["fwd_euler", "bwd_euler", "cranck"].
tridiag_solver: Algorithm to solve tridiagonal systems. The different options
Expand All @@ -47,7 +44,7 @@ def integrate(
assert module.initialized, "Module is not initialized, run `.initialize()`."

i_current, i_inds = prepare_stim(module, stimuli)
rec_inds = prepare_recs(module, recordings)
rec_inds = module.recordings.comp_index.to_numpy()

# Shorten or pad stimulus depending on `t_max`.
if t_max is not None:
Expand Down Expand Up @@ -107,27 +104,6 @@ def _body_fun(state, i_stim):
return jnp.concatenate([init_recording, recordings[:nsteps_to_return]], axis=0).T


def prepare_recs(module, recordings: List[Recording]):
"""Prepare recordings."""
nseg = module.nseg
cumsum_nbranches = module.cumsum_nbranches

for rec in recordings:
assert rec.cell_ind < len(
module.nbranches_per_cell
), "recording.cell_ind is larger than the number of cells."
assert (
rec.branch_ind < module.nbranches_per_cell[rec.cell_ind]
), "recording.branch_ind is larger than the number of branches in the cell."
assert rec.loc <= 1.0 and rec.loc >= 0.0, "recording.loc must be in [0, 1]."

rec_comp_inds = [index_of_loc(r.branch_ind, r.loc, nseg) for r in recordings]
rec_comp_inds = jnp.asarray(rec_comp_inds)
rec_branch_inds = jnp.asarray([r.cell_ind for r in recordings])
rec_branch_inds = nseg * cumsum_nbranches[rec_branch_inds]
return rec_branch_inds + rec_comp_inds


def prepare_stim(module, stimuli: Union[List[Stimulus], Stimuli]):
"""Prepare stimuli."""
nseg = module.nseg
Expand Down
57 changes: 40 additions & 17 deletions neurax/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def __init__(self):
self.conns: List[Synapse] = None
self.group_views = {}

self.nodes: pd.DataFrame = None
self.syn_edges: pd.DataFrame = None
self.branch_edges: pd.DataFrame = None
self.nodes: Optional[pd.DataFrame] = None
self.syn_edges: Optional[pd.DataFrame] = None
self.branch_edges: Optional[pd.DataFrame] = None

self.cumsum_nbranches: jnp.ndarray = None

Expand All @@ -53,6 +53,9 @@ def __init__(self):
self.trainable_params: List[Dict[str, jnp.ndarray]] = []
self.allow_make_trainable: bool = True

# For recordings.
self.recordings: pd.DataFrame = pd.DataFrame().from_dict({})

def __repr__(self):
return f"{type(self).__name__} with {len(self.channel_nodes)} different channels. Use `.show()` for details."

Expand Down Expand Up @@ -395,6 +398,16 @@ def initialize(self):
self.init_syns()
return self

def record(self):
"""Insert a recording into the given section."""
self._record(self.nodes)

def _record(self, view):
assert (
len(view) == 1
), "Can only record from compartments, not branches, cells, or networks."
self.recordings = pd.concat([self.recordings, view])

def insert(self, channel):
"""Insert a channel."""
self._insert(channel, self.nodes)
Expand Down Expand Up @@ -551,15 +564,30 @@ def show(
states: bool = True,
):
if channel_name is None:
myview = self.view.drop("original_comp_index", axis=1)
myview = myview.drop("original_branch_index", axis=1)
myview = myview.drop("original_cell_index", axis=1)
myview = self.view.drop("global_comp_index", axis=1)
myview = myview.drop("global_branch_index", axis=1)
myview = myview.drop("global_cell_index", axis=1)
return self.pointer._show_base(myview, indices, params, states)
else:
return self.pointer._show_channel(
self.view, channel_name, indices, params, states
)

def set_global_index_and_index(nodes):
"""Use the global compartment, branch, and cell index as the index."""
nodes = nodes.drop("controlled_by_param", axis=1)
nodes = nodes.drop("comp_index", axis=1)
nodes = nodes.drop("branch_index", axis=1)
nodes = nodes.drop("cell_index", axis=1)
nodes = nodes.rename(
columns={
"global_comp_index": "comp_index",
"global_branch_index": "branch_index",
"global_cell_index": "cell_index",
}
)
return nodes

def insert(self, channel):
"""Insert a channel."""
assert not inspect.isclass(
Expand All @@ -568,19 +596,14 @@ def insert(self, channel):
Channel is a class, but it was not initialized. Use `.insert(Channel())`
instead of `.insert(Channel)`.
"""
nodes = self.view.drop("controlled_by_param", axis=1)
nodes = nodes.drop("comp_index", axis=1)
nodes = nodes.drop("branch_index", axis=1)
nodes = nodes.drop("cell_index", axis=1)
nodes = nodes.rename(
columns={
"original_comp_index": "comp_index",
"original_branch_index": "branch_index",
"original_cell_index": "cell_index",
}
)
nodes = self.set_global_index_and_index(self.view)
self.pointer._insert(channel, nodes)

def record(self):
"""Insert a channel."""
nodes = self.set_global_index_and_index(self.view)
self.pointer._record(nodes)

def set_params(self, key: str, val: float):
"""Set parameters of the pointer."""
self.pointer._set_params(key, val, self.view)
Expand Down
8 changes: 0 additions & 8 deletions neurax/recording.py

This file was deleted.

Loading

0 comments on commit 0395800

Please sign in to comment.