Skip to content

Commit

Permalink
attempt at empty stimuli
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Nov 22, 2023
1 parent c5fb06a commit bea31fb
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 33 deletions.
47 changes: 33 additions & 14 deletions jaxley/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,36 @@ def integrate(

assert module.initialized, "Module is not initialized, run `.initialize()`."

i_current = module.currents.T if currents is None else currents.T
i_inds = module.current_inds.comp_index.to_numpy()
if module.currents is not None:
# At least one stimulus was inserted.
if currents is not None:
i_current = currents.T
else:
i_current = module.currents.T
i_inds = module.current_inds.comp_index.to_numpy()
else:
# No stimulus was inserted.
i_current = None
i_inds = None
assert (
t_max is not None
), "If no stimulus is inserted that you have to specify the simulation duration at `jx.integrate(..., t_max=)`."

# Deal with recording.
rec_inds = module.recordings.comp_index.to_numpy()

# Shorten or pad stimulus depending on `t_max`.
if t_max is not None:
t_max_steps = int(t_max // delta_t + 1)
if t_max_steps > i_current.shape[0]:
i_current = jnp.zeros((t_max_steps, i_current.shape[1]))
else:
i_current = i_current[:t_max_steps, :]
nsteps_to_return = int(t_max // delta_t + 1)
if i_current is not None:
if nsteps_to_return > i_current.shape[0]:
num_additional_steps = nsteps_to_return - i_current.shape[0]
i_pad = jnp.zeros((num_additional_steps, i_current.shape[1]))
i_current = jnp.concatenate([i_current, i_pad], axis=0)
else:
i_current = i_current[:nsteps_to_return, :]
else:
nsteps_to_return = len(i_current)

# Run `init_conds()` and return every parameter that is needed to solve the ODE.
# This includes conductances, radiuses, lenghts, axial_resistivities, but also
Expand All @@ -70,23 +89,23 @@ def _body_fun(state, i_stim):
)
return state, state["voltages"][rec_inds]

nsteps_to_return = len(i_current)
init_recording = jnp.expand_dims(module.states["voltages"][rec_inds], axis=0)

# If necessary, pad the stimulus with zeros in order to simulate sufficiently long.
# The total simulation length will be `prod(checkpoint_lengths)`. At the end, we
# return only the first `nsteps_to_return` elements (plus the initial state).
if checkpoint_lengths is None:
checkpoint_lengths = [len(i_current)]
length = len(i_current)
checkpoint_lengths = [nsteps_to_return]
length = nsteps_to_return
else:
length = prod(checkpoint_lengths)
assert (
len(i_current) <= length
nsteps_to_return <= length
), "The desired simulation duration is longer than `prod(nested_length)`."
size_difference = length - len(i_current)
dummy_stimulus = jnp.zeros((size_difference, i_current.shape[1]))
i_current = jnp.concatenate([i_current, dummy_stimulus])
size_difference = length - nsteps_to_return
if i_current is not None:
dummy_stimulus = jnp.zeros((size_difference, i_current.shape[1]))
i_current = jnp.concatenate([i_current, dummy_stimulus])

# Join node and edge states.
states = {}
Expand Down
40 changes: 21 additions & 19 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,8 @@ def step(
self,
u,
delta_t,
i_inds,
i_current,
i_inds: Optional[jnp.ndarray],
i_current: Optional[jnp.ndarray],
params: Dict[str, jnp.ndarray],
solver: str = "bwd_euler",
tridiag_solver: str = "stone",
Expand Down Expand Up @@ -596,30 +596,32 @@ def _step_synapse(
@staticmethod
def get_external_input(
voltages: jnp.ndarray,
i_inds: jnp.ndarray,
i_stim: jnp.ndarray,
i_inds: Optional[jnp.ndarray],
i_stim: Optional[jnp.ndarray],
radius: float,
length_single_compartment: float,
):
"""
Return external input to each compartment in uA / cm^2.
"""
zero_vec = jnp.zeros_like(voltages)
# `radius`: um
# `length_single_compartment`: um
# `i_stim`: nA
current = (
i_stim / 2 / pi / radius[i_inds] / length_single_compartment[i_inds]
) # nA / um^2
current *= 100_000 # Convert (nA / um^2) to (uA / cm^2)

dnums = ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,),
)
stim_at_timestep = scatter_add(zero_vec, i_inds[:, None], current, dnums)
return stim_at_timestep
if i_stim is not None:
# `radius`: um
# `length_single_compartment`: um
# `i_stim`: nA
current = (
i_stim / 2 / pi / radius[i_inds] / length_single_compartment[i_inds]
) # nA / um^2
current *= 100_000 # Convert (nA / um^2) to (uA / cm^2)

dnums = ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,),
)
return scatter_add(zero_vec, i_inds[:, None], current, dnums)
else:
return zero_vec


class View:
Expand Down

0 comments on commit bea31fb

Please sign in to comment.