diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index abd41a81..55c43a48 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -737,8 +737,8 @@ def to_jax(self): for jax_arrays, data, mechs in zip( [jaxnodes, jaxedges], - [self.nodes, self.edges], - [self.channels, self.synapses], + [self.base.nodes, self.base.edges], + [self.base.channels, self.base.synapses], ): jax_arrays.update({"index": data.index.to_numpy()}) all_inds = jax_arrays["index"] @@ -754,9 +754,9 @@ def to_jax(self): jax_arrays.update(params.to_dict(orient="list")) morph_params = ["radius", "length", "axial_resistivity", "capacitance"] - jaxnodes.update(self.nodes[["v"]+morph_params].to_dict(orient="list")) - jaxnodes = {k: jnp.asarray(v) for k, v in jaxnodes.items()} - jaxedges = {k: jnp.asarray(v) for k, v in jaxedges.items()} + jaxnodes.update(self.nodes[["v"] + morph_params].to_dict(orient="list")) + self.base.jaxnodes = {k: jnp.asarray(v) for k, v in jaxnodes.items()} + self.base.jaxedges = {k: jnp.asarray(v) for k, v in jaxedges.items()} def show( self, @@ -1260,12 +1260,17 @@ def get_all_parameters( for key in ["v"] + morph_params: params[key] = self.base.jaxnodes[key] - for channel in self.base.channels: - for channel_params in channel.params: - params[channel_params] = self.base.jaxnodes[channel_params] - - for synapse_params in self.base.synapse_param_names: - params[synapse_params] = self.base.jaxedges[synapse_params] + for jax_arrays, data, mechs in zip( + [self.base.jaxnodes, self.base.jaxedges], + [self.base.nodes, self.base.edges], + [self.base.channels, self.base.synapses], + ): + for mech in mechs: + inds = jax_arrays[mech._name] + for mech_param in mech.params: + params[mech_param] = data[mech_param].to_numpy() + params[mech_param][inds] = jax_arrays[mech_param] + params[mech_param] = jnp.asarray(params[mech_param]) # Override with those parameters set by `.make_trainable()`. for parameter in pstate: