Skip to content

Commit

Permalink
a test for the groupviews
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Dec 19, 2023
1 parent c1f29b7 commit a6744cb
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions tests/test_make_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,32 @@ def test_make_trainable_corresponds_to_set_pospischil():
voltages1 = jx.integrate(net1)
voltages2 = jx.integrate(net2)
assert np.max(np.abs(voltages1 - voltages2)) < 1e-8


def test_group_trainable_corresponds_to_set():
"""Use `GroupView` and make it trainable; test if it gives the same as `set`."""

def build_net():
comp = jx.Compartment()
branch = jx.Branch(comp, nseg=4)
cell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1])
net = jx.Network([cell for _ in range(4)])
net.cell(0).add_to_group("test")
net.cell(1).branch(2).add_to_group("test")
return net

net1 = build_net()

net1.test.make_trainable("radius")
params = net1.get_parameters()
params[0]["radius"] = params[0]["radius"].at[:].set(2.5)
net1.to_jax()
all_parameters1 = net1.get_all_parameters(params)

net2 = build_net()
net2.test.set("radius", 2.5)
params = net2.get_parameters()
net2.to_jax()
all_parameters2 = net2.get_all_parameters(params)

assert np.allclose(all_parameters1["radius"], all_parameters2["radius"])

0 comments on commit a6744cb

Please sign in to comment.