Skip to content

Commit

Permalink
test for make trainable and synapses
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Dec 11, 2023
1 parent 47ea4b2 commit 4a64cf3
Showing 1 changed file with 49 additions and 1 deletion.
50 changes: 49 additions & 1 deletion tests/test_make_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
jax.config.update("jax_platform_name", "cpu")

import jax.numpy as jnp
import numpy as np

import jaxley as jx
from jaxley.channels import HHChannel
from jaxley.synapses import GlutamateSynapse
from jaxley.synapses import GlutamateSynapse, TestSynapse


def test_make_trainable():
Expand Down Expand Up @@ -70,3 +71,50 @@ def test_make_trainable_network():
cell.get_parameters()
net.GlutamateSynapse.set_params("gS", 0.1)
assert cell.num_trainable_params == 8 # `set_params()` is ignored.


def test_diverse_synapse_types():
"""Runs `.get_all_parameters()` and checks if the output is as expected."""
comp = jx.Compartment()
branch = jx.Branch(comp, nseg=1)
cell = jx.Cell(branch, parents=[-1])

net = jx.Network([cell for _ in range(4)])
for pre_ind in [0, 1]:
for post_ind, syn in zip([2, 3], [GlutamateSynapse(), TestSynapse()]):
pre = net.cell(pre_ind).branch(0).comp(0.0)
post = net.cell(post_ind).branch(0).comp(0.0)
pre.connect(post, syn)

net.make_trainable("gS")
net.TestSynapse([0, 1]).make_trainable("gC")
assert net.num_trainable_params == 3

params = net.get_parameters()

# Modify the trainable parameters.
params[0]["gS"] = params[0]["gS"].at[:].set(2.2)
params[1]["gC"] = params[1]["gC"].at[0].set(3.3)
params[1]["gC"] = params[1]["gC"].at[1].set(4.4)
all_parameters = net.get_all_parameters(params)

assert np.all(all_parameters["radius"] == 1.0)
assert np.all(all_parameters["length"] == 10.0)
assert np.all(all_parameters["axial_resistivity"] == 5000.0)
assert np.all(all_parameters["gS"][0] == 2.2)
assert np.all(all_parameters["gS"][1] == 2.2)
assert np.all(all_parameters["gC"][0] == 3.3)
assert np.all(all_parameters["gC"][1] == 4.4)

# Add another trainable parameter and test again.
net.GlutamateSynapse(1).make_trainable("gS")
assert net.num_trainable_params == 4

params = net.get_parameters()

# Modify the trainable parameters.
params[2]["gS"] = params[2]["gS"].at[:].set(5.5)
all_parameters = net.get_all_parameters(params)
assert np.all(all_parameters["gS"][0] == 2.2)
assert np.all(all_parameters["gS"][1] == 5.5)

0 comments on commit 4a64cf3

Please sign in to comment.