diff --git a/tests/neurax_vs_neuron/test_branch.py b/tests/neurax_vs_neuron/test_branch.py index d73ac5d0..56249b0f 100644 --- a/tests/neurax_vs_neuron/test_branch.py +++ b/tests/neurax_vs_neuron/test_branch.py @@ -60,10 +60,11 @@ def _run_neurax(i_delay, i_dur, i_amp, dt, t_max): branch.set_states("n", 0.3644787002343737) branch.set_states("voltages", -62.0) - stims = [nx.Stimulus(0, 0, 0.0, nx.step_current(i_delay, i_dur, i_amp, time_vec))] - recs = [nx.Recording(0, 0, 0.0), nx.Recording(0, 0, 1.0)] - - voltages = nx.integrate(branch, stims, recs, delta_t=dt) + branch.comp(0.0).stimulate(nx.step_current(i_delay, i_dur, i_amp, time_vec)) + branch.comp(0.0).record() + branch.comp(1.0).record() + + voltages = nx.integrate(branch, delta_t=dt) return voltages @@ -191,14 +192,12 @@ def _neurax_complex(i_delay, i_dur, i_amp, dt, t_max, diams): branch = branch.initialize() # 0.02 is fine here because nseg=8 for NEURON, but nseg=16 for neurax. - stims = [nx.Stimulus(0, 0, 0.02, nx.step_current(i_delay, i_dur, i_amp, time_vec))] - recs = [ - nx.Recording(0, 0, 0.02), - nx.Recording(0, 0, 0.52), - nx.Recording(0, 0, 0.98), - ] + branch.comp(0.02).stimulate(nx.step_current(i_delay, i_dur, i_amp, time_vec)) + branch.comp(0.02).record() + branch.comp(0.52).record() + branch.comp(0.98).record() - s = nx.integrate(branch, stims, recs, delta_t=dt, tridiag_solver="thomas") + s = nx.integrate(branch, delta_t=dt, tridiag_solver="thomas") return s diff --git a/tests/neurax_vs_neuron/test_cell.py b/tests/neurax_vs_neuron/test_cell.py index 31737cf4..4afad2d5 100644 --- a/tests/neurax_vs_neuron/test_cell.py +++ b/tests/neurax_vs_neuron/test_cell.py @@ -55,10 +55,12 @@ def _run_neurax(i_delay, i_dur, i_amp, dt, t_max): cell.set_states("n", 0.3644787002343737) cell.set_states("voltages", -62.0) - stims = [nx.Stimulus(0, 0, 0.0, nx.step_current(i_delay, i_dur, i_amp, time_vec))] - recs = [nx.Recording(0, 0, 0.0), nx.Recording(0, 1, 1.0), nx.Recording(0, 2, 1.0)] - - voltages = nx.integrate(cell, stims, recs, delta_t=dt) + cell.branch(0).comp(0.0).stimulate(nx.step_current(i_delay, i_dur, i_amp, time_vec)) + cell.branch(0).comp(0.0).record() + cell.branch(1).comp(1.0).record() + cell.branch(2).comp(1.0).record() + + voltages = nx.integrate(cell, delta_t=dt) return voltages diff --git a/tests/neurax_vs_neuron/test_comp.py b/tests/neurax_vs_neuron/test_comp.py index 8e322d3d..343d2711 100644 --- a/tests/neurax_vs_neuron/test_comp.py +++ b/tests/neurax_vs_neuron/test_comp.py @@ -29,7 +29,7 @@ def test_similarity(): t_max = 10.0 # ms voltages_neurax = _run_neurax(i_delay, i_dur, i_amp, dt, t_max) - voltages_neuron = _run_neurax(i_delay, i_dur, i_amp, dt, t_max) + voltages_neuron = _run_neuron(i_delay, i_dur, i_amp, dt, t_max) assert np.mean(np.abs(voltages_neurax - voltages_neuron)) < 1.0 @@ -51,14 +51,14 @@ def _run_neurax(i_delay, i_dur, i_amp, dt, t_max): comp.set_states("n", 0.3644787002343737) comp.set_states("voltages", -62.0) - stims = [nx.Stimulus(0, 0, 0.0, nx.step_current(i_delay, i_dur, i_amp, time_vec))] - recs = [nx.Recording(0, 0, 0.0)] + comp.stimulate(nx.step_current(i_delay, i_dur, i_amp, time_vec)) + comp.record() - voltages = nx.integrate(comp, stims, recs, delta_t=dt) + voltages = nx.integrate(comp, delta_t=dt) return voltages -def _run_neurax(i_delay, i_dur, i_amp, dt, t_max): +def _run_neuron(i_delay, i_dur, i_amp, dt, t_max): h.dt = dt for sec in h.allsec(): diff --git a/tests/test_cell_matches_branch.py b/tests/test_cell_matches_branch.py index b9cdbfeb..603bb3f8 100644 --- a/tests/test_cell_matches_branch.py +++ b/tests/test_cell_matches_branch.py @@ -22,13 +22,11 @@ def _run_long_branch(time_vec): branch.comp("all").make_trainable("radius", 1.0) params = branch.get_parameters() - stims = [ - nx.Stimulus(0, 0, 0.0, nx.step_current(0.5, 5.0, 0.1, time_vec)), - ] - recs = [nx.Recording(0, 0, 0.0)] + branch.comp(0.0).record() + branch.comp(0.0).stimulate(nx.step_current(0.5, 5.0, 0.1, time_vec)) def loss(params): - s = nx.integrate(branch, stims, recs, params=params) + s = nx.integrate(branch, params=params) return s[0, -1] jitted_loss_grad = jit(value_and_grad(loss)) @@ -49,13 +47,11 @@ def _run_short_branches(time_vec): cell.branch("all").comp("all").make_trainable("radius", 1.0) params = cell.get_parameters() - stims = [ - nx.Stimulus(0, 0, 0.0, nx.step_current(0.5, 5.0, 0.1, time_vec)), - ] - recs = [nx.Recording(0, 0, 0.0)] + cell.branch(0).comp(0.0).record() + cell.branch(0).comp(0.0).stimulate(nx.step_current(0.5, 5.0, 0.1, time_vec)) def loss(params): - s = nx.integrate(cell, stims, recs, params=params) + s = nx.integrate(cell, params=params) return s[0, -1] jitted_loss_grad = jit(value_and_grad(loss)) diff --git a/tests/test_swc.py b/tests/test_swc.py index 45d2a468..6500eeb2 100644 --- a/tests/test_swc.py +++ b/tests/test_swc.py @@ -160,10 +160,11 @@ def test_swc_voltages(): cell.set_states("h", 0.4889) cell.set_states("n", 0.3644787) - stims = [nx.Stimulus(0, 1, 0.05, nx.step_current(i_delay, i_dur, i_amp, time_vec))] - recs = [nx.Recording(0, i, 0.05) for i in trunk_inds + tuft_inds + basal_inds] - - voltages_neurax = nx.integrate(cell, stims, recs, delta_t=dt) + cell.branch(1).comp(0.05).stimulate(nx.step_current(i_delay, i_dur, i_amp, time_vec)) + for i in trunk_inds + tuft_inds + basal_inds: + cell.branch(i).comp(0.05).record() + + voltages_neurax = nx.integrate(cell, delta_t=dt) ################### NEURON ################# stim = h.IClamp(h.soma[0](0.1))