Skip to content

Commit

Permalink
vectorized recording updates
Browse files Browse the repository at this point in the history
  • Loading branch information
ntolley committed Dec 20, 2024
1 parent ac5026d commit 44a3e55
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions jaxley/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,12 @@ def integrate(

if module.recordings.empty:
raise ValueError("No recordings are set. Please set them.")
rec_inds = module.recordings.rec_index.to_numpy()
rec_states = module.recordings.state.to_numpy()
recording_df = module.recordings.reset_index(drop=True)
rec_states, rec_inds, sort_inds = list(), list(), list()
for state, df_group in recording_df.groupby("state"):
rec_states.append(state)
rec_inds.append(df_group.rec_index.to_numpy())
sort_inds.extend(df_group.index.to_list())

# Shorten or pad stimulus depending on `t_max`.
if t_max is not None:
Expand Down Expand Up @@ -260,7 +264,7 @@ def integrate(

def _body_fun(state, externals):
state = step_fn(state, all_params, externals, external_inds, delta_t)
recs = jnp.asarray(
recs = jnp.concatenate(
[
state[rec_state][rec_ind]
for rec_state, rec_ind in zip(rec_states, rec_inds)
Expand Down Expand Up @@ -294,7 +298,7 @@ def _body_fun(state, externals):
externals[key] = jnp.concatenate([externals[key], dummy_external])

# Record the initial state.
init_recs = jnp.asarray(
init_recs = jnp.concatenate(
[
all_states[rec_state][rec_ind]
for rec_state, rec_ind in zip(rec_states, rec_inds)
Expand All @@ -311,4 +315,5 @@ def _body_fun(state, externals):
nested_lengths=checkpoint_lengths,
)
recs = jnp.concatenate([init_recording, recordings[:nsteps_to_return]], axis=0).T
# recs = recs[sort_inds, :] # Sort recordings back to order that was set by user.
return (recs, all_states) if return_states else recs

0 comments on commit 44a3e55

Please sign in to comment.