-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Vectorize recording during integrate
#561
base: main
Are you sure you want to change the base?
Conversation
Here's some code to see how recordings impact speed: the comparison is pretty extreme but it gets the point across: import jaxley as jx
import time
from jaxley.channels import Na
from jaxley.synapses import IonotropicSynapse
from jaxley.connect import fully_connect
from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")
cell = jx.Cell()
cell.insert(Na())
sim_time_list, array_time_list = list(), list()
net = jx.Network([cell for _ in range(100)])
fully_connect(net, net, IonotropicSynapse())
params = net.get_parameters()
# Small number of recordings (4)
net.delete_recordings()
net.cell(range(2)).record('i_IonotropicSynapse')
start_time = time.time()
v = jx.integrate(net, params=params, t_max=10.0, delta_t=0.025)
simulate_time = time.time() - start_time
print(simulate_time)
# Huge number of recordings (10,000)
net.delete_recordings()
net.record('i_IonotropicSynapse')
start_time = time.time()
v = jx.integrate(net, params=params, t_max=10.0, delta_t=0.025)
simulate_time = time.time() - start_time
print(simulate_time) |
And finally here's my timing results for the two different branches. There's still a slight slowdown with more recordings, but it's an order of magnitude faster when recording from 10,000 synapses so I'd say it's an improvement 😄 Here's the results for a 10 ms simulation
|
integrate
integrate
@michaeldeistler @jnsbck since this is mainly a performance boost I'm not sure how it should be tested. I feel like testing the execution time directly could be very brittle for running tests locally. Unless you have some ideas, perhaps it isn't necessary? |
Unfortunately that performance hit does scale with time, here's the results for a 100 ms simulation
So there's still some optimizations to be made... |
While testing the recording from many many states, I noticed I was experiencing some serious performance hits during simulation. @jnsbck suggested that this may be due to a for loop over recordings that occurs during the
integrate
call. This is an attempt to vectorize that update only indexing each unique state once.