Skip to content

Commit

Permalink
Write trainables to module (#470)
Browse files Browse the repository at this point in the history
* Write trainables to module

* Allow to write the trainables into the module
  • Loading branch information
michaeldeistler authored Oct 28, 2024
1 parent 3bf221c commit de1e9f2
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 1 deletion.
52 changes: 51 additions & 1 deletion jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
convert_point_process_to_distributed,
interpolate_xyz,
loc_of_index,
params_to_pstate,
query_channel_states_and_params,
v_interp,
)
Expand Down Expand Up @@ -681,8 +682,8 @@ def to_jax(self):
self.base.jaxedges = {}
edges = self.base.edges.to_dict(orient="list")
for i, synapse in enumerate(self.base.synapses):
condition = np.asarray(edges["type_ind"]) == i
for key in synapse.synapse_params:
condition = np.asarray(edges["type_ind"]) == i
self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])
for key in synapse.synapse_states:
self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])
Expand Down Expand Up @@ -1044,6 +1045,55 @@ def make_trainable(
f"Number of newly added trainable parameters: {num_created_parameters}. Total number of trainable parameters: {self.base.num_trainable_params}"
)

def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]):
"""Write the trainables into `.nodes` and `.edges`.
This allows to, e.g., visualize trained networks with `.vis()`.
Args:
trainable_params: The trainable parameters returned by `get_parameters()`.
"""
# We do not support views. Why? `jaxedges` does not have any NaN
# elements, whereas edges does. Because of this, we already need special
# treatment to make this function work, and it would be an even bigger hassle
# if we wanted to support this.
assert self.__class__.__name__ in [
"Compartment",
"Branch",
"Cell",
"Network",
], "Only supports modules."

# We could also implement this without casting the module to jax.
# However, I think it allows us to reuse as much code as possible and it avoids
# any kind of issues with indexing or parameter sharing (as this is fully
# taken care of by `get_all_parameters()`).
self.base.to_jax()
pstate = params_to_pstate(trainable_params, self.base.indices_set_by_trainables)
all_params = self.base.get_all_parameters(pstate, voltage_solver="jaxley.stone")

# The value for `delta_t` does not matter here because it is only used to
# compute the initial current. However, the initial current cannot be made
# trainable and so its value never gets used below.
all_states = self.base.get_all_states(pstate, all_params, delta_t=0.025)

# Loop only over the keys in `pstate` to avoid unnecessary computation.
for parameter in pstate:
key = parameter["key"]
if key in self.base.nodes.columns:
vals_to_set = all_params if key in all_params.keys() else all_states
self.base.nodes[key] = vals_to_set[key]

# `jaxedges` contains only non-Nan elements. This is unlike the channels where
# we allow parameter sharing.
edges = self.base.edges.to_dict(orient="list")
for i, synapse in enumerate(self.base.synapses):
condition = np.asarray(edges["type_ind"]) == i
for key in list(synapse.synapse_params.keys()):
self.base.edges.loc[condition, key] = all_params[key]
for key in list(synapse.synapse_states.keys()):
self.base.edges.loc[condition, key] = all_states[key]

def distance(self, endpoint: "View") -> float:
"""Return the direct distance between two compartments.
This does not compute the pathwise distance (which is currently not
Expand Down
73 changes: 73 additions & 0 deletions tests/test_make_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
from copy import copy

import jax.numpy as jnp
import numpy as np
import pytest

import jaxley as jx
from jaxley.channels import HH, K, Na
Expand Down Expand Up @@ -421,3 +423,74 @@ def simulate(params):
parameters = net.get_parameters()
v = simulate(parameters)
assert np.invert(np.any(np.isnan(v))), "Found NaN in voltage."


def test_write_trainables():
"""Test whether `write_trainables()` gives the same result as using the trainables."""
comp = jx.Compartment()
branch = jx.Branch(comp, 4)
cell = jx.Cell(branch, [-1, 0])
net = jx.Network([cell for _ in range(2)])
connect(
net.cell(0).branch(0).loc(0.9),
net.cell(1).branch(1).loc(0.1),
IonotropicSynapse(),
)
connect(
net.cell(1).branch(0).loc(0.1),
net.cell(0).branch(1).loc(0.3),
TestSynapse(),
)
connect(
net.cell(0).branch(0).loc(0.3),
net.cell(0).branch(1).loc(0.6),
TestSynapse(),
)
connect(
net.cell(1).branch(0).loc(0.6),
net.cell(1).branch(1).loc(0.9),
IonotropicSynapse(),
)
net.insert(HH())
net.cell(0).branch(0).comp(0).record()
net.cell(1).branch(0).comp(0).record()
net.cell(0).branch(0).comp(0).stimulate(jx.step_current(0.1, 4.0, 0.1, 0.025, 5.0))

net.make_trainable("radius")
net.cell(0).make_trainable("length")
net.cell("all").make_trainable("axial_resistivity")
net.cell("all").branch("all").make_trainable("HH_gNa")
net.cell("all").branch("all").make_trainable("HH_m")
net.make_trainable("IonotropicSynapse_gS")
net.make_trainable("IonotropicSynapse_s")
net.select(edges="all").make_trainable("TestSynapse_gC")
net.select(edges="all").make_trainable("TestSynapse_c")
net.cell(0).branch(0).comp(0).make_trainable("radius")

params = net.get_parameters()

# Now, we manually modify the parameters.
for p in params:
for key in p:
p[key] = p[key].at[:].set(np.random.rand())

# Test whether voltages match.
v1 = jx.integrate(net, params=params)

previous_nodes = copy(net.nodes)
previous_edges = copy(net.edges)
net.write_trainables(params)
v2 = jx.integrate(net)
assert np.max(np.abs(v1 - v2)) < 1e-8

# Test whether nodes and edges actually changed.
assert not net.nodes.equals(previous_nodes)
assert not net.edges.equals(previous_edges)

# Test whether `View` raises with `write_trainables()`.
with pytest.raises(AssertionError):
net.cell(0).write_trainables(params)

# Test whether synapse view raises an error.
with pytest.raises(AssertionError):
net.select(edges=[0, 2, 3]).write_trainables(params)

0 comments on commit de1e9f2

Please sign in to comment.