diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b4fa8ed..54ef89fa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,10 @@ net.vis() - changelog added to CI (#537, #558, @jnsbck) +### Code Health + +- Vectorize recording updates during `integrate()`. (#561, @ntolley) + # 0.5.0 ### API changes diff --git a/jaxley/integrate.py b/jaxley/integrate.py index c068ec15..f8c5b1a2 100644 --- a/jaxley/integrate.py +++ b/jaxley/integrate.py @@ -231,8 +231,13 @@ def integrate( if module.recordings.empty: raise ValueError("No recordings are set. Please set them.") - rec_inds = module.recordings.rec_index.to_numpy() - rec_states = module.recordings.state.to_numpy() + recording_df = module.recordings.reset_index(drop=True) + rec_states, rec_inds, group_inds = list(), list(), list() + for state, df_group in recording_df.groupby("state"): + rec_states.append(state) + rec_inds.append(df_group.rec_index.to_numpy()) + group_inds.extend(df_group.index.to_list()) + sort_inds = jnp.argsort(jnp.asarray(group_inds)) # Shorten or pad stimulus depending on `t_max`. if t_max is not None: @@ -260,7 +265,7 @@ def integrate( def _body_fun(state, externals): state = step_fn(state, all_params, externals, external_inds, delta_t) - recs = jnp.asarray( + recs = jnp.concatenate( [ state[rec_state][rec_ind] for rec_state, rec_ind in zip(rec_states, rec_inds) @@ -294,7 +299,7 @@ def _body_fun(state, externals): externals[key] = jnp.concatenate([externals[key], dummy_external]) # Record the initial state. - init_recs = jnp.asarray( + init_recs = jnp.concatenate( [ all_states[rec_state][rec_ind] for rec_state, rec_ind in zip(rec_states, rec_inds) @@ -311,4 +316,5 @@ def _body_fun(state, externals): nested_lengths=checkpoint_lengths, ) recs = jnp.concatenate([init_recording, recordings[:nsteps_to_return]], axis=0).T + recs = recs[sort_inds, :] # Sort recordings back to order that was set by user. return (recs, all_states) if return_states else recs