From de1e9f22994a3a2dfdbe35080e0b41cdf89bebbc Mon Sep 17 00:00:00 2001 From: Michael Deistler Date: Mon, 28 Oct 2024 17:21:35 +0100 Subject: [PATCH] Write trainables to module (#470) * Write trainables to module * Allow to write the trainables into the module --- jaxley/modules/base.py | 52 ++++++++++++++++++++++++- tests/test_make_trainable.py | 73 ++++++++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 1 deletion(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index acf300e2..ee4bc998 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -30,6 +30,7 @@ convert_point_process_to_distributed, interpolate_xyz, loc_of_index, + params_to_pstate, query_channel_states_and_params, v_interp, ) @@ -681,8 +682,8 @@ def to_jax(self): self.base.jaxedges = {} edges = self.base.edges.to_dict(orient="list") for i, synapse in enumerate(self.base.synapses): + condition = np.asarray(edges["type_ind"]) == i for key in synapse.synapse_params: - condition = np.asarray(edges["type_ind"]) == i self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition]) for key in synapse.synapse_states: self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition]) @@ -1044,6 +1045,55 @@ def make_trainable( f"Number of newly added trainable parameters: {num_created_parameters}. Total number of trainable parameters: {self.base.num_trainable_params}" ) + def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]): + """Write the trainables into `.nodes` and `.edges`. + + This allows to, e.g., visualize trained networks with `.vis()`. + + Args: + trainable_params: The trainable parameters returned by `get_parameters()`. + """ + # We do not support views. Why? `jaxedges` does not have any NaN + # elements, whereas edges does. Because of this, we already need special + # treatment to make this function work, and it would be an even bigger hassle + # if we wanted to support this. + assert self.__class__.__name__ in [ + "Compartment", + "Branch", + "Cell", + "Network", + ], "Only supports modules." + + # We could also implement this without casting the module to jax. + # However, I think it allows us to reuse as much code as possible and it avoids + # any kind of issues with indexing or parameter sharing (as this is fully + # taken care of by `get_all_parameters()`). + self.base.to_jax() + pstate = params_to_pstate(trainable_params, self.base.indices_set_by_trainables) + all_params = self.base.get_all_parameters(pstate, voltage_solver="jaxley.stone") + + # The value for `delta_t` does not matter here because it is only used to + # compute the initial current. However, the initial current cannot be made + # trainable and so its value never gets used below. + all_states = self.base.get_all_states(pstate, all_params, delta_t=0.025) + + # Loop only over the keys in `pstate` to avoid unnecessary computation. + for parameter in pstate: + key = parameter["key"] + if key in self.base.nodes.columns: + vals_to_set = all_params if key in all_params.keys() else all_states + self.base.nodes[key] = vals_to_set[key] + + # `jaxedges` contains only non-Nan elements. This is unlike the channels where + # we allow parameter sharing. + edges = self.base.edges.to_dict(orient="list") + for i, synapse in enumerate(self.base.synapses): + condition = np.asarray(edges["type_ind"]) == i + for key in list(synapse.synapse_params.keys()): + self.base.edges.loc[condition, key] = all_params[key] + for key in list(synapse.synapse_states.keys()): + self.base.edges.loc[condition, key] = all_states[key] + def distance(self, endpoint: "View") -> float: """Return the direct distance between two compartments. This does not compute the pathwise distance (which is currently not diff --git a/tests/test_make_trainable.py b/tests/test_make_trainable.py index b69909ee..8574121b 100644 --- a/tests/test_make_trainable.py +++ b/tests/test_make_trainable.py @@ -5,9 +5,11 @@ jax.config.update("jax_enable_x64", True) jax.config.update("jax_platform_name", "cpu") +from copy import copy import jax.numpy as jnp import numpy as np +import pytest import jaxley as jx from jaxley.channels import HH, K, Na @@ -421,3 +423,74 @@ def simulate(params): parameters = net.get_parameters() v = simulate(parameters) assert np.invert(np.any(np.isnan(v))), "Found NaN in voltage." + + +def test_write_trainables(): + """Test whether `write_trainables()` gives the same result as using the trainables.""" + comp = jx.Compartment() + branch = jx.Branch(comp, 4) + cell = jx.Cell(branch, [-1, 0]) + net = jx.Network([cell for _ in range(2)]) + connect( + net.cell(0).branch(0).loc(0.9), + net.cell(1).branch(1).loc(0.1), + IonotropicSynapse(), + ) + connect( + net.cell(1).branch(0).loc(0.1), + net.cell(0).branch(1).loc(0.3), + TestSynapse(), + ) + connect( + net.cell(0).branch(0).loc(0.3), + net.cell(0).branch(1).loc(0.6), + TestSynapse(), + ) + connect( + net.cell(1).branch(0).loc(0.6), + net.cell(1).branch(1).loc(0.9), + IonotropicSynapse(), + ) + net.insert(HH()) + net.cell(0).branch(0).comp(0).record() + net.cell(1).branch(0).comp(0).record() + net.cell(0).branch(0).comp(0).stimulate(jx.step_current(0.1, 4.0, 0.1, 0.025, 5.0)) + + net.make_trainable("radius") + net.cell(0).make_trainable("length") + net.cell("all").make_trainable("axial_resistivity") + net.cell("all").branch("all").make_trainable("HH_gNa") + net.cell("all").branch("all").make_trainable("HH_m") + net.make_trainable("IonotropicSynapse_gS") + net.make_trainable("IonotropicSynapse_s") + net.select(edges="all").make_trainable("TestSynapse_gC") + net.select(edges="all").make_trainable("TestSynapse_c") + net.cell(0).branch(0).comp(0).make_trainable("radius") + + params = net.get_parameters() + + # Now, we manually modify the parameters. + for p in params: + for key in p: + p[key] = p[key].at[:].set(np.random.rand()) + + # Test whether voltages match. + v1 = jx.integrate(net, params=params) + + previous_nodes = copy(net.nodes) + previous_edges = copy(net.edges) + net.write_trainables(params) + v2 = jx.integrate(net) + assert np.max(np.abs(v1 - v2)) < 1e-8 + + # Test whether nodes and edges actually changed. + assert not net.nodes.equals(previous_nodes) + assert not net.edges.equals(previous_edges) + + # Test whether `View` raises with `write_trainables()`. + with pytest.raises(AssertionError): + net.cell(0).write_trainables(params) + + # Test whether synapse view raises an error. + with pytest.raises(AssertionError): + net.select(edges=[0, 2, 3]).write_trainables(params)