Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increasing cell count in network increases time to convert recorded arrays #559

Closed
ntolley opened this issue Dec 18, 2024 · 5 comments
Closed

Comments

@ntolley
Copy link
Contributor

ntolley commented Dec 18, 2024

I've found some peculiar behavior when simulating networks with a large number of cells. While increasing the cell count in the network has a negligible impact on the simulation time, it seems that internally the recordings need to be "converted" before further analysis.

This is seen whenever you use a function that accesses data in the recorded output (like plt.plot() or np.array()). Even stranger, it does not depend on the size of the recording. In the example below I simulate networks with increasing numbers of neurons, but only record the voltage of the first neuron.

import jaxley as jx
from jax import jit
import time
from jaxley.channels import Na
from jax import config
import numpy as np
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

cell = jx.Cell()
cell.insert(Na())

# Store simulation times, and output conversion times
sim_time_list, array_time_list = list(), list()
cell_num_sweep = np.arange(10, 1000, 100)
for num_cells in cell_num_sweep:

    net = jx.Network([cell for _ in range(num_cells)])
    def simulate(params):
        return jx.integrate(net, params=params, t_max=1_000.0, delta_t=0.025)
    jitted_simulate = jit(simulate)
    
    params = net.get_parameters()

    # Only record 1 variable
    net.delete_recordings()
    net.cell(0).record('v')

    start_time = time.time()
    v = jitted_simulate(params)
    simulate_time = time.time() - start_time
    print(f"Simulate time {simulate_time}")

    start_time = time.time()
    v_new = np.array(v)
    array_time = time.time() - start_time
    print(f"Array convert time {array_time}")

    sim_time_list.append(simulate_time)
    array_time_list.append(array_time)

It seems this linearly increases the time to convert the simulated output fron jnp.Array to np.array, but the size of the vector is the same, and the simulation time is not impacted:

import matplotlib.pyplot as plt

plt.plot(cell_num_sweep, array_time_list, label='Array Convert')
plt.plot(cell_num_sweep, sim_time_list, label='Sim Time')
plt.xlabel('Num Cells', fontsize=15)
plt.ylabel('Time (s)', fontsize=15)
plt.legend(fontsize=12)

image

@ntolley
Copy link
Contributor Author

ntolley commented Dec 18, 2024

Also this is only the case when increasing the number of cells! You can run the same test by increasing the number of branches and see that there is not impact...

@michaeldeistler
Copy link
Contributor

Hi Nick,

thanks for reporting. This is super interesting, in part because the timings look completely different on my laptop (Macbook pro M3):
timing

I think there are two interesting things:

  1. Overall, my simulation time goes up linearly, whereas for you it is the array convert time. Could you modify your code as follows and try again:
v = jitted_simulate(params).block_until_ready()

I hope that this will make the weird behavior go away. See here for why you are observing this behavior.

  1. Overall, your timings are more than one order of magnitude higher than mine. This might just be a difference in our machines, but I find the extent of the difference quite striking.

Michael

@ntolley
Copy link
Contributor Author

ntolley commented Dec 19, 2024

Really help explanation @michaeldeistler! This was definitely an issue with the asynchronous dispatch, adding block_until_ready() immediately sim time block increase linearly

In terms of different machines, I've noticed that with NEURON simulations that the newest M1+ chips just destroy any other machine in terms of CPU processing speed. The results above are from an HPC node, but I get similar timing on my local computer with an intel I5 CPU.

Not a real issue on my end as I've got GPU access, but I totally agree an order of magnitude difference is striking

@ntolley
Copy link
Contributor Author

ntolley commented Dec 19, 2024

Also I'm profiling the performance hit when increasing the number of recordings, I'll put together an example soon. Feel free to either close this issue, or I can just continue the discussion here (and possibly adjust the title)

@ntolley
Copy link
Contributor Author

ntolley commented Dec 20, 2024

Actually I just opened a PR addressing the recording issue (#561) so I think I'll go ahead and close this issue since it was mainly my own confusion on how jax works. Thanks a bunch for walking me through it!

@ntolley ntolley closed this as completed Dec 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants