diff --git a/tests/test_grad.py b/tests/test_grad.py new file mode 100644 index 000000000..a76b4a336 --- /dev/null +++ b/tests/test_grad.py @@ -0,0 +1,41 @@ +import jax + +jax.config.update("jax_enable_x64", True) +jax.config.update("jax_platform_name", "cpu") + +import jax.numpy as jnp +from jax import value_and_grad + +import jaxley as jx +from jaxley.channels import HH + + +def test_grad_against_finite_diff_initial_state(): + comp = jx.Compartment() + comp.insert(HH()) + comp.record() + comp.stimulate(jx.step_current(0.1, 0.2, 0.1, 0.025, 5.0)) + + def simulate(): + return jnp.sum(jx.integrate(comp)) + + # Finite differences. + comp.set("HH_m", 0.2) + r1 = simulate() + comp.set("HH_m", 0.21) + r2 = simulate() + finitediff_grad = (r2 - r1) / 0.01 + + # Autodiff gradient. + def simulate(params): + return jnp.sum(jx.integrate(comp, params=params)) + + grad_fn = value_and_grad(simulate) + comp.set("HH_m", 0.2) + comp.make_trainable("HH_m") + params = comp.get_parameters() + v, g = grad_fn(params) + autodiff_grad = g[0]["HH_m"] + + # Less than 5% error on the gradient difference. + assert jnp.abs(autodiff_grad - finitediff_grad) / autodiff_grad < 0.05