diff --git a/jaxley/integrate.py b/jaxley/integrate.py index f221fc25..df63813c 100644 --- a/jaxley/integrate.py +++ b/jaxley/integrate.py @@ -41,8 +41,18 @@ def integrate( assert module.initialized, "Module is not initialized, run `.initialize()`." - i_current = module.currents.T if currents is None else currents.T - i_inds = module.current_inds.comp_index.to_numpy() + if module.currents is not None: + # At least one stimulus was inserted. + i_current = currents.T if currents is not None else module.currents.T + i_inds = module.current_inds.comp_index.to_numpy() + else: + # No stimulus was inserted. + i_current = jnp.asarray([[]]).astype("int") + i_inds = jnp.asarray([]).astype("int") + assert ( + t_max is not None + ), "If no stimulus is inserted that you have to specify the simulation duration at `jx.integrate(..., t_max=)`." + rec_inds = module.recordings.comp_index.to_numpy() # Shorten or pad stimulus depending on `t_max`.