Skip to content

Commit

Permalink
Add test for autodiff on initial states
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Dec 18, 2023
1 parent 24d095f commit f9c217f
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions tests/test_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import jax

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

import jax.numpy as jnp
from jax import value_and_grad

import jaxley as jx
from jaxley.channels import HH



def test_grad_against_finite_diff_initial_state():
comp = jx.Compartment()
comp.insert(HH())
comp.record()
comp.stimulate(jx.step_current(0.1, 0.2, 0.1, 0.025, 5.0))

def simulate():
return jnp.sum(jx.integrate(comp))

# Finite differences.
comp.set("HH_m", 0.2)
r1 = simulate()
comp.set("HH_m", 0.21)
r2 = simulate()
finitediff_grad = (r2 - r1) / 0.01

# Autodiff gradient.
def simulate(params):
return jnp.sum(jx.integrate(comp, params=params))

grad_fn = value_and_grad(simulate)
comp.set("HH_m", 0.2)
comp.make_trainable("HH_m")
params = comp.get_parameters()
v, g = grad_fn(params)
autodiff_grad = g[0]["HH_m"]

# Less than 5% error on the gradient difference.
assert jnp.abs(autodiff_grad - finitediff_grad) / autodiff_grad < 0.1

0 comments on commit f9c217f

Please sign in to comment.