From 657a9010e6e8c03f645fd1ce75af585a9009fb57 Mon Sep 17 00:00:00 2001 From: Michael Deistler Date: Tue, 14 Nov 2023 09:41:25 +0100 Subject: [PATCH] Add test for equaivalence of APIs --- tests/test_api_equivalence.py | 38 +++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 tests/test_api_equivalence.py diff --git a/tests/test_api_equivalence.py b/tests/test_api_equivalence.py new file mode 100644 index 000000000..5ab3ea0f2 --- /dev/null +++ b/tests/test_api_equivalence.py @@ -0,0 +1,38 @@ +import jax + +jax.config.update("jax_enable_x64", True) +jax.config.update("jax_platform_name", "cpu") + +import jax.numpy as jnp + +import neurax as nx + + +def test_api_equivalence(): + """Test the API for recording and stimulating.""" + nseg_per_branch = 2 + depth = 2 + dt = 0.025 + + parents = [-1] + [b // 2 for b in range(0, 2**depth - 2)] + parents = jnp.asarray(parents) + num_branches = len(parents) + + comp = nx.Compartment().initialize() + + branch1 = nx.Branch([comp for _ in range(nseg_per_branch)]).initialize() + cell1 = nx.Cell([branch1 for _ in range(num_branches)], parents=parents).initialize() + + branch2 = nx.Branch(comp, nseg=nseg_per_branch).initialize() + cell2 = nx.Cell(branch2, parents=parents).initialize() + + cell1.branch(2).comp(0.4).record() + cell2.branch(2).comp(0.4).record() + + current = nx.step_current(0.5, 1.0, 1.0, dt, 3.0) + cell1.branch(1).comp(1.0).stimulate(current) + cell2.branch(1).comp(1.0).stimulate(current) + + voltages1 = nx.integrate(cell1, delta_t=dt) + voltages2 = nx.integrate(cell2, delta_t=dt) + assert jnp.max(jnp.abs(voltages1 - voltages2)) < 1e-8, "Voltages do not match between APIs."