Skip to content

Commit

Permalink
fix: all tests finally passing
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Dec 16, 2024
1 parent b652e4a commit 1e76daf
Showing 1 changed file with 87 additions and 62 deletions.
149 changes: 87 additions & 62 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from warnings import warn

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -743,11 +744,10 @@ def to_jax(self):
nodes = self.nodes.to_dict(orient="list")
edges = self.edges.to_dict(orient="list")

for key, inds in self._inds_of_state_param.items():
for key, inds in self._iter_states_params(states=True, params=True):
data = nodes if key in self.nodes.columns else edges
jax_arrays = jaxnodes if key in self.nodes.columns else jaxedges

inds = self._inds_of_state_param[key]
values = jnp.asarray(data[key])[inds]
jax_arrays.update({key: values})

Expand Down Expand Up @@ -1136,13 +1136,13 @@ def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]):
)

# Loop only over the keys in `pstate` to avoid unnecessary computation.
for parameter in pstate:
key = parameter["key"]
mech_inds = self._inds_of_state_param[key]
for p in pstate:
key, inds = p["key"], p["indices"]
inds = np.array(inds.reshape(-1))
data = (
self.base.nodes if key in self.base.nodes.columns else self.base.edges
)
data.loc[mech_inds, key] = all_params_states[key]
data.loc[inds, key] = all_params_states[key][inds]

def distance(self, endpoint: "View") -> float:
"""Return the direct distance between two compartments.
Expand Down Expand Up @@ -1214,49 +1214,75 @@ def get_parameters(self) -> List[Dict[str, jnp.ndarray]]:
return self.trainable_params

def _iter_states_params(
self, params=False, states=False
self, params=False, states=False, currents=False
) -> Tuple[str, jnp.ndarray]:
# TODO FROM #447: MAKE THIS WORK FOR VIEW?

# assert that either params or states is True
assert params or states, "Either params or states must be True."
assert params or states or currents, "Select either params / states / currents."
global_params = ["radius", "length", "axial_resistivity", "capacitance"]
global_states = ["v"]
morph_params = ["radius", "length", "axial_resistivity", "capacitance"]

current_names = self.membrane_current_names + self.synapse_current_names
channel_currents = [c.current_name for c in self.channels]

all_mechs = self.channels + self.synapses
all_states = sum([list(m.states) for m in all_mechs], []) + global_states
all_params = sum([list(m.params) for m in all_mechs], []) + morph_params
all_states_params = all_states if states else []
all_states_params += all_params if params else []
all_params = sum([list(m.params) for m in all_mechs], []) + global_params

# Join node and edge states into a single state dictionary.
for key in all_states_params:
jax_arrays = self.jaxnodes if key in self.nodes.columns else self.jaxedges
yield key, jax_arrays[key], self._inds_of_state_param[key]
if params:
for key in all_states:
yield key, self._inds_of_state_param(key)

if states:
for key in all_params:
yield key, self._inds_of_state_param(key)

if currents:
for key in current_names + channel_currents:
yield key, self._inds_of_state_param(key)

def _prepare_for_jax(self):
# prepare lookup of indices of states, parameters and mechanisms
global_params = ["radius", "length", "axial_resistivity", "capacitance"]
global_states = ["v"]

def inds_of_key(key):
data = self.nodes if key in self.nodes.columns else self.edges
return data.index[~data[key].isna()].to_numpy()
current_names = self.membrane_current_names + self.synapse_current_names
global_states_params = global_states + global_params + current_names
node_attrs = self.nodes.columns.to_list() + current_names + channel_names

self._inds_of_state_param = {
k: inds_of_key(k) for k in global_states + global_params
}
channel_names = [c._name for c in self.channels]
syn_names = [s._name for s in self.synapses]

def inds_of_key(key: str) -> np.ndarray:
"""Return the indices for params, states, mechanisms and currents."""
data = self.nodes if key in node_attrs else pd.DataFrame()
data = self.edges if key in self.edges.columns or key in syn_names else data

if key in channel_names + syn_names:
where = data["type"] == key if key in syn_names else data[key]
elif key in data.columns:
where = ~data[key].isna()
elif key in global_states_params:
where = pd.Index([True] * len(data))
else:
raise ValueError(f"Key '{key}' not found in nodes or edges")
return data.index[where].to_numpy()

# expose the lookup function to the class with precomputed attrs in scope
self._inds_of_state_param = inds_of_key

# add index attrs to mechansisms (i.e. where was it inserted) and also keep track
# of states / parameters that are also shared by other mechanisms.
for mech in self.channels + self.synapses:
is_channel = isinstance(mech, Channel)
data = self.nodes if is_channel else self.edges
cond = data[mech._name] if is_channel else data["type"] == mech._name
inds = data.index[cond].to_numpy()
mech.indices = jnp.asarray(inds)
mech.indices = self._inds_of_state_param(mech._name)
mech._jax_inds = {}
currents = {mech.current_name: None} if isinstance(mech, Channel) else {}

for key in list(mech.params) + list(mech.states):
is_global = mech._name not in key
param_state_inds = inds_of_key(key) if is_global else inds
self._inds_of_state_param[key] = jnp.asarray(param_state_inds)
for param_state in {**mech.params, **mech.states, **currents}:
is_global = not param_state.startswith(f"{mech._name}_")
if is_global:
global_inds = self._inds_of_state_param(param_state)
local_inds = np.where(np.isin(global_inds, mech.indices))[0]
mech._jax_inds[param_state] = local_inds

def _get_all_states_params(
self,
Expand All @@ -1268,20 +1294,22 @@ def _get_all_states_params(
states=False,
) -> Dict[str, jnp.ndarray]:
states_params = {}
for key, jax_arrays, _ in self._iter_states_params(params, states):
states_params[key] = jax_arrays
pkeys = {}
for i, p in enumerate(pstate):
pkeys[p["key"]] = pkeys[p["key"]] + [i] if p["key"] in pkeys else [i]

# Override with those parameters set by `.make_trainable()`.
for p in pstate:
key, inds, set_param = p["key"], p["indices"], p["val"]

if key in states_params:
for key, param_state_inds in self._iter_states_params(params, states):
jax_arrays = self.jaxnodes if key in self.nodes.columns else self.jaxedges
states_params[key] = jax_arrays[key]
# Override with those parameters set by `.make_trainable()`.
for i in pkeys.get(key, []):
p = pstate[i]
key, inds, set_param = p["key"], p["indices"], p["val"]
# `inds` is of shape `(num_params, num_comps_per_param)`.
# `set_param` is of shape `(num_params,)`
# We need to unsqueeze `set_param` to make it `(num_params, 1)`
# for the `.set()` to work. This is done with `[:, None]`.
mech_inds = self._inds_of_state_param[key]
inds = jnp.searchsorted(mech_inds, inds)
inds = jnp.searchsorted(param_state_inds, inds)
states_params[key] = states_params[key].at[inds].set(set_param[:, None])

if params:
Expand Down Expand Up @@ -1384,8 +1412,9 @@ def init_states(self, delta_t: float = 0.025):
# Update states of the channels.
self.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`.
states = {}
for key, jax_arrays, _ in self._iter_states_params(states=True):
states[key] = jax_arrays
for key, _ in self._iter_states_params(states=True):
jax_arrays = self.jaxnodes if key in self.nodes.columns else self.jaxedges
states[key] = jax_arrays[key]

# We do not use any `pstate` for initializing. In principle, we could change
# that by allowing an input `params` and `pstate` to this function.
Expand All @@ -1395,8 +1424,8 @@ def init_states(self, delta_t: float = 0.025):
voltages = self.nodes["v"].to_numpy()

for channel in self.channels:
params = self._filter_global_params_states(params, channel)
states = self._filter_global_params_states(states, channel)
states = self._filter_params_states(states, channel._jax_inds)
params = self._filter_params_states(params, channel._jax_inds)

init_state = channel.init_state(
states, voltages[channel.indices], params, delta_t
Expand Down Expand Up @@ -1743,7 +1772,7 @@ def delete_channel(self, channel: Channel):
channel_names = [c._name for c in self.channels]
all_channel_names = [c._name for c in self.base.channels]
if name in channel_names:
channel_cols = list(channel.params) + list(channel.states)
channel_cols = list({**channel.params, **channel.states}.keys())
self.base.nodes.loc[self._nodes_in_view, channel_cols] = float("nan")
self.base.nodes.loc[self._nodes_in_view, name] = False

Expand All @@ -1755,6 +1784,12 @@ def delete_channel(self, channel: Channel):
else:
raise ValueError(f"Channel {name} not found in the module.")

def _filter_params_states(self, pytree, filter_dct):
for key, inds in filter_dct.items():
if key in pytree:
pytree[key] = pytree[key][inds]
return pytree

@only_allow_module
def step(
self,
Expand Down Expand Up @@ -1787,7 +1822,6 @@ def step(
Returns:
The updated state of the module.
"""

# Extract the voltages
voltages = u["v"]

Expand Down Expand Up @@ -1925,6 +1959,8 @@ def _step_channels_state(

for channel in channels:
# States updates.
states = self._filter_params_states(states, channel._jax_inds)
params = self._filter_params_states(params, channel._jax_inds)
channel_states_updated = channel.update_states(
states, delta_t, voltages[channel.indices], params
)
Expand All @@ -1935,17 +1971,6 @@ def _step_channels_state(

return states

def _filter_global_params_states(self, dct, mech):
mech_state_params = list(mech.params) + list(mech.states)
is_global = lambda key: f"{mech._name}_" not in key and key in dct
global_params_states = [key for key in mech_state_params if is_global(key)]
for key in global_params_states:
param_inds = self._inds_of_state_param[key]
param_where_channel = jnp.searchsorted(param_inds, mech.indices)
dct[key] = dct[key][param_where_channel]

return dct

def _channel_currents(
self,
states: Dict[str, jnp.ndarray],
Expand All @@ -1959,7 +1984,7 @@ def _channel_currents(
This is also updates `state` because the `state` also contains the current.
"""
voltages = states["v"]
morph_params = ["radius", "length", "axial_resistivity"]
morph_params = ["radius", "length", "axial_resistivity", "capacitance"]
morph_params = {pkey: params[pkey] for pkey in morph_params}

# Compute current through channels.
Expand All @@ -1974,8 +1999,8 @@ def _channel_currents(
v_channel = voltages[channel_inds]
v_and_perturbed = jnp.array([v_channel, v_channel + diff])

params = self._filter_global_params_states(params, channel)
states = self._filter_global_params_states(states, channel)
states = self._filter_params_states(states, channel._jax_inds)
params = self._filter_params_states(params, channel._jax_inds)

membrane_currents = vmap(channel.compute_current, in_axes=(None, 0, None))(
states, v_and_perturbed, params
Expand Down

0 comments on commit 1e76daf

Please sign in to comment.