diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index c01982c8..af503454 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -32,7 +32,6 @@ compute_levels, interpolate_xyzr, params_to_pstate, - query_states_and_params, v_interp, ) from jaxley.utils.debug_solver import compute_morphology_indices @@ -733,35 +732,30 @@ def to_jax(self): they can be processed on GPU/TPU and such that the simulation can be differentiated. `.to_jax()` copies the `.nodes` to `.jaxnodes`. """ - jaxnodes, jaxedges = {}, {} - - for jax_arrays, data, mechs in zip( - [jaxnodes, jaxedges], - [self.base.nodes, self.base.edges], - [self.base.channels, self.base.synapses], - ): - for mech in mechs: - if isinstance(mech, Channel): - inds = data.index[data[mech._name]] - else: - inds = data.index[data["type"] == mech._name] - states_params = list(mech.params) + list(mech.states) - params = data[states_params].loc[inds] - # jax_arrays.update({mech._name: inds}) - jax_arrays.update(params.to_dict(orient="list")) - - morph_params = ["radius", "length", "axial_resistivity", "capacitance"] - jaxnodes.update(self.nodes[["v"] + morph_params].to_dict(orient="list")) - self.base.jaxnodes = {k: jnp.asarray(v) for k, v in jaxnodes.items()} - self.base.jaxedges = {k: jnp.asarray(v) for k, v in jaxedges.items()} - # the parameters and states in the jaxnodes are stored on a per-mechanism basis, # i.e. if only compartment #2 has a HH channels, then the jaxnodes will be # {HH_gNa: [0.1], ...} vs. {gbar_HH: [NaN, NaN, 0.1, ...], ...}. This means # a seperate lookup table is needed to figure out which parameters and states # belong to which mechanism and at which global indices the mechanism lives. + self._update_mech_lookup_table() + syn_param_states = sum( + [list(s.params) + list(s.states) for s in self.synapses], [] + ) - self.base._update_mech_lookup_table() + jaxnodes, jaxedges = {}, {} + + for state_param in self._mech_lookup: + mech_inds = self._get_state_param_inds(state_param) + data = self.edges if state_param in syn_param_states else self.nodes + jax_arrays = jaxedges if state_param in syn_param_states else jaxnodes + + values = data.loc[mech_inds, state_param].to_numpy() + jax_arrays.update({state_param: values}) + + morph_params = ["radius", "length", "axial_resistivity", "capacitance"] + jaxnodes.update(self.nodes[["v"] + morph_params].to_dict(orient="list")) + self.jaxnodes = {k: jnp.asarray(v) for k, v in jaxnodes.items()} + self.jaxedges = {k: jnp.asarray(v) for k, v in jaxedges.items()} def show( self, @@ -1136,12 +1130,12 @@ def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]): # However, I think it allows us to reuse as much code as possible and it avoids # any kind of issues with indexing or parameter sharing (as this is fully # taken care of by `get_all_parameters()`). - self.base.to_jax() + self.to_jax() # The value for `delta_t` does not matter here because it is only used to # compute the initial current. However, the initial current cannot be made # trainable and so its value never gets used below. - pstate = params_to_pstate(trainable_params, self.base.indices_set_by_trainables) - all_params_states = self.base._get_all_states_params( + pstate = params_to_pstate(trainable_params, self.indices_set_by_trainables) + all_params_states = self._get_all_states_params( pstate, delta_t=0.025, voltage_solver="jaxley.stone", @@ -1152,10 +1146,9 @@ def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]): # Loop only over the keys in `pstate` to avoid unnecessary computation. for parameter in pstate: key = parameter["key"] - mech, mech_inds = self.base._get_mech_inds_of_param_state(key) - data = ( - self.base.nodes if key in self.base.nodes.columns else self.base.edges - ) + mech_inds = self._get_state_param_inds(key) + mech_inds = np.concatenate(list(mech_inds.values())) + data = self.nodes if key in self.nodes.columns else self.edges data.loc[mech_inds, key] = all_params_states[key] def distance(self, endpoint: "View") -> float: @@ -1257,33 +1250,64 @@ def _iter_states_params( yield key, jax_arrays[key] def _update_mech_lookup_table(self): - chan_items = [(list(c.params) + list(c.states), c._name) for c in self.channels] - syn_items = [(list(s.params) + list(s.states), s._name) for s in self.synapses] - - chan_inds = [ - {c._name + "_index": self.nodes.index[self.nodes[c._name]].to_numpy()} - for c in self.channels - ] - syn_inds = [ - { - s._name - + "_index": self.edges.index[self.edges["type"] == s._name].to_numpy() + state_param_lookup = {} + mech2inds = {} + + for mech in self.channels + self.synapses: + is_channel = isinstance(mech, Channel) + data = self.nodes if is_channel else self.edges + cond = data[mech._name] if is_channel else data["type"] == mech._name + + chan_inds = {mech._name: data.index[cond].to_numpy()} + mech2inds.update(chan_inds) + for item in list(mech.params) + list(mech.states): + inds = chan_inds[mech._name] + mech_info = {mech._name: {"global_index": inds}} + if item not in state_param_lookup: + state_param_lookup[item] = mech_info + else: + state_param_lookup[item].update(mech_info) + + for state_param, info in state_param_lookup.items(): + global_inds = [v["global_index"] for v in info.values()] + comb_inds = np.unique(np.concatenate(global_inds)) + local_inds = { + k: jnp.searchsorted(comb_inds, v["global_index"]) + for k, v in info.items() } - for s in self.synapses - ] - - mech_items = [dict(zip(k, [v] * len(k))) for (k, v) in chan_items + syn_items] - mech_items += chan_inds + syn_inds + new_info = {} + for v1, (k, v2) in zip(global_inds, local_inds.items()): + new_info.update({k: {"global_index": v1, "local_index": v2}}) + state_param_lookup[state_param].update(new_info) + + self._mech_lookup = state_param_lookup + self._mech_inds = mech2inds + + def _get_state_param_inds(self, key: str) -> Tuple[str, jnp.ndarray]: + inds = None + if key in self._mech_lookup: + mechs = self._mech_lookup[key] + inds = np.unique(np.concatenate([self._mech_inds[m] for m in mechs])) + elif key in self.nodes.columns: + inds = {key: self._nodes_in_view} + elif key in self.edges.columns: + inds = {key: self._edges_in_view} + else: + raise KeyError(f"Key '{key}' not found in nodes or edges") - self._mech_lookup_table = {k: v for d in mech_items for k, v in d.items()} + return inds - def _get_mech_inds_of_param_state(self, key: str) -> Tuple[str, jnp.ndarray]: - if key in self._mech_lookup_table: - mech = self._mech_lookup_table[key] - inds = self._mech_lookup_table[mech + "_index"] - return mech, inds + def _filter_states_params( + self, states_params: Dict[str, jnp.ndarray], mech: Union[Channel, Synapse] + ) -> Dict[str, jnp.ndarray]: + pkeys = list(mech.params) + mech_states_params = pkeys if pkeys[0] in states_params else list(mech.states) - return None, self._nodes_in_view + filtered_states_params = {} + for key in mech_states_params: + mech_inds = self._mech_lookup[key][mech._name]["local_index"] + filtered_states_params[key] = states_params[key][mech_inds] + return filtered_states_params @only_allow_module def _get_all_states_params( @@ -1296,39 +1320,38 @@ def _get_all_states_params( states=False, ) -> Dict[str, jnp.ndarray]: states_params = {} - for key, jax_arrays in self.base._iter_states_params(params, states): + for key, jax_arrays in self._iter_states_params(params, states): states_params[key] = jax_arrays # Override with those parameters set by `.make_trainable()`. - for parameter in pstate: - key = parameter["key"] - inds = parameter["indices"] - set_param = parameter["val"] + for p in pstate: + key, inds, set_param = p["key"], p["indices"], p["val"] if key in states_params: - mech, mech_inds = self.base._get_mech_inds_of_param_state(key) # `inds` is of shape `(num_params, num_comps_per_param)`. # `set_param` is of shape `(num_params,)` # We need to unsqueeze `set_param` to make it `(num_params, 1)` # for the `.set()` to work. This is done with `[:, None]`. - inds = np.searchsorted(mech_inds, inds) + mech_inds = self._get_state_param_inds(key) + inds = jnp.searchsorted(mech_inds, inds) states_params[key] = states_params[key].at[inds].set(set_param[:, None]) if params: # Compute conductance params and add them to the params dictionary. - states_params["axial_conductances"] = self.base._compute_axial_conductances( + states_params["axial_conductances"] = self._compute_axial_conductances( params=states_params ) + if states: all_params = states_params if all_params is None and params else all_params # Add to the states the initial current through every channel. - states, _ = self.base._channel_currents( - states_params, delta_t, self.base.channels, self.base.nodes, all_params + states, _ = self._channel_currents( + states_params, delta_t, self.channels, self.nodes, all_params ) # Add to the states the initial current through every synapse. - states, _ = self.base._synapse_currents( - states_params, self.base.synapses, all_params, delta_t, self.base.edges + states, _ = self._synapse_currents( + states_params, self.synapses, all_params, delta_t, self.edges ) return states_params @@ -1414,30 +1437,25 @@ def init_states(self, delta_t: float = 0.025): delta_t: Passed on to `channel.init_state()`. """ # Update states of the channels. - self.base.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`. - channel_nodes = self.base.nodes + self.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`. states = {} - for key, jax_arrays in self.base._iter_states_params(states=True): + for key, jax_arrays in self._iter_states_params(states=True): states[key] = jax_arrays # We do not use any `pstate` for initializing. In principle, we could change # that by allowing an input `params` and `pstate` to this function. # `voltage_solver` could also be `jax.sparse` here, because both of them # build the channel parameters in the same way. - params = self.base.get_all_parameters([], voltage_solver="jaxley.thomas") + params = self.get_all_parameters([], voltage_solver="jaxley.thomas") + voltages = self.nodes["v"].to_numpy() - for channel in self.base.channels: - name = channel._name - channel_indices = self._mech_lookup_table[name + "_index"] - voltages = channel_nodes.loc[channel_indices, "v"].to_numpy() - - channel_param_names = list(channel.params.keys()) - channel_state_names = list(channel.states.keys()) - channel_states = query_states_and_params(states, channel_state_names) - channel_params = query_states_and_params(params, channel_param_names) + for channel in self.channels: + channel_states = self._filter_states_params(states, channel) + channel_params = self._filter_states_params(params, channel) + channel_inds = self._mech_inds[channel._name] init_state = channel.init_state( - channel_states, voltages, channel_params, delta_t + channel_states, voltages[channel_inds], channel_params, delta_t ) # `init_state` might not return all channel states. Only the ones that are @@ -1446,7 +1464,7 @@ def init_states(self, delta_t: float = 0.025): # Note that we are overriding `self.nodes` here, but `self.nodes` is # not used above to actually compute the current states (so there are # no issues with overriding states). - self.nodes.loc[channel_indices, key] = val + self.nodes.loc[channel_inds, key] = val def _init_morph_for_debugging(self): """Instandiates row and column inds which can be used to solve the voltage eqs. @@ -1781,8 +1799,7 @@ def delete_channel(self, channel: Channel): channel_names = [c._name for c in self.channels] all_channel_names = [c._name for c in self.base.channels] if name in channel_names: - channel_cols = list(channel.params.keys()) - channel_cols += list(channel.states.keys()) + channel_cols = list(channel.params) + list(channel.states) self.base.nodes.loc[self._nodes_in_view, channel_cols] = float("nan") self.base.nodes.loc[self._nodes_in_view, name] = False @@ -1962,22 +1979,26 @@ def _step_channels_state( """One integration step of the channels.""" voltages = states["v"] morph_params = ["radius", "length", "axial_resistivity", "capacitance"] + morph_params = {pkey: params[pkey] for pkey in morph_params} + current_states = {name: states[name] for name in self.membrane_current_names} for channel in channels: - channel_param_names = list(channel.params) + morph_params - channel_params = query_states_and_params(params, channel_param_names) - channel_state_names = list(channel.states) + self.membrane_current_names - channel_states = query_states_and_params(states, channel_state_names) + channel_params = self._filter_states_params(params, channel) + channel_states = self._filter_states_params(states, channel) + + channel_params.update(morph_params) + channel_states.update(current_states) # States updates. - channel_inds = self._mech_lookup_table[channel._name + "_index"] - states_updated = channel.update_states( + channel_inds = self._mech_inds[channel._name] + channel_states_updated = channel.update_states( channel_states, delta_t, voltages[channel_inds], channel_params ) # Rebuild state. This has to be done within the loop over channels to allow # multiple channels which modify the same state. - for key, val in states_updated.items(): - states[key] = states[key].at[:].set(val) + for key, val in channel_states_updated.items(): + channel_inds = self._mech_lookup[key][channel._name]["local_index"] + states[key] = states[key].at[channel_inds].set(val) return states @@ -1995,21 +2016,22 @@ def _channel_currents( """ voltages = states["v"] morph_params = ["radius", "length", "axial_resistivity"] + morph_params = {pkey: params[pkey] for pkey in morph_params} # Compute current through channels. zeros = jnp.zeros_like(voltages) - voltage_terms, constant_terms = zeros, zeros + voltage_terms, const_terms = zeros, zeros # Run with two different voltages that are `diff` apart to infer the slope and # offset. diff = 1e-3 - current_states = {name: zeros for name in self.membrane_current_names} for channel in channels: - channel_param_names = list(channel.params) + morph_params - channel_params = query_states_and_params(params, channel_param_names) - channel_states = query_states_and_params(states, channel.states) + channel_params = self._filter_states_params(params, channel) + channel_states = self._filter_states_params(states, channel) + + channel_params.update(morph_params) - channel_inds = self._mech_lookup_table[channel._name + "_index"] + channel_inds = self._mech_inds[channel._name] v_channel = voltages[channel_inds] v_and_perturbed = jnp.array([v_channel, v_channel + diff]) @@ -2019,13 +2041,11 @@ def _channel_currents( # Split into voltage and constant terms. voltage_term = (membrane_currents[1] - membrane_currents[0]) / diff - constant_term = membrane_currents[0] - voltage_term * v_channel + const_term = membrane_currents[0] - voltage_term * v_channel # * 1000 to convert from mA/cm^2 to uA/cm^2. voltage_terms = voltage_terms.at[channel_inds].add(voltage_term * 1000.0) - constant_terms = constant_terms.at[channel_inds].add( - -constant_term * 1000.0 - ) + const_terms = const_terms.at[channel_inds].add(-const_term * 1000.0) # Save the current (for the unperturbed voltage) as a state that will # also be passed to the state update. @@ -2039,7 +2059,7 @@ def _channel_currents( # recorded and used by `Channel.update_states()`. for name in self.membrane_current_names: states[name] = current_states[name] - return states, (voltage_terms, constant_terms) + return states, (voltage_terms, const_terms) def _step_synapse( self, @@ -2539,35 +2559,32 @@ def _jax_arrays_in_view(self, pointer: Union[Module, View]): jaxnodes = {} if self.base.jaxnodes else None jaxedges = {} if self.base.jaxedges else None - # None check is needed for View -> see `View._jax_arrays_in_view` - chan_mechs = [m._name for m in self.channels] - syn_mechs = [m._name for m in self.synapses if m is not None] - mechs = chan_mechs + syn_mechs - if self.base._mech_lookup_table: - self._mech_lookup_table = {} - for k, v in self.base._mech_lookup_table.items(): - if "index" in k: - mech = k.replace("_index", "") - if mech in chan_mechs: - v = v[a_intersects_b_at(v, self._nodes_in_view)] - elif mech in syn_mechs: - v = v[a_intersects_b_at(v, self._edges_in_view)] - self._mech_lookup_table[k] = v - elif v in mechs + ["v"] + morph_params: - self._mech_lookup_table[k] = v - - for jax_arrays, base_jax_arrays, viewed_inds in zip( - [jaxnodes, jaxedges], - [self.base.jaxnodes, self.base.jaxedges], - [self._nodes_in_view, self._edges_in_view], - ): - if base_jax_arrays is not None and len(viewed_inds) > 0: - for key, values in base_jax_arrays.items(): - mech, mech_inds = self.base._get_mech_inds_of_param_state(key) - if mech is None or mech in mechs: - jax_arrays[key] = values.at[ - a_intersects_b_at(mech_inds, viewed_inds) - ] + # # None check is needed for View -> see `View._jax_arrays_in_view` + # chan_mechs = [m._name for m in self.channels] + # syn_mechs = [m._name for m in self.synapses if m is not None] + # mechs = chan_mechs + syn_mechs + # if self.base._mech_lookup: + # self._mech_lookup = {} + # self._mech_inds = {} + # for k, v in self.base._mech_inds.items(): + # inds = self._nodes_in_view if k in chan_mechs else self._edges_in_view + # self._mech_inds[k] = v[a_intersects_b_at(v, inds)] + # for k, v in self.base._mech_lookup.items(): + # self._mech_lookup[k] = {k:v_i for k,v_i in v.items() if k in mechs} + + # for jax_arrays, base_jax_arrays, viewed_inds in zip( + # [jaxnodes, jaxedges], + # [self.base.jaxnodes, self.base.jaxedges], + # [self._nodes_in_view, self._edges_in_view], + # ): + # if base_jax_arrays is not None and len(viewed_inds) > 0: + # for key, values in base_jax_arrays.items(): + # mech_inds = self.base._get_state_param_inds(key) + # for mech, inds in mech_inds.items(): + # if mech is None or mech in mechs: + # jax_arrays[key] = values.at[ + # a_intersects_b_at(inds, viewed_inds) + # ] return jaxnodes, jaxedges diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 3f1f7339..a86cf8a0 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -22,7 +22,6 @@ dtype_aware_concat, loc_of_index, merge_cells, - query_states_and_params, ) from jaxley.utils.misc_utils import concat_and_ignore_empty, cumsum_leading_zero from jaxley.utils.solver_utils import ( @@ -267,16 +266,16 @@ def _step_synapse_state( ) -> Dict: voltages = states["v"] - for i, group in edges.groupby("type_ind"): - synapse = syn_channels[i] - pre_inds = group["pre_global_comp_index"].to_numpy() - post_inds = group["post_global_comp_index"].to_numpy() + for synapse in syn_channels: + inds = self._mech_inds[synapse._name] + pre_inds = edges.loc[inds, "pre_global_comp_index"].to_numpy() + post_inds = edges.loc[inds, "post_global_comp_index"].to_numpy() - synapse_params = query_states_and_params(params, synapse.params) - synapse_states = query_states_and_params(states, synapse.states) + synapse_params = self._filter_states_params(params, synapse) + synapse_states = self._filter_states_params(states, synapse) # State updates. - states_updated = synapse.update_states( + synapse_states_updated = synapse.update_states( synapse_states, delta_t, voltages[pre_inds], @@ -286,8 +285,9 @@ def _step_synapse_state( # Rebuild state. This has to be done within the loop over channels to allow # multiple channels which modify the same state. - for key, val in states_updated.items(): - states[key] = states[key].at[:].set(val) + for key, val in synapse_states_updated.items(): + synapse_inds = self._mech_lookup[key][synapse._name]["local_inds"] + states[key] = states[key].at[synapse_inds].set(val) return states @@ -303,7 +303,7 @@ def _synapse_currents( # Compute current through synapses. zeros = jnp.zeros_like(voltages) - syn_voltage_terms, syn_constant_terms = zeros, zeros + syn_voltage_terms, syn_const_terms = zeros, zeros # Run with two different voltages that are `diff` apart to infer the slope and # offset. diff = 1e-3 @@ -315,8 +315,8 @@ def _synapse_currents( pre_inds = group["pre_global_comp_index"].to_numpy() post_inds = group["post_global_comp_index"].to_numpy() - synapse_params = query_states_and_params(params, synapse.params) - synapse_states = query_states_and_params(states, synapse.states) + synapse_params = self._filter_states_params(params, synapse) + synapse_states = self._filter_states_params(states, synapse) v_pre, v_post = voltages[pre_inds], voltages[post_inds] pre_v_and_perturbed = jnp.array([v_pre, v_pre + diff]) @@ -346,7 +346,7 @@ def _synapse_currents( gathered_syn_currents = gather_synapes(num_comp, post_inds, *syn_voltages) syn_voltage_terms = syn_voltage_terms.at[:].add(gathered_syn_currents[0]) - syn_constant_terms = syn_constant_terms.at[:].add(-gathered_syn_currents[1]) + syn_const_terms = syn_const_terms.at[:].add(-gathered_syn_currents[1]) # Save the current (for the unperturbed voltage) as a state that will # also be passed to the state update. synapse_current_states[f"i_{synapse._name}"] = ( @@ -359,7 +359,7 @@ def _synapse_currents( # recorded and used by `Channel.update_states()`. for name in [s._name for s in self.synapses]: states[f"i_{name}"] = synapse_current_states[f"i_{name}"] - return states, (syn_voltage_terms, syn_constant_terms) + return states, (syn_voltage_terms, syn_const_terms) def arrange_in_layers( self, diff --git a/jaxley/utils/cell_utils.py b/jaxley/utils/cell_utils.py index a1e35c81..27e337ab 100644 --- a/jaxley/utils/cell_utils.py +++ b/jaxley/utils/cell_utils.py @@ -686,21 +686,6 @@ def group_and_sum( return group_sums -def query_states_and_params(d, keys, idcs=None): - """Get dict with subset of keys and values from d. - - This is used to restrict a dict where every item contains __all__ states to only - the ones that are relevant for the channel. E.g. - - ```states = {'eCa': Array([ 0., 0., nan]}``` - - will be - ```states = {'eCa': Array([ 0., 0.]}``` - - Only loops over necessary keys, as opposed to looping over `d.items()`.""" - return dict(zip(keys, (v if idcs is None else v[idcs] for v in map(d.get, keys)))) - - def compute_axial_conductances( comp_edges: pd.DataFrame, params: Dict[str, jnp.ndarray] ) -> jnp.ndarray: