Skip to content

Commit

Permalink
wip: make get_all_params work with new indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Dec 5, 2024
1 parent 658255b commit e3e2000
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,8 +737,8 @@ def to_jax(self):

for jax_arrays, data, mechs in zip(
[jaxnodes, jaxedges],
[self.nodes, self.edges],
[self.channels, self.synapses],
[self.base.nodes, self.base.edges],
[self.base.channels, self.base.synapses],
):
jax_arrays.update({"index": data.index.to_numpy()})
all_inds = jax_arrays["index"]
Expand All @@ -754,9 +754,9 @@ def to_jax(self):
jax_arrays.update(params.to_dict(orient="list"))

morph_params = ["radius", "length", "axial_resistivity", "capacitance"]
jaxnodes.update(self.nodes[["v"]+morph_params].to_dict(orient="list"))
jaxnodes = {k: jnp.asarray(v) for k, v in jaxnodes.items()}
jaxedges = {k: jnp.asarray(v) for k, v in jaxedges.items()}
jaxnodes.update(self.nodes[["v"] + morph_params].to_dict(orient="list"))
self.base.jaxnodes = {k: jnp.asarray(v) for k, v in jaxnodes.items()}
self.base.jaxedges = {k: jnp.asarray(v) for k, v in jaxedges.items()}

def show(
self,
Expand Down Expand Up @@ -1260,12 +1260,17 @@ def get_all_parameters(
for key in ["v"] + morph_params:
params[key] = self.base.jaxnodes[key]

for channel in self.base.channels:
for channel_params in channel.params:
params[channel_params] = self.base.jaxnodes[channel_params]

for synapse_params in self.base.synapse_param_names:
params[synapse_params] = self.base.jaxedges[synapse_params]
for jax_arrays, data, mechs in zip(
[self.base.jaxnodes, self.base.jaxedges],
[self.base.nodes, self.base.edges],
[self.base.channels, self.base.synapses],
):
for mech in mechs:
inds = jax_arrays[mech._name]
for mech_param in mech.params:
params[mech_param] = data[mech_param].to_numpy()
params[mech_param][inds] = jax_arrays[mech_param]
params[mech_param] = jnp.asarray(params[mech_param])

# Override with those parameters set by `.make_trainable()`.
for parameter in pstate:
Expand Down

0 comments on commit e3e2000

Please sign in to comment.