Skip to content

Commit

Permalink
data_clamp (#374)
Browse files Browse the repository at this point in the history
* Added data_clamp method for jitting simulations with clamps

* Added delete clamps method

* Added data_clamp to View

* Black

* Unified _data_clamp and _data_stimulate

* Review updates

* One more typehint

* Get rid of extraneously added null currents from integration

* Added more data_clamp tests and a check for same length externals in integration (padding only happens when t_max is specified)

* Minor revisions and formatting

* Black upgraded
  • Loading branch information
kyralianaka authored Oct 11, 2024
1 parent cd861e3 commit 3cdb490
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 53 deletions.
69 changes: 44 additions & 25 deletions jaxley/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def integrate(
*,
param_state: Optional[List[Dict]] = None,
data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,
data_clamps: Optional[Tuple[str, jnp.ndarray, pd.DataFrame]] = None,
t_max: Optional[float] = None,
delta_t: float = 0.025,
solver: str = "bwd_euler",
Expand All @@ -34,6 +35,8 @@ def integrate(
param_state: Parameters returned by `data_set`.
data_stimuli: Outputs of `.data_stimulate()`, only needed if stimuli change
across function calls.
data_clamps: Outputs of `.data_clamp()`, only needed if clamps change across
function calls.
t_max: Duration of the simulation in milliseconds. If `t_max` is greater than
the length of the stimulus input, the stimulus will be padded at the end
with zeros. If `t_max` is smaller, then the stimulus with be truncated.
Expand Down Expand Up @@ -70,16 +73,25 @@ def integrate(
if "i" in module.externals.keys() or data_stimuli is not None:
if "i" in module.externals.keys():
if data_stimuli is not None:
externals["i"] = jnp.concatenate([externals["i"], data_stimuli[0]])
externals["i"] = jnp.concatenate([externals["i"], data_stimuli[1]])
external_inds["i"] = jnp.concatenate(
[external_inds["i"], data_stimuli[1].comp_index.to_numpy()]
[external_inds["i"], data_stimuli[2].comp_index.to_numpy()]
)
else:
externals["i"] = data_stimuli[0]
external_inds["i"] = data_stimuli[1].comp_index.to_numpy()
else:
externals["i"] = jnp.asarray([[]]).astype("float")
external_inds["i"] = jnp.asarray([]).astype("int32")
externals["i"] = data_stimuli[1]
external_inds["i"] = data_stimuli[2].comp_index.to_numpy()

# If a clamp is inserted, add it to the external inputs.
if data_clamps is not None:
state_name, clamps, inds = data_clamps
if state_name in module.externals.keys():
externals[state_name] = jnp.concatenate([externals[state_name], clamps])
external_inds[state_name] = jnp.concatenate(
[external_inds[state_name], inds.comp_index.to_numpy()]
)
else:
externals[state_name] = clamps
external_inds[state_name] = inds.comp_index.to_numpy()

if not externals.keys():
# No stimulus was inserted and no clamp was set.
Expand All @@ -98,17 +110,17 @@ def integrate(
t_max_steps = int(t_max // delta_t + 1)

# Pad or truncate the stimulus.
if "i" in externals.keys() and t_max_steps > externals["i"].shape[0]:
pad = jnp.zeros(
(t_max_steps - externals["i"].shape[0], externals["i"].shape[1])
)
externals["i"] = jnp.concatenate((externals["i"], pad))

for key in externals.keys():
if t_max_steps > externals[key].shape[0]:
raise NotImplementedError(
"clamp must be at least as long as simulation."
)
if key == "i":
pad = jnp.zeros(
(t_max_steps - externals["i"].shape[0], externals["i"].shape[1])
)
externals["i"] = jnp.concatenate((externals["i"], pad))
else:
raise NotImplementedError(
"clamp must be at least as long as simulation."
)
else:
externals[key] = externals[key][:t_max_steps, :]

Expand Down Expand Up @@ -148,20 +160,27 @@ def _body_fun(state, externals):
# If necessary, pad the stimulus with zeros in order to simulate sufficiently long.
# The total simulation length will be `prod(checkpoint_lengths)`. At the end, we
# return only the first `nsteps_to_return` elements (plus the initial state).
example_key = list(externals.keys())[0]
nsteps_to_return = len(externals[example_key])
if externals:
example_key = list(externals.keys())[0]
nsteps_to_return = len(externals[example_key])
else:
nsteps_to_return = t_max_steps

if checkpoint_lengths is None:
checkpoint_lengths = [len(externals[example_key])]
length = len(externals[example_key])
checkpoint_lengths = [nsteps_to_return]
length = nsteps_to_return
else:
length = prod(checkpoint_lengths)
size_difference = length - len(externals[example_key])
dummy_external = jnp.zeros((size_difference, externals[example_key].shape[1]))
size_difference = length - nsteps_to_return
assert (
len(externals[example_key]) <= length
nsteps_to_return <= length
), "The desired simulation duration is longer than `prod(nested_length)`."
for key in externals.keys():
externals[key] = jnp.concatenate([externals[key], dummy_external])
if externals:
dummy_external = jnp.zeros(
(size_difference, externals[example_key].shape[1])
)
for key in externals.keys():
externals[key] = jnp.concatenate([externals[key], dummy_external])

# Record the initial state.
init_recs = jnp.asarray(
Expand Down
105 changes: 77 additions & 28 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,45 +996,79 @@ def data_stimulate(
verbose: Whether or not to print the number of inserted stimuli. `False`
by default because this method is meant to be jitted.
"""
return self._data_stimulate(current, data_stimuli, self.nodes, verbose=verbose)
return self._data_external_input(
"i", current, data_stimuli, self.nodes, verbose=verbose
)

def _data_stimulate(
def data_clamp(
self,
current: jnp.ndarray,
data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]],
state_name: str,
state_array: jnp.ndarray,
data_clamps: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,
verbose: bool = False,
):
"""Insert a clamp into the module within jit (or grad).
Args:
state_name: Name of the state variable to set.
state_array: Time series of the state variable in the default Jaxley unit.
State array should be of shape (num_clamps, simulation_time) or
(simulation_time, ) for a single clamp.
verbose: Whether or not to print the number of inserted clamps. `False`
by default because this method is meant to be jitted.
"""
return self._data_external_input(
state_name, state_array, data_clamps, self.nodes, verbose=verbose
)

def _data_external_input(
self,
state_name: str,
state_array: jnp.ndarray,
data_external_input: Optional[Tuple[jnp.ndarray, pd.DataFrame]],
view: pd.DataFrame,
verbose: bool = False,
) -> Tuple[jnp.ndarray, pd.DataFrame]:
current = current if current.ndim == 2 else jnp.expand_dims(current, axis=0)
batch_size = current.shape[0]
):
state_array = (
state_array
if state_array.ndim == 2
else jnp.expand_dims(state_array, axis=0)
)
batch_size = state_array.shape[0]
is_multiple = len(view) == batch_size
current = current if is_multiple else jnp.repeat(current, len(view), axis=0)
assert batch_size in [1, len(view)], "Number of comps and stimuli do not match."
state_array = (
state_array if is_multiple else jnp.repeat(state_array, len(view), axis=0)
)
assert batch_size in [1, len(view)], "Number of comps and clamps do not match."

if data_stimuli is not None:
currents = data_stimuli[0]
inds = data_stimuli[1]
if data_external_input is not None:
external_input = data_external_input[1]
external_input = jnp.concatenate([external_input, state_array])
inds = data_external_input[2]
else:
currents = None
external_input = state_array
inds = pd.DataFrame().from_dict({})

# Same as in `.stimulate()`.
if currents is not None:
currents = jnp.concatenate([currents, current])
else:
currents = current
inds = pd.concat([inds, view])

if verbose:
print(f"Added {len(view)} stimuli.")
if state_name == "i":
print(f"Added {len(view)} stimuli.")
else:
print(f"Added {len(view)} clamps.")

return (currents, inds)
return (state_name, external_input, inds)

def delete_stimuli(self):
"""Removes all stimuli from the module."""
self.externals.pop("i", None)
self.external_inds.pop("i", None)

def delete_clamps(self, state_name: str):
"""Removes all clamps of the given state from the module."""
self.externals.pop(state_name, None)
self.external_inds.pop(state_name, None)

def insert(self, channel: Channel):
"""Insert a channel into the module.
Expand Down Expand Up @@ -1084,12 +1118,14 @@ def step(
voltages = u["v"]

# Extract the external inputs
has_current = "i" in externals.keys()
i_current = externals["i"] if has_current else jnp.asarray([]).astype("float")
i_inds = external_inds["i"] if has_current else jnp.asarray([]).astype("int32")
i_ext = self._get_external_input(
voltages, i_inds, i_current, params["radius"], params["length"]
)
if "i" in externals.keys():
i_current = externals["i"]
i_inds = external_inds["i"]
i_ext = self._get_external_input(
voltages, i_inds, i_current, params["radius"], params["length"]
)
else:
i_ext = 0.0

# Step of the channels.
u, (v_terms, const_terms) = self._step_channels(
Expand Down Expand Up @@ -1829,8 +1865,8 @@ def data_stimulate(
by default because this method is meant to be jitted.
"""
nodes = self.set_global_index_and_index(self.view)
return self.pointer._data_stimulate(
current, data_stimuli, nodes, verbose=verbose
return self.pointer._data_external_input(
"i", current, data_stimuli, nodes, verbose=verbose
)

def clamp(self, state_name: str, state_array: jnp.ndarray, verbose: bool = True):
Expand All @@ -1846,6 +1882,19 @@ def clamp(self, state_name: str, state_array: jnp.ndarray, verbose: bool = True)
nodes = self.set_global_index_and_index(self.view)
self.pointer._external_input(state_name, state_array, nodes, verbose=verbose)

def data_clamp(
self,
state_name: str,
state_array: jnp.ndarray,
data_clamps: Optional[Tuple[jnp.ndarray, pd.DataFrame]],
verbose: bool = False,
):
"""Insert a clamp into the module within jit (or grad)."""
nodes = self.set_global_index_and_index(self.view)
return self.pointer._data_external_input(
state_name, state_array, data_clamps, nodes, verbose=verbose
)

def set(self, key: str, val: float):
"""Set parameters of the pointer."""
self.pointer._set(key, val, self.view, self.pointer.nodes)
Expand Down
111 changes: 111 additions & 0 deletions tests/test_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,114 @@ def test_clamp_and_stimulate_api():

vs2 = jnp.concatenate([vs21, vs22])
assert np.max(np.abs(vs1 - vs2)) < 1e-8


def test_data_clamp():
"""Data clamp with no stimuli or data_stimuli, and no t_max (should get defined by the clamp)."""
comp = jx.Compartment()
comp.insert(HH())
comp.record()
clamp = -50.0 * jnp.ones((1000,))

def provide_data(clamp):
data_clamps = comp.data_clamp("v", clamp)
return data_clamps

def simulate(clamp):
data_clamps = provide_data(clamp)
return jx.integrate(comp, data_clamps=data_clamps)

jitted_simulate = jax.jit(simulate)

s = jitted_simulate(clamp)
assert np.all(s[:, 1:] == -50.0)


def test_data_clamp_and_data_stimulate():
"""In theory people shouldn't use these two together, but at least it shouldn't break."""
comp = jx.Compartment()
comp.insert(HH())
comp.record()
clamp = -50.0 * jnp.ones((1000,))
stim = 0.1 * jnp.ones((1000,))

def provide_data(clamp, stim):
data_clamps = comp.data_clamp("v", clamp)
data_stims = comp.data_stimulate(stim)
return data_clamps, data_stims

def simulate(clamp, stim):
data_clamps, data_stims = provide_data(clamp, stim)
return jx.integrate(comp, data_clamps=data_clamps, data_stimuli=data_stims)

jitted_simulate = jax.jit(simulate)

s = jitted_simulate(clamp, stim)
assert np.all(s[:, 1:] == -50.0)


def test_data_clamp_and_stimulate():
"""Test that data clamp overrides a previously set stimulus."""
comp = jx.Compartment()
comp.insert(HH())
comp.record()
clamp = -50.0 * jnp.ones((1000,))
stim = 0.1 * jnp.ones((800,))
t_max = clamp.shape[0] * 0.025 # make sure the stimulus gets padded
comp.stimulate(stim)

def simulate(clamp):
data_clamps = comp.data_clamp("v", clamp) # should override the stimulation
return jx.integrate(comp, data_clamps=data_clamps, t_max=t_max)

jitted_simulate = jax.jit(simulate)

s = jitted_simulate(clamp)
assert np.all(s[:, 1:] == -50.0)


def test_data_clamp_and_clamp():
"""Test that data clamp can override (same loc.) and add (another loc.) to clamp."""
comp = jx.Compartment()
comp.insert(HH())
comp.record()
clamp1 = -50.0 * jnp.ones((1000,))
clamp2 = -60.0 * jnp.ones((1000,))
comp.clamp("v", clamp1)

def simulate(clamp):
data_clamps = comp.data_clamp(
"v", clamp, None
) # should override the first clamp
return jx.integrate(comp, data_clamps=data_clamps)

jitted_simulate = jax.jit(simulate)

# Clamp2 should override clamp1 here
s = jitted_simulate(clamp2)
assert np.all(s[:, 1:] == -60.0)

comp2 = jx.Compartment()
comp2.insert(HH())
branch1 = jx.Branch(comp, 4)
branch2 = jx.Branch(comp2, 4)
cell = jx.Cell([branch1, branch2], [-1, -1])

# Apply the clamp1 to the second branch via clamp
cell[1, 0].clamp("v", clamp1)

cell.delete_recordings()
cell.branch(0).comp(0).record()
cell.branch(1).comp(0).record()

def simulate(clamp):
data_clamps = cell.branch(0).comp(0).data_clamp("v", clamp, None)
return jx.integrate(cell, data_clamps=data_clamps)

jitted_simulate = jax.jit(simulate)

# Apply clamp2 to the first branch via data_clamp
s = jitted_simulate(clamp2)

assert np.all(s[0, 1:] == -60.0)
assert np.all(s[1, 1:] == -50.0)

0 comments on commit 3cdb490

Please sign in to comment.