From a23c586e42d77942d1bbbbb526c83cd05df5bc8b Mon Sep 17 00:00:00 2001 From: Michael Deistler Date: Thu, 9 Nov 2023 10:07:55 +0100 Subject: [PATCH] Change API of some tests --- tests/neurax_identical/test_basic_modules.py | 16 ++++------------ tests/neurax_identical/test_radius_and_length.py | 16 ++++------------ tests/neurax_identical/test_swc.py | 6 ++---- 3 files changed, 10 insertions(+), 28 deletions(-) diff --git a/tests/neurax_identical/test_basic_modules.py b/tests/neurax_identical/test_basic_modules.py index 3664c513..de7278db 100644 --- a/tests/neurax_identical/test_basic_modules.py +++ b/tests/neurax_identical/test_basic_modules.py @@ -18,9 +18,7 @@ def test_compartment(): dt = 0.025 # ms t_max = 5.0 # ms - - time_vec = jnp.arange(0.0, t_max + dt, dt) - current = nx.step_current(0.5, 1.0, 0.02, time_vec) + current = nx.step_current(0.5, 1.0, 0.02, dt, t_max) comp = nx.Compartment().initialize() comp.insert(HHChannel()) @@ -55,9 +53,7 @@ def test_branch(): nseg_per_branch = 2 dt = 0.025 # ms t_max = 5.0 # ms - - time_vec = jnp.arange(0.0, t_max + dt, dt) - current = nx.step_current(0.5, 1.0, 0.02, time_vec) + current = nx.step_current(0.5, 1.0, 0.02, dt, t_max) comp = nx.Compartment().initialize() branch = nx.Branch([comp for _ in range(nseg_per_branch)]).initialize() @@ -93,9 +89,7 @@ def test_cell(): nseg_per_branch = 2 dt = 0.025 # ms t_max = 5.0 # ms - - time_vec = jnp.arange(0.0, t_max + dt, dt) - current = nx.step_current(0.5, 1.0, 0.02, time_vec) + current = nx.step_current(0.5, 1.0, 0.02, dt, t_max) depth = 2 parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] @@ -135,9 +129,7 @@ def test_net(): nseg_per_branch = 2 dt = 0.025 # ms t_max = 5.0 # ms - - time_vec = jnp.arange(0.0, t_max + dt, dt) - current = nx.step_current(0.5, 1.0, 0.02, time_vec) + current = nx.step_current(0.5, 1.0, 0.02, dt, t_max) depth = 2 parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] diff --git a/tests/neurax_identical/test_radius_and_length.py b/tests/neurax_identical/test_radius_and_length.py index 748787d7..4056e565 100644 --- a/tests/neurax_identical/test_radius_and_length.py +++ b/tests/neurax_identical/test_radius_and_length.py @@ -18,9 +18,7 @@ def test_radius_and_length_compartment(): dt = 0.025 # ms t_max = 5.0 # ms - - time_vec = jnp.arange(0.0, t_max + dt, dt) - current = nx.step_current(0.5, 1.0, 0.02, time_vec) + current = nx.step_current(0.5, 1.0, 0.02, dt, t_max) comp = nx.Compartment().initialize() @@ -60,9 +58,7 @@ def test_radius_and_length_branch(): nseg_per_branch = 2 dt = 0.025 # ms t_max = 5.0 # ms - - time_vec = jnp.arange(0.0, t_max + dt, dt) - current = nx.step_current(0.5, 1.0, 0.02, time_vec) + current = nx.step_current(0.5, 1.0, 0.02, dt, t_max) comp = nx.Compartment().initialize() branch = nx.Branch([comp for _ in range(nseg_per_branch)]).initialize() @@ -103,9 +99,7 @@ def test_radius_and_length_cell(): nseg_per_branch = 2 dt = 0.025 # ms t_max = 5.0 # ms - - time_vec = jnp.arange(0.0, t_max + dt, dt) - current = nx.step_current(0.5, 1.0, 0.02, time_vec) + current = nx.step_current(0.5, 1.0, 0.02, dt, t_max) depth = 2 parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] @@ -151,9 +145,7 @@ def test_radius_and_length_net(): nseg_per_branch = 2 dt = 0.025 # ms t_max = 5.0 # ms - - time_vec = jnp.arange(0.0, t_max + dt, dt) - current = nx.step_current(0.5, 1.0, 0.02, time_vec) + current = nx.step_current(0.5, 1.0, 0.02, dt, t_max) depth = 2 parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] diff --git a/tests/neurax_identical/test_swc.py b/tests/neurax_identical/test_swc.py index f49c1405..f76521f6 100644 --- a/tests/neurax_identical/test_swc.py +++ b/tests/neurax_identical/test_swc.py @@ -18,8 +18,7 @@ def test_swc_cell(): dt = 0.025 # ms t_max = 5.0 # ms - time_vec = jnp.arange(0.0, t_max + dt, dt) - current = nx.step_current(0.5, 1.0, 0.2, time_vec) + current = nx.step_current(0.5, 1.0, 0.02, dt, t_max) cell = nx.read_swc("../morph.swc", nseg=2, max_branch_len=300.0) cell.insert(HHChannel()) @@ -53,8 +52,7 @@ def test_swc_cell(): def test_swc_net(): dt = 0.025 # ms t_max = 5.0 # ms - time_vec = jnp.arange(0.0, t_max + dt, dt) - current = nx.step_current(0.5, 1.0, 0.2, time_vec) + current = nx.step_current(0.5, 1.0, 0.02, dt, t_max) cell1 = nx.read_swc("../morph.swc", nseg=2, max_branch_len=300.0) cell2 = nx.read_swc("../morph.swc", nseg=2, max_branch_len=300.0)