Skip to content

Commit

Permalink
fix all tests
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Nov 9, 2023
1 parent d0cada4 commit 440abca
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 20 deletions.
8 changes: 2 additions & 6 deletions tests/neurax_vs_neuron/test_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ def test_similarity():


def _run_neurax(i_delay, i_dur, i_amp, dt, t_max):
time_vec = jnp.arange(0.0, t_max + dt, dt)

nseg_per_branch = 8
comp = nx.Compartment().initialize()
branch = nx.Branch([comp for _ in range(nseg_per_branch)]).initialize()
Expand All @@ -60,7 +58,7 @@ def _run_neurax(i_delay, i_dur, i_amp, dt, t_max):
branch.set_states("n", 0.3644787002343737)
branch.set_states("voltages", -62.0)

branch.comp(0.0).stimulate(nx.step_current(i_delay, i_dur, i_amp, time_vec))
branch.comp(0.0).stimulate(nx.step_current(i_delay, i_dur, i_amp, dt, t_max))
branch.comp(0.0).record()
branch.comp(1.0).record()

Expand Down Expand Up @@ -159,8 +157,6 @@ def test_similarity_complex():


def _neurax_complex(i_delay, i_dur, i_amp, dt, t_max, diams):
time_vec = np.arange(0, t_max + dt, dt)

nseg = 16
comp = nx.Compartment()
branch = nx.Branch([comp for _ in range(nseg)])
Expand Down Expand Up @@ -192,7 +188,7 @@ def _neurax_complex(i_delay, i_dur, i_amp, dt, t_max, diams):
branch = branch.initialize()

# 0.02 is fine here because nseg=8 for NEURON, but nseg=16 for neurax.
branch.comp(0.02).stimulate(nx.step_current(i_delay, i_dur, i_amp, time_vec))
branch.comp(0.02).stimulate(nx.step_current(i_delay, i_dur, i_amp, dt, t_max))
branch.comp(0.02).record()
branch.comp(0.52).record()
branch.comp(0.98).record()
Expand Down
3 changes: 1 addition & 2 deletions tests/neurax_vs_neuron/test_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def test_similarity():


def _run_neurax(i_delay, i_dur, i_amp, dt, t_max):
time_vec = jnp.arange(0.0, t_max + dt, dt)

nseg_per_branch = 8
comp = nx.Compartment().initialize()
Expand All @@ -55,7 +54,7 @@ def _run_neurax(i_delay, i_dur, i_amp, dt, t_max):
cell.set_states("n", 0.3644787002343737)
cell.set_states("voltages", -62.0)

cell.branch(0).comp(0.0).stimulate(nx.step_current(i_delay, i_dur, i_amp, time_vec))
cell.branch(0).comp(0.0).stimulate(nx.step_current(i_delay, i_dur, i_amp, dt, t_max))
cell.branch(0).comp(0.0).record()
cell.branch(1).comp(1.0).record()
cell.branch(2).comp(1.0).record()
Expand Down
3 changes: 1 addition & 2 deletions tests/neurax_vs_neuron/test_comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def test_similarity():


def _run_neurax(i_delay, i_dur, i_amp, dt, t_max):
time_vec = jnp.arange(0.0, t_max + dt, dt)

comp = nx.Compartment().initialize()
comp.insert(HHChannel())
Expand All @@ -51,7 +50,7 @@ def _run_neurax(i_delay, i_dur, i_amp, dt, t_max):
comp.set_states("n", 0.3644787002343737)
comp.set_states("voltages", -62.0)

comp.stimulate(nx.step_current(i_delay, i_dur, i_amp, time_vec))
comp.stimulate(nx.step_current(i_delay, i_dur, i_amp, dt, t_max))
comp.record()

voltages = nx.integrate(comp, delta_t=dt)
Expand Down
13 changes: 6 additions & 7 deletions tests/test_cell_matches_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from neurax.channels import HHChannel


def _run_long_branch(time_vec):
def _run_long_branch(dt, t_max):
nseg_per_branch = 8

comp = nx.Compartment()
Expand All @@ -23,7 +23,7 @@ def _run_long_branch(time_vec):
params = branch.get_parameters()

branch.comp(0.0).record()
branch.comp(0.0).stimulate(nx.step_current(0.5, 5.0, 0.1, time_vec))
branch.comp(0.0).stimulate(nx.step_current(0.5, 5.0, 0.1, dt, t_max))

def loss(params):
s = nx.integrate(branch, params=params)
Expand All @@ -35,7 +35,7 @@ def loss(params):
return l, g


def _run_short_branches(time_vec):
def _run_short_branches(dt, t_max):
nseg_per_branch = 4
parents = jnp.asarray([-1, 0])

Expand All @@ -48,7 +48,7 @@ def _run_short_branches(time_vec):
params = cell.get_parameters()

cell.branch(0).comp(0.0).record()
cell.branch(0).comp(0.0).stimulate(nx.step_current(0.5, 5.0, 0.1, time_vec))
cell.branch(0).comp(0.0).stimulate(nx.step_current(0.5, 5.0, 0.1, dt, t_max))

def loss(params):
s = nx.integrate(cell, params=params)
Expand All @@ -64,9 +64,8 @@ def test_equivalence():
"""Test whether a single long branch matches a cell of two shorter branches."""
dt = 0.025
t_max = 5.0 # ms
time_vec = jnp.arange(0.0, t_max + dt, dt)
l1, g1 = _run_long_branch(time_vec)
l2, g2 = _run_short_branches(time_vec)
l1, g1 = _run_long_branch(dt, t_max)
l2, g2 = _run_short_branches(dt, t_max)

assert np.allclose(l1, l2), "Losses do not match."

Expand Down
4 changes: 1 addition & 3 deletions tests/test_swc.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ def test_swc_voltages():
t_max = 20.0
dt = 0.025

time_vec = np.arange(0, t_max + dt, dt)

nseg_per_branch = 8

##################### NEURON ##################
Expand Down Expand Up @@ -161,7 +159,7 @@ def test_swc_voltages():
cell.set_states("n", 0.3644787)

cell.branch(1).comp(0.05).stimulate(
nx.step_current(i_delay, i_dur, i_amp, time_vec)
nx.step_current(i_delay, i_dur, i_amp, dt, t_max)
)
for i in trunk_inds + tuft_inds + basal_inds:
cell.branch(i).comp(0.05).record()
Expand Down

0 comments on commit 440abca

Please sign in to comment.