From 1e76daf8cc21a25f158c3037d97f6e2d91ca6f60 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Mon, 16 Dec 2024 17:47:56 +0100 Subject: [PATCH] fix: all tests finally passing --- jaxley/modules/base.py | 149 ++++++++++++++++++++++++----------------- 1 file changed, 87 insertions(+), 62 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 1936bc29..e793fc21 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from warnings import warn +import jax import jax.numpy as jnp import numpy as np import pandas as pd @@ -743,11 +744,10 @@ def to_jax(self): nodes = self.nodes.to_dict(orient="list") edges = self.edges.to_dict(orient="list") - for key, inds in self._inds_of_state_param.items(): + for key, inds in self._iter_states_params(states=True, params=True): data = nodes if key in self.nodes.columns else edges jax_arrays = jaxnodes if key in self.nodes.columns else jaxedges - inds = self._inds_of_state_param[key] values = jnp.asarray(data[key])[inds] jax_arrays.update({key: values}) @@ -1136,13 +1136,13 @@ def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]): ) # Loop only over the keys in `pstate` to avoid unnecessary computation. - for parameter in pstate: - key = parameter["key"] - mech_inds = self._inds_of_state_param[key] + for p in pstate: + key, inds = p["key"], p["indices"] + inds = np.array(inds.reshape(-1)) data = ( self.base.nodes if key in self.base.nodes.columns else self.base.edges ) - data.loc[mech_inds, key] = all_params_states[key] + data.loc[inds, key] = all_params_states[key][inds] def distance(self, endpoint: "View") -> float: """Return the direct distance between two compartments. @@ -1214,49 +1214,75 @@ def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: return self.trainable_params def _iter_states_params( - self, params=False, states=False + self, params=False, states=False, currents=False ) -> Tuple[str, jnp.ndarray]: - # TODO FROM #447: MAKE THIS WORK FOR VIEW? - # assert that either params or states is True - assert params or states, "Either params or states must be True." + assert params or states or currents, "Select either params / states / currents." + global_params = ["radius", "length", "axial_resistivity", "capacitance"] global_states = ["v"] - morph_params = ["radius", "length", "axial_resistivity", "capacitance"] + + current_names = self.membrane_current_names + self.synapse_current_names + channel_currents = [c.current_name for c in self.channels] all_mechs = self.channels + self.synapses all_states = sum([list(m.states) for m in all_mechs], []) + global_states - all_params = sum([list(m.params) for m in all_mechs], []) + morph_params - all_states_params = all_states if states else [] - all_states_params += all_params if params else [] + all_params = sum([list(m.params) for m in all_mechs], []) + global_params - # Join node and edge states into a single state dictionary. - for key in all_states_params: - jax_arrays = self.jaxnodes if key in self.nodes.columns else self.jaxedges - yield key, jax_arrays[key], self._inds_of_state_param[key] + if params: + for key in all_states: + yield key, self._inds_of_state_param(key) + + if states: + for key in all_params: + yield key, self._inds_of_state_param(key) + + if currents: + for key in current_names + channel_currents: + yield key, self._inds_of_state_param(key) def _prepare_for_jax(self): + # prepare lookup of indices of states, parameters and mechanisms global_params = ["radius", "length", "axial_resistivity", "capacitance"] global_states = ["v"] - def inds_of_key(key): - data = self.nodes if key in self.nodes.columns else self.edges - return data.index[~data[key].isna()].to_numpy() + current_names = self.membrane_current_names + self.synapse_current_names + global_states_params = global_states + global_params + current_names + node_attrs = self.nodes.columns.to_list() + current_names + channel_names - self._inds_of_state_param = { - k: inds_of_key(k) for k in global_states + global_params - } + channel_names = [c._name for c in self.channels] + syn_names = [s._name for s in self.synapses] + + def inds_of_key(key: str) -> np.ndarray: + """Return the indices for params, states, mechanisms and currents.""" + data = self.nodes if key in node_attrs else pd.DataFrame() + data = self.edges if key in self.edges.columns or key in syn_names else data + + if key in channel_names + syn_names: + where = data["type"] == key if key in syn_names else data[key] + elif key in data.columns: + where = ~data[key].isna() + elif key in global_states_params: + where = pd.Index([True] * len(data)) + else: + raise ValueError(f"Key '{key}' not found in nodes or edges") + return data.index[where].to_numpy() + + # expose the lookup function to the class with precomputed attrs in scope + self._inds_of_state_param = inds_of_key + # add index attrs to mechansisms (i.e. where was it inserted) and also keep track + # of states / parameters that are also shared by other mechanisms. for mech in self.channels + self.synapses: - is_channel = isinstance(mech, Channel) - data = self.nodes if is_channel else self.edges - cond = data[mech._name] if is_channel else data["type"] == mech._name - inds = data.index[cond].to_numpy() - mech.indices = jnp.asarray(inds) + mech.indices = self._inds_of_state_param(mech._name) + mech._jax_inds = {} + currents = {mech.current_name: None} if isinstance(mech, Channel) else {} - for key in list(mech.params) + list(mech.states): - is_global = mech._name not in key - param_state_inds = inds_of_key(key) if is_global else inds - self._inds_of_state_param[key] = jnp.asarray(param_state_inds) + for param_state in {**mech.params, **mech.states, **currents}: + is_global = not param_state.startswith(f"{mech._name}_") + if is_global: + global_inds = self._inds_of_state_param(param_state) + local_inds = np.where(np.isin(global_inds, mech.indices))[0] + mech._jax_inds[param_state] = local_inds def _get_all_states_params( self, @@ -1268,20 +1294,22 @@ def _get_all_states_params( states=False, ) -> Dict[str, jnp.ndarray]: states_params = {} - for key, jax_arrays, _ in self._iter_states_params(params, states): - states_params[key] = jax_arrays + pkeys = {} + for i, p in enumerate(pstate): + pkeys[p["key"]] = pkeys[p["key"]] + [i] if p["key"] in pkeys else [i] - # Override with those parameters set by `.make_trainable()`. - for p in pstate: - key, inds, set_param = p["key"], p["indices"], p["val"] - - if key in states_params: + for key, param_state_inds in self._iter_states_params(params, states): + jax_arrays = self.jaxnodes if key in self.nodes.columns else self.jaxedges + states_params[key] = jax_arrays[key] + # Override with those parameters set by `.make_trainable()`. + for i in pkeys.get(key, []): + p = pstate[i] + key, inds, set_param = p["key"], p["indices"], p["val"] # `inds` is of shape `(num_params, num_comps_per_param)`. # `set_param` is of shape `(num_params,)` # We need to unsqueeze `set_param` to make it `(num_params, 1)` # for the `.set()` to work. This is done with `[:, None]`. - mech_inds = self._inds_of_state_param[key] - inds = jnp.searchsorted(mech_inds, inds) + inds = jnp.searchsorted(param_state_inds, inds) states_params[key] = states_params[key].at[inds].set(set_param[:, None]) if params: @@ -1384,8 +1412,9 @@ def init_states(self, delta_t: float = 0.025): # Update states of the channels. self.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`. states = {} - for key, jax_arrays, _ in self._iter_states_params(states=True): - states[key] = jax_arrays + for key, _ in self._iter_states_params(states=True): + jax_arrays = self.jaxnodes if key in self.nodes.columns else self.jaxedges + states[key] = jax_arrays[key] # We do not use any `pstate` for initializing. In principle, we could change # that by allowing an input `params` and `pstate` to this function. @@ -1395,8 +1424,8 @@ def init_states(self, delta_t: float = 0.025): voltages = self.nodes["v"].to_numpy() for channel in self.channels: - params = self._filter_global_params_states(params, channel) - states = self._filter_global_params_states(states, channel) + states = self._filter_params_states(states, channel._jax_inds) + params = self._filter_params_states(params, channel._jax_inds) init_state = channel.init_state( states, voltages[channel.indices], params, delta_t @@ -1743,7 +1772,7 @@ def delete_channel(self, channel: Channel): channel_names = [c._name for c in self.channels] all_channel_names = [c._name for c in self.base.channels] if name in channel_names: - channel_cols = list(channel.params) + list(channel.states) + channel_cols = list({**channel.params, **channel.states}.keys()) self.base.nodes.loc[self._nodes_in_view, channel_cols] = float("nan") self.base.nodes.loc[self._nodes_in_view, name] = False @@ -1755,6 +1784,12 @@ def delete_channel(self, channel: Channel): else: raise ValueError(f"Channel {name} not found in the module.") + def _filter_params_states(self, pytree, filter_dct): + for key, inds in filter_dct.items(): + if key in pytree: + pytree[key] = pytree[key][inds] + return pytree + @only_allow_module def step( self, @@ -1787,7 +1822,6 @@ def step( Returns: The updated state of the module. """ - # Extract the voltages voltages = u["v"] @@ -1925,6 +1959,8 @@ def _step_channels_state( for channel in channels: # States updates. + states = self._filter_params_states(states, channel._jax_inds) + params = self._filter_params_states(params, channel._jax_inds) channel_states_updated = channel.update_states( states, delta_t, voltages[channel.indices], params ) @@ -1935,17 +1971,6 @@ def _step_channels_state( return states - def _filter_global_params_states(self, dct, mech): - mech_state_params = list(mech.params) + list(mech.states) - is_global = lambda key: f"{mech._name}_" not in key and key in dct - global_params_states = [key for key in mech_state_params if is_global(key)] - for key in global_params_states: - param_inds = self._inds_of_state_param[key] - param_where_channel = jnp.searchsorted(param_inds, mech.indices) - dct[key] = dct[key][param_where_channel] - - return dct - def _channel_currents( self, states: Dict[str, jnp.ndarray], @@ -1959,7 +1984,7 @@ def _channel_currents( This is also updates `state` because the `state` also contains the current. """ voltages = states["v"] - morph_params = ["radius", "length", "axial_resistivity"] + morph_params = ["radius", "length", "axial_resistivity", "capacitance"] morph_params = {pkey: params[pkey] for pkey in morph_params} # Compute current through channels. @@ -1974,8 +1999,8 @@ def _channel_currents( v_channel = voltages[channel_inds] v_and_perturbed = jnp.array([v_channel, v_channel + diff]) - params = self._filter_global_params_states(params, channel) - states = self._filter_global_params_states(states, channel) + states = self._filter_params_states(states, channel._jax_inds) + params = self._filter_params_states(params, channel._jax_inds) membrane_currents = vmap(channel.compute_current, in_axes=(None, 0, None))( states, v_and_perturbed, params