diff --git a/jaxley/integrate.py b/jaxley/integrate.py index c068ec15..06f0454a 100644 --- a/jaxley/integrate.py +++ b/jaxley/integrate.py @@ -231,8 +231,12 @@ def integrate( if module.recordings.empty: raise ValueError("No recordings are set. Please set them.") - rec_inds = module.recordings.rec_index.to_numpy() - rec_states = module.recordings.state.to_numpy() + recording_df = module.recordings.reset_index(drop=True) + rec_states, rec_inds, sort_inds = list(), list(), list() + for state, df_group in recording_df.groupby("state"): + rec_states.append(state) + rec_inds.append(df_group.rec_index.to_numpy()) + sort_inds.extend(df_group.index.to_list()) # Shorten or pad stimulus depending on `t_max`. if t_max is not None: @@ -260,7 +264,7 @@ def integrate( def _body_fun(state, externals): state = step_fn(state, all_params, externals, external_inds, delta_t) - recs = jnp.asarray( + recs = jnp.concatenate( [ state[rec_state][rec_ind] for rec_state, rec_ind in zip(rec_states, rec_inds) @@ -294,7 +298,7 @@ def _body_fun(state, externals): externals[key] = jnp.concatenate([externals[key], dummy_external]) # Record the initial state. - init_recs = jnp.asarray( + init_recs = jnp.concatenate( [ all_states[rec_state][rec_ind] for rec_state, rec_ind in zip(rec_states, rec_inds) @@ -311,4 +315,5 @@ def _body_fun(state, externals): nested_lengths=checkpoint_lengths, ) recs = jnp.concatenate([init_recording, recordings[:nsteps_to_return]], axis=0).T + # recs = recs[sort_inds, :] # Sort recordings back to order that was set by user. return (recs, all_states) if return_states else recs