From 44a3e5510852d05da456f906cab658e538d685bc Mon Sep 17 00:00:00 2001 From: Nicholas Tolley Date: Fri, 20 Dec 2024 13:53:25 -0500 Subject: [PATCH 1/3] vectorized recording updates --- jaxley/integrate.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/jaxley/integrate.py b/jaxley/integrate.py index c068ec15..06f0454a 100644 --- a/jaxley/integrate.py +++ b/jaxley/integrate.py @@ -231,8 +231,12 @@ def integrate( if module.recordings.empty: raise ValueError("No recordings are set. Please set them.") - rec_inds = module.recordings.rec_index.to_numpy() - rec_states = module.recordings.state.to_numpy() + recording_df = module.recordings.reset_index(drop=True) + rec_states, rec_inds, sort_inds = list(), list(), list() + for state, df_group in recording_df.groupby("state"): + rec_states.append(state) + rec_inds.append(df_group.rec_index.to_numpy()) + sort_inds.extend(df_group.index.to_list()) # Shorten or pad stimulus depending on `t_max`. if t_max is not None: @@ -260,7 +264,7 @@ def integrate( def _body_fun(state, externals): state = step_fn(state, all_params, externals, external_inds, delta_t) - recs = jnp.asarray( + recs = jnp.concatenate( [ state[rec_state][rec_ind] for rec_state, rec_ind in zip(rec_states, rec_inds) @@ -294,7 +298,7 @@ def _body_fun(state, externals): externals[key] = jnp.concatenate([externals[key], dummy_external]) # Record the initial state. - init_recs = jnp.asarray( + init_recs = jnp.concatenate( [ all_states[rec_state][rec_ind] for rec_state, rec_ind in zip(rec_states, rec_inds) @@ -311,4 +315,5 @@ def _body_fun(state, externals): nested_lengths=checkpoint_lengths, ) recs = jnp.concatenate([init_recording, recordings[:nsteps_to_return]], axis=0).T + # recs = recs[sort_inds, :] # Sort recordings back to order that was set by user. return (recs, all_states) if return_states else recs From 1aa2523ae05abf0072c05c638bb7ad275c7fc53d Mon Sep 17 00:00:00 2001 From: Nicholas Tolley Date: Fri, 20 Dec 2024 14:11:23 -0500 Subject: [PATCH 2/3] update changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b4fa8ed..54ef89fa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,10 @@ net.vis() - changelog added to CI (#537, #558, @jnsbck) +### Code Health + +- Vectorize recording updates during `integrate()`. (#561, @ntolley) + # 0.5.0 ### API changes From 9595a213869d94a01e919369138f8f1c79cdf18a Mon Sep 17 00:00:00 2001 From: Nicholas Tolley Date: Fri, 20 Dec 2024 14:56:00 -0500 Subject: [PATCH 3/3] fix sorting --- jaxley/integrate.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/jaxley/integrate.py b/jaxley/integrate.py index 06f0454a..f8c5b1a2 100644 --- a/jaxley/integrate.py +++ b/jaxley/integrate.py @@ -232,11 +232,12 @@ def integrate( if module.recordings.empty: raise ValueError("No recordings are set. Please set them.") recording_df = module.recordings.reset_index(drop=True) - rec_states, rec_inds, sort_inds = list(), list(), list() + rec_states, rec_inds, group_inds = list(), list(), list() for state, df_group in recording_df.groupby("state"): rec_states.append(state) rec_inds.append(df_group.rec_index.to_numpy()) - sort_inds.extend(df_group.index.to_list()) + group_inds.extend(df_group.index.to_list()) + sort_inds = jnp.argsort(jnp.asarray(group_inds)) # Shorten or pad stimulus depending on `t_max`. if t_max is not None: @@ -315,5 +316,5 @@ def _body_fun(state, externals): nested_lengths=checkpoint_lengths, ) recs = jnp.concatenate([init_recording, recordings[:nsteps_to_return]], axis=0).T - # recs = recs[sort_inds, :] # Sort recordings back to order that was set by user. + recs = recs[sort_inds, :] # Sort recordings back to order that was set by user. return (recs, all_states) if return_states else recs