Skip to content

Commit

Permalink
update all test scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Nov 3, 2023
1 parent 36d41c1 commit 37b2e46
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 34 deletions.
21 changes: 10 additions & 11 deletions tests/neurax_vs_neuron/test_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,11 @@ def _run_neurax(i_delay, i_dur, i_amp, dt, t_max):
branch.set_states("n", 0.3644787002343737)
branch.set_states("voltages", -62.0)

stims = [nx.Stimulus(0, 0, 0.0, nx.step_current(i_delay, i_dur, i_amp, time_vec))]
recs = [nx.Recording(0, 0, 0.0), nx.Recording(0, 0, 1.0)]

voltages = nx.integrate(branch, stims, recs, delta_t=dt)
branch.comp(0.0).stimulate(nx.step_current(i_delay, i_dur, i_amp, time_vec))
branch.comp(0.0).record()
branch.comp(1.0).record()

voltages = nx.integrate(branch, delta_t=dt)

return voltages

Expand Down Expand Up @@ -191,14 +192,12 @@ def _neurax_complex(i_delay, i_dur, i_amp, dt, t_max, diams):
branch = branch.initialize()

# 0.02 is fine here because nseg=8 for NEURON, but nseg=16 for neurax.
stims = [nx.Stimulus(0, 0, 0.02, nx.step_current(i_delay, i_dur, i_amp, time_vec))]
recs = [
nx.Recording(0, 0, 0.02),
nx.Recording(0, 0, 0.52),
nx.Recording(0, 0, 0.98),
]
branch.comp(0.02).stimulate(nx.step_current(i_delay, i_dur, i_amp, time_vec))
branch.comp(0.02).record()
branch.comp(0.52).record()
branch.comp(0.98).record()

s = nx.integrate(branch, stims, recs, delta_t=dt, tridiag_solver="thomas")
s = nx.integrate(branch, delta_t=dt, tridiag_solver="thomas")
return s


Expand Down
10 changes: 6 additions & 4 deletions tests/neurax_vs_neuron/test_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,12 @@ def _run_neurax(i_delay, i_dur, i_amp, dt, t_max):
cell.set_states("n", 0.3644787002343737)
cell.set_states("voltages", -62.0)

stims = [nx.Stimulus(0, 0, 0.0, nx.step_current(i_delay, i_dur, i_amp, time_vec))]
recs = [nx.Recording(0, 0, 0.0), nx.Recording(0, 1, 1.0), nx.Recording(0, 2, 1.0)]

voltages = nx.integrate(cell, stims, recs, delta_t=dt)
cell.branch(0).comp(0.0).stimulate(nx.step_current(i_delay, i_dur, i_amp, time_vec))
cell.branch(0).comp(0.0).record()
cell.branch(1).comp(1.0).record()
cell.branch(2).comp(1.0).record()

voltages = nx.integrate(cell, delta_t=dt)
return voltages


Expand Down
10 changes: 5 additions & 5 deletions tests/neurax_vs_neuron/test_comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_similarity():
t_max = 10.0 # ms

voltages_neurax = _run_neurax(i_delay, i_dur, i_amp, dt, t_max)
voltages_neuron = _run_neurax(i_delay, i_dur, i_amp, dt, t_max)
voltages_neuron = _run_neuron(i_delay, i_dur, i_amp, dt, t_max)

assert np.mean(np.abs(voltages_neurax - voltages_neuron)) < 1.0

Expand All @@ -51,14 +51,14 @@ def _run_neurax(i_delay, i_dur, i_amp, dt, t_max):
comp.set_states("n", 0.3644787002343737)
comp.set_states("voltages", -62.0)

stims = [nx.Stimulus(0, 0, 0.0, nx.step_current(i_delay, i_dur, i_amp, time_vec))]
recs = [nx.Recording(0, 0, 0.0)]
comp.stimulate(nx.step_current(i_delay, i_dur, i_amp, time_vec))
comp.record()

voltages = nx.integrate(comp, stims, recs, delta_t=dt)
voltages = nx.integrate(comp, delta_t=dt)
return voltages


def _run_neurax(i_delay, i_dur, i_amp, dt, t_max):
def _run_neuron(i_delay, i_dur, i_amp, dt, t_max):
h.dt = dt

for sec in h.allsec():
Expand Down
16 changes: 6 additions & 10 deletions tests/test_cell_matches_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,11 @@ def _run_long_branch(time_vec):
branch.comp("all").make_trainable("radius", 1.0)
params = branch.get_parameters()

stims = [
nx.Stimulus(0, 0, 0.0, nx.step_current(0.5, 5.0, 0.1, time_vec)),
]
recs = [nx.Recording(0, 0, 0.0)]
branch.comp(0.0).record()
branch.comp(0.0).stimulate(nx.step_current(0.5, 5.0, 0.1, time_vec))

def loss(params):
s = nx.integrate(branch, stims, recs, params=params)
s = nx.integrate(branch, params=params)
return s[0, -1]

jitted_loss_grad = jit(value_and_grad(loss))
Expand All @@ -49,13 +47,11 @@ def _run_short_branches(time_vec):
cell.branch("all").comp("all").make_trainable("radius", 1.0)
params = cell.get_parameters()

stims = [
nx.Stimulus(0, 0, 0.0, nx.step_current(0.5, 5.0, 0.1, time_vec)),
]
recs = [nx.Recording(0, 0, 0.0)]
cell.branch(0).comp(0.0).record()
cell.branch(0).comp(0.0).stimulate(nx.step_current(0.5, 5.0, 0.1, time_vec))

def loss(params):
s = nx.integrate(cell, stims, recs, params=params)
s = nx.integrate(cell, params=params)
return s[0, -1]

jitted_loss_grad = jit(value_and_grad(loss))
Expand Down
9 changes: 5 additions & 4 deletions tests/test_swc.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,11 @@ def test_swc_voltages():
cell.set_states("h", 0.4889)
cell.set_states("n", 0.3644787)

stims = [nx.Stimulus(0, 1, 0.05, nx.step_current(i_delay, i_dur, i_amp, time_vec))]
recs = [nx.Recording(0, i, 0.05) for i in trunk_inds + tuft_inds + basal_inds]

voltages_neurax = nx.integrate(cell, stims, recs, delta_t=dt)
cell.branch(1).comp(0.05).stimulate(nx.step_current(i_delay, i_dur, i_amp, time_vec))
for i in trunk_inds + tuft_inds + basal_inds:
cell.branch(i).comp(0.05).record()

voltages_neurax = nx.integrate(cell, delta_t=dt)

################### NEURON #################
stim = h.IClamp(h.soma[0](0.1))
Expand Down

0 comments on commit 37b2e46

Please sign in to comment.