diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b4fa8ed..dc56b02f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,8 @@ net.vis() - changelog added to CI (#537, #558, @jnsbck) +- Refactor of channel and synapse stepping internals and how the model is transferred to jax for more efficient and readable code (#487, @jnsbck). + # 0.5.0 ### API changes diff --git a/jaxley/channels/channel.py b/jaxley/channels/channel.py index 678b1e1e..b8a1dc41 100644 --- a/jaxley/channels/channel.py +++ b/jaxley/channels/channel.py @@ -59,22 +59,22 @@ def change_name(self, new_name: str): new_prefix = new_name + "_" self._name = new_name - self.channel_params = { + self.params = { ( new_prefix + key[len(old_prefix) :] if key.startswith(old_prefix) else key ): value - for key, value in self.channel_params.items() + for key, value in self.params.items() } - self.channel_states = { + self.states = { ( new_prefix + key[len(old_prefix) :] if key.startswith(old_prefix) else key ): value - for key, value in self.channel_states.items() + for key, value in self.states.items() } return self diff --git a/jaxley/channels/hh.py b/jaxley/channels/hh.py index c19bf002..70fc72b5 100644 --- a/jaxley/channels/hh.py +++ b/jaxley/channels/hh.py @@ -17,7 +17,7 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gNa": 0.12, f"{prefix}_gK": 0.036, f"{prefix}_gLeak": 0.0003, @@ -25,7 +25,7 @@ def __init__(self, name: Optional[str] = None): f"{prefix}_eK": -77.0, f"{prefix}_eLeak": -54.3, } - self.channel_states = { + self.states = { f"{prefix}_m": 0.2, f"{prefix}_h": 0.2, f"{prefix}_n": 0.2, diff --git a/jaxley/channels/pospischil.py b/jaxley/channels/pospischil.py index 5884deac..8602a72c 100644 --- a/jaxley/channels/pospischil.py +++ b/jaxley/channels/pospischil.py @@ -40,11 +40,11 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gLeak": 1e-4, f"{prefix}_eLeak": -70.0, } - self.channel_states = {} + self.states = {} self.current_name = f"i_{prefix}" def update_states( @@ -77,12 +77,12 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gNa": 50e-3, "eNa": 50.0, "vt": -60.0, # Global parameter, not prefixed with `Na`. } - self.channel_states = {f"{prefix}_m": 0.2, f"{prefix}_h": 0.2} + self.states = {f"{prefix}_m": 0.2, f"{prefix}_h": 0.2} self.current_name = f"i_Na" def update_states( @@ -148,12 +148,12 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gK": 5e-3, "eK": -90.0, "vt": -60.0, # Global parameter, not prefixed with `Na`. } - self.channel_states = {f"{prefix}_n": 0.2} + self.states = {f"{prefix}_n": 0.2} self.current_name = f"i_K" def update_states( @@ -204,12 +204,12 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gKm": 0.004e-3, f"{prefix}_taumax": 4000.0, f"eK": -90.0, } - self.channel_states = {f"{prefix}_p": 0.2} + self.states = {f"{prefix}_p": 0.2} self.current_name = f"i_K" def update_states( @@ -261,11 +261,11 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gCaL": 0.1e-3, "eCa": 120.0, } - self.channel_states = {f"{prefix}_q": 0.2, f"{prefix}_r": 0.2} + self.states = {f"{prefix}_q": 0.2, f"{prefix}_r": 0.2} self.current_name = f"i_Ca" def update_states( @@ -329,12 +329,12 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gCaT": 0.4e-4, f"{prefix}_vx": 2.0, "eCa": 120.0, # Global parameter, not prefixed with `CaT`. } - self.channel_states = {f"{prefix}_u": 0.2} + self.states = {f"{prefix}_u": 0.2} self.current_name = f"i_Ca" def update_states( diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 2893f983..e3648010 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from warnings import warn +import jax import jax.numpy as jnp import numpy as np import pandas as pd @@ -28,12 +29,12 @@ _compute_num_children, build_radiuses_from_xyzr, compute_axial_conductances, + compute_current_density, compute_levels, - convert_point_process_to_distributed, + index_of_a_in_b, interpolate_xyzr, - loc_of_index, + iterate_leaves, params_to_pstate, - query_channel_states_and_params, v_interp, ) from jaxley.utils.debug_solver import compute_morphology_indices @@ -145,9 +146,6 @@ def __init__(self): # List of all types of `jx.Synapse`s. self.synapses: List = [] - self.synapse_param_names = [] - self.synapse_state_names = [] - self.synapse_names = [] self.synapse_current_names: List[str] = [] # List of types of all `jx.Channel`s. @@ -189,7 +187,9 @@ def __str__(self): def __dir__(self): base_dir = object.__dir__(self) - return sorted(base_dir + self.synapse_names + list(self.group_nodes.keys())) + synapses = [s.name for s in self.synapses] + groups = [] if len(self.groups) == 0 else list(self.groups.keys()) + return sorted(base_dir + synapses + groups) def __getattr__(self, key): # Ensure that hidden methods such as `__deepcopy__` still work. @@ -207,22 +207,24 @@ def __getattr__(self, key): return view # intercepts calls to channels - if key in [c._name for c in self.base.channels]: - channel_names = [c._name for c in self.channels] + if key in [c.name for c in self.base.channels]: + channel_names = [c.name for c in self.channels] inds = self.nodes.index[self.nodes[key]].to_numpy() view = self.select(inds) if key in channel_names else self.select(None) view._set_controlled_by_param(key) return view # intercepts calls to synapse types - if key in self.base.synapse_names: + base_syn_names = [s.name for s in self.base.synapses] + syn_names = [s.name for s in self.synapses] + if key in base_syn_names: syn_inds = self.edges[self.edges["type"] == key][ "global_edge_index" ].to_numpy() orig_scope = self._scope view = ( self.scope("global").edge(syn_inds).scope(orig_scope) - if key in self.synapse_names + if key in syn_names else self.select(None) ) view._set_controlled_by_param(key) # overwrites param set by edge @@ -714,42 +716,90 @@ def _gather_channels_from_constituents(self, constituents: List): """ for module in constituents: for channel in module.channels: - if channel._name not in [c._name for c in self.channels]: + if channel.name not in [c.name for c in self.channels]: self.base.channels.append(channel) if channel.current_name not in self.membrane_current_names: self.base.membrane_current_names.append(channel.current_name) # Setting columns of channel names to `False` instead of `NaN`. for channel in self.base.channels: - name = channel._name + name = channel.name self.base.nodes.loc[self.nodes[name].isna(), name] = False - @only_allow_module + def _inds_of_state_param(self, key: str) -> jnp.ndarray: + """lookup the indices for params or states. + + Returns indices that have non-NaN values for the given key in `nodes` or `edges`. + + Args: + key: The name of the param or state to get the indices for. + + Returns: + The indices of the param or state. + """ + if key in self.nodes.columns: + data = self.nodes[key] + return jnp.asarray(data.index[data.notna()]) + elif key in self.edges.columns: + data = self.edges[key] + return jnp.asarray(data.index[data.notna()]) + else: + raise ValueError(f"Key '{key}' not found in nodes or edges") + def to_jax(self): - # TODO FROM #447: Make this work for View? - """Move `.nodes` to `.jaxnodes`. + """Move `jx.Module` to `jax`. Before the actual simulation is run (via `jx.integrate`), all parameters of - the `jx.Module` are stored in `.nodes` (a `pd.DataFrame`). However, for + the `jx.Module` are stored in `.nodes`/`.edges` (`pd.DataFrame`). However, for simulation, these parameters have to be moved to be `jnp.ndarrays` such that they can be processed on GPU/TPU and such that the simulation can be - differentiated. `.to_jax()` copies the `.nodes` to `.jaxnodes`. + differentiated. `.to_jax()` copies `.nodes` to `.jax["nodes"]` and `.edges` + to `.jax["edges"]`. In addition, jax["global"] keeps track of parameters and + states that are shared by multiple mechanisms. """ - self.base.jaxnodes = {} - for key, value in self.base.nodes.to_dict(orient="list").items(): - inds = jnp.arange(len(value)) - self.base.jaxnodes[key] = jnp.asarray(value)[inds] - - # `jaxedges` contains only parameters (no indices). - # `jaxedges` contains only non-Nan elements. This is unlike the channels where - # we allow parameter sharing. - self.base.jaxedges = {} - edges = self.base.edges.to_dict(orient="list") - for i, synapse in enumerate(self.base.synapses): - condition = np.asarray(edges["type_ind"]) == i - for key in synapse.synapse_params: - self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition]) - for key in synapse.synapse_states: - self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition]) + # the parameters and states in the jax["nodes"] are stored on a per-mechanism basis, + # i.e. if only compartment #2 has a HH channels, then the jax["nodes"] will be + # {HH_gNa: [0.1], ...} vs. {gbar_HH: [NaN, NaN, 0.1, ...], ...}. This means + # a seperate lookup table is needed to figure out which parameters and states + # belong to which mechanism and at which global indices the mechanism lives. + + keys = ["nodes", "edges", "global"] + jax = dict(zip(keys, [{"states": {}, "params": {}}] * len(keys))) + + module_param_states = { + "states": ["v"], + "params": ["radius", "length", "axial_resistivity", "capacitance"], + } + + for label, keys in module_param_states.items(): + for key in keys: + jax["global"][label][key] = jnp.asarray(self.nodes[key]) + + for mech in self.channels + self.synapses: + is_channel = isinstance(mech, Channel) + jax_arrays = jax["nodes"] if is_channel else jax["edges"] + data = self.nodes if is_channel else self.edges + + where_mech = data[mech.name] if is_channel else data["type"] == mech.name + mech.indices = jnp.asarray(data.index[where_mech].to_list()) + if isinstance(mech, Synapse): + pre_inds = data["pre_global_comp_index"] + post_inds = data["post_global_comp_index"] + mech.pre_indices = jnp.asarray(pre_inds[where_mech].to_list()) + mech.post_indices = jnp.asarray(post_inds[where_mech].to_list()) + + params = data.loc[where_mech, mech.params.keys()].to_dict(orient="list") + states = data.loc[where_mech, mech.states.keys()].to_dict(orient="list") + + is_global = lambda x: not x.startswith(f"{mech.name}_") + for label, params_or_states in zip(["params", "states"], [params, states]): + for k in params_or_states: + jax_data = jnp.asarray(data[k][data[k].notna()].to_list()) + if not is_global(k): + jax["global"][label][k] = jax_data + else: + jax_arrays[label][k] = jax_data[mech.indices] + + self.jax = jax def show( self, @@ -781,13 +831,9 @@ def show( scopes = ["local", "global"] inds = [f"{s}_{i}" for i in inds for s in scopes] if indices else [] cols += inds - cols += [ch._name for ch in self.channels] if channel_names else [] - cols += ( - sum([list(ch.channel_params) for ch in self.channels], []) if params else [] - ) - cols += ( - sum([list(ch.channel_states) for ch in self.channels], []) if states else [] - ) + cols += [ch.name for ch in self.channels] if channel_names else [] + cols += sum([list(ch.params) for ch in self.channels], []) if params else [] + cols += sum([list(ch.states) for ch in self.channels], []) if states else [] if not param_names is None: cols = ( @@ -915,13 +961,9 @@ def set_ncomp( all_nodes = self.base.nodes start_idx = self.nodes["global_comp_index"].to_numpy()[0] ncomp_per_branch = self.base.ncomp_per_branch - channel_names = [c._name for c in self.base.channels] - channel_param_names = list( - chain(*[c.channel_params for c in self.base.channels]) - ) - channel_state_names = list( - chain(*[c.channel_states for c in self.base.channels]) - ) + channel_names = [c.name for c in self.base.channels] + channel_param_names = list(chain(*[c.params for c in self.base.channels])) + channel_state_names = list(chain(*[c.states for c in self.base.channels])) radius_generating_fns = self.base._radius_generating_fns within_branch_radiuses = view["radius"].to_numpy() @@ -1130,10 +1172,6 @@ def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]): Args: trainable_params: The trainable parameters returned by `get_parameters()`. """ - # We do not support views. Why? `jaxedges` does not have any NaN - # elements, whereas edges does. Because of this, we already need special - # treatment to make this function work, and it would be an even bigger hassle - # if we wanted to support this. assert self.__class__.__name__ in [ "Compartment", "Branch", @@ -1145,31 +1183,27 @@ def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]): # However, I think it allows us to reuse as much code as possible and it avoids # any kind of issues with indexing or parameter sharing (as this is fully # taken care of by `get_all_parameters()`). - self.base.to_jax() - pstate = params_to_pstate(trainable_params, self.base.indices_set_by_trainables) - all_params = self.base.get_all_parameters(pstate, voltage_solver="jaxley.stone") - + self.to_jax() # The value for `delta_t` does not matter here because it is only used to # compute the initial current. However, the initial current cannot be made # trainable and so its value never gets used below. - all_states = self.base.get_all_states(pstate, all_params, delta_t=0.025) + pstate = params_to_pstate(trainable_params, self.indices_set_by_trainables) + all_params_states = self._get_all_states_params( + pstate, + delta_t=0.025, + voltage_solver="jaxley.stone", + params=True, + states=True, + ) # Loop only over the keys in `pstate` to avoid unnecessary computation. - for parameter in pstate: - key = parameter["key"] - if key in self.base.nodes.columns: - vals_to_set = all_params if key in all_params.keys() else all_states - self.base.nodes[key] = vals_to_set[key] - - # `jaxedges` contains only non-Nan elements. This is unlike the channels where - # we allow parameter sharing. - edges = self.base.edges.to_dict(orient="list") - for i, synapse in enumerate(self.base.synapses): - condition = np.asarray(edges["type_ind"]) == i - for key in list(synapse.synapse_params.keys()): - self.base.edges.loc[condition, key] = all_params[key] - for key in list(synapse.synapse_states.keys()): - self.base.edges.loc[condition, key] = all_states[key] + for p in pstate: + key, inds = p["key"], p["indices"] + inds = np.array(inds.reshape(-1)) + data = ( + self.base.nodes if key in self.base.nodes.columns else self.base.edges + ) + data.loc[inds, key] = all_params_states[key][inds] def distance(self, endpoint: "View") -> float: """Return the direct distance between two compartments. @@ -1221,10 +1255,8 @@ def _get_state_names(self) -> Tuple[List, List]: """Collect all recordable / clampable states in the membrane and synapses. Returns states seperated by comps and edges.""" - channel_states = [name for c in self.channels for name in c.channel_states] - synapse_states = [ - name for s in self.synapses if s is not None for name in s.synapse_states - ] + channel_states = [name for c in self.channels for name in c.states] + synapse_states = [name for s in self.synapses for name in s.states] membrane_states = ["v", "i"] + self.membrane_current_names return ( channel_states + membrane_states, @@ -1242,7 +1274,65 @@ def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: """ return self.trainable_params - @only_allow_module + def _get_all_states_params( + self, + pstate: List[Dict], + voltage_solver: str = None, + delta_t: float = None, + all_params: Dict[str, jnp.ndarray] = None, + params: bool = False, + states: bool = False, + ) -> Dict[str, jnp.ndarray]: + """Get all parameters and/or states of the module. + + Common backbone of both `get_all_parameters()` and `get_all_states()`. + + Args: + pstate: The state of the trainable parameters. + voltage_solver: The voltage solver that is used. + delta_t: The stepsize. + all_params: All parameters of the module. + params: Whether to get the parameters. + states: Whether to get the states. + + Returns: + A dictionary of all parameters and/or states of the module. + """ + states_params = {} + + for key, values, path in iterate_leaves(self.jax): + states_params[key] = values + + # Override with those parameters set by `.make_trainable()`. + for p in pstate: + key, inds, set_param = p["key"], p["indices"], p["val"] + # `inds` is of shape `(num_params, num_comps_per_param)`. + # `set_param` is of shape `(num_params,)` + # We need to unsqueeze `set_param` to make it `(num_params, 1)` + # for the `.set()` to work. This is done with `[:, None]`. + param_state_inds = self._inds_of_state_param(key) + inds = index_of_a_in_b(inds, param_state_inds) + states_params[key] = states_params[key].at[inds].set(set_param[:, None]) + + if params: + # Compute conductance params and add them to the params dictionary. + states_params["axial_conductances"] = self._compute_axial_conductances( + params=states_params + ) + + if states: + all_params = states_params if all_params is None and params else all_params + # Add to the states the initial current through every channel. + states, _ = self._channel_currents( + states_params, delta_t, self.channels, self.nodes, all_params + ) + + # Add to the states the initial current through every synapse. + states, _ = self._synapse_currents( + states_params, self.synapses, all_params, delta_t, self.edges + ) + return states_params + def get_all_parameters( self, pstate: List[Dict], voltage_solver: str ) -> Dict[str, jnp.ndarray]: @@ -1263,7 +1353,7 @@ def get_all_parameters( params = module.get_parameters() # i.e. [0, 1, 2] pstate = params_to_pstate(params, module.indices_set_by_trainables) - module.to_jax() # needed for call to module.jaxnodes + module.to_jax() # needed for call to module.jax Args: pstate: The state of the trainable parameters. pstate takes the form @@ -1278,67 +1368,16 @@ def get_all_parameters( Returns: A dictionary of all module parameters. """ - params = {} - for key in ["radius", "length", "axial_resistivity", "capacitance"]: - params[key] = self.base.jaxnodes[key] - - for channel in self.base.channels: - for channel_params in channel.channel_params: - params[channel_params] = self.base.jaxnodes[channel_params] - - for synapse_params in self.base.synapse_param_names: - params[synapse_params] = self.base.jaxedges[synapse_params] - - # Override with those parameters set by `.make_trainable()`. - for parameter in pstate: - key = parameter["key"] - inds = parameter["indices"] - set_param = parameter["val"] - - # This is needed since SynapseViews worked differently before. - # This mimics the old behaviour and tranformes the new indices - # to the old indices. - # TODO FROM #447: Longterm this should be gotten rid of. - # Instead edges should work similar to nodes (would also allow for - # param sharing). - synapse_inds = self.base.edges.groupby("type").rank()["global_edge_index"] - synapse_inds = (synapse_inds.astype(int) - 1).to_numpy() - if key in self.base.synapse_param_names: - inds = synapse_inds[inds] - - if key in params: # Only parameters, not initial states. - # `inds` is of shape `(num_params, num_comps_per_param)`. - # `set_param` is of shape `(num_params,)` - # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the - # `.set()` to work. This is done with `[:, None]`. - params[key] = params[key].at[inds].set(set_param[:, None]) - - # Compute conductance params and add them to the params dictionary. - params["axial_conductances"] = self.base._compute_axial_conductances( - params=params + params = self._get_all_states_params( + pstate, params=True, voltage_solver=voltage_solver ) return params - @only_allow_module - def _get_states_from_nodes_and_edges(self) -> Dict[str, jnp.ndarray]: - # TODO FROM #447: MAKE THIS WORK FOR VIEW? - """Return states as they are set in the `.nodes` and `.edges` tables.""" - self.base.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`. - states = {"v": self.base.jaxnodes["v"]} - # Join node and edge states into a single state dictionary. - for channel in self.base.channels: - for channel_states in channel.channel_states: - states[channel_states] = self.base.jaxnodes[channel_states] - for synapse_states in self.base.synapse_state_names: - states[synapse_states] = self.base.jaxedges[synapse_states] - return states - - @only_allow_module def get_all_states( self, pstate: List[Dict], all_params, delta_t: float ) -> Dict[str, jnp.ndarray]: # TODO FROM #447: MAKE THIS WORK FOR VIEW? - """Get the full initial state of the module from jaxnodes and trainables. + """Get the full initial state of the module from `.jax` and `.trainables`. Args: pstate: The state of the trainable parameters. @@ -1348,28 +1387,8 @@ def get_all_states( Returns: A dictionary of all states of the module. """ - states = self.base._get_states_from_nodes_and_edges() - - # Override with the initial states set by `.make_trainable()`. - for parameter in pstate: - key = parameter["key"] - inds = parameter["indices"] - set_param = parameter["val"] - if key in states: # Only initial states, not parameters. - # `inds` is of shape `(num_params, num_comps_per_param)`. - # `set_param` is of shape `(num_params,)` - # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the - # `.set()` to work. This is done with `[:, None]`. - states[key] = states[key].at[inds].set(set_param[:, None]) - - # Add to the states the initial current through every channel. - states, _ = self.base._channel_currents( - states, delta_t, self.channels, self.nodes, all_params - ) - - # Add to the states the initial current through every synapse. - states, _ = self.base._synapse_currents( - states, self.synapses, all_params, delta_t, self.edges + states = self._get_all_states_params( + pstate, states=True, all_params=all_params, delta_t=delta_t ) return states @@ -1383,7 +1402,35 @@ def _initialize(self): self._init_morph() return self - @only_allow_module + def _filter_by_mech( + self, param_states: Dict, mech: Union[Channel, Synapse] + ) -> Dict: + """Filter params/states to include only those relevant to the active mech. + + Args: + param_states: The param_states dictionary to filter. + mech: The active mechanism. + + Returns: + The filtered dictionary. + """ + is_channel = isinstance(mech, Channel) + i_mech = mech.current_name if is_channel else f"i_{mech.name}" + + filtered_param_states = param_states.copy() + if i_mech in param_states: + filtered_param_states[i_mech] = param_states[i_mech][mech.indices] + + params_and_or_states = ["states"] if "v" in param_states else [] + params_and_or_states += ["params"] if "radius" in param_states else [] + for param_state_key in params_and_or_states: + for key, _, _ in iterate_leaves(self.jax["global"][param_state_key]): + if key in param_states: + param_state_inds = self._inds_of_state_param(key) + filtered_inds = index_of_a_in_b(mech.indices, param_state_inds) + filtered_param_states[key] = param_states[key][filtered_inds] + return filtered_param_states + def init_states(self, delta_t: float = 0.025): # TODO FROM #447: MAKE THIS WORK FOR VIEW? """Initialize all mechanisms in their steady state. @@ -1394,33 +1441,23 @@ def init_states(self, delta_t: float = 0.025): delta_t: Passed on to `channel.init_state()`. """ # Update states of the channels. - channel_nodes = self.base.nodes - states = self.base._get_states_from_nodes_and_edges() + self.to_jax() # Create `.jax` from `.nodes` and `.edges`. # We do not use any `pstate` for initializing. In principle, we could change # that by allowing an input `params` and `pstate` to this function. # `voltage_solver` could also be `jax.sparse` here, because both of them # build the channel parameters in the same way. - params = self.base.get_all_parameters([], voltage_solver="jaxley.thomas") - - for channel in self.base.channels: - name = channel._name - channel_indices = channel_nodes.loc[channel_nodes[name]][ - "global_comp_index" - ].to_numpy() - voltages = channel_nodes.loc[channel_indices, "v"].to_numpy() + param_states = self._get_all_states_params([], voltage_solver="jaxley.thomas") + voltages = param_states["v"] - channel_param_names = list(channel.channel_params.keys()) - channel_state_names = list(channel.channel_states.keys()) - channel_states = query_channel_states_and_params( - states, channel_state_names, channel_indices - ) - channel_params = query_channel_states_and_params( - params, channel_param_names, channel_indices - ) + for channel in self.channels: + channel_param_states = self._filter_by_mech(param_states, channel) init_state = channel.init_state( - channel_states, voltages, channel_params, delta_t + channel_param_states, + voltages[channel.indices], + channel_param_states, + delta_t, ) # `init_state` might not return all channel states. Only the ones that are @@ -1429,7 +1466,7 @@ def init_states(self, delta_t: float = 0.025): # Note that we are overriding `self.nodes` here, but `self.nodes` is # not used above to actually compute the current states (so there are # no issues with overriding states). - self.nodes.loc[channel_indices, key] = val + self.base.nodes.loc[channel.indices, key] = val def _init_morph_for_debugging(self): """Instandiates row and column inds which can be used to solve the voltage eqs. @@ -1732,10 +1769,10 @@ def insert(self, channel: Channel): Args: channel: The channel to insert.""" - name = channel._name + name = channel.name # Channel does not yet exist in the `jx.Module` at all. - if name not in [c._name for c in self.base.channels]: + if name not in [c.name for c in self.base.channels]: self.base.channels.append(channel) self.base.nodes[name] = ( False # Previous columns do not have the new channel. @@ -1748,24 +1785,23 @@ def insert(self, channel: Channel): self.base.nodes.loc[self._nodes_in_view, name] = True # Loop over all new parameters, e.g. gNa, eNa. - for key in channel.channel_params: - self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_params[key] + for key in channel.params: + self.base.nodes.loc[self._nodes_in_view, key] = channel.params[key] # Loop over all new parameters, e.g. gNa, eNa. - for key in channel.channel_states: - self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_states[key] + for key in channel.states: + self.base.nodes.loc[self._nodes_in_view, key] = channel.states[key] def delete_channel(self, channel: Channel): """Remove a channel from the module. Args: channel: The channel to remove.""" - name = channel._name - channel_names = [c._name for c in self.channels] - all_channel_names = [c._name for c in self.base.channels] + name = channel.name + channel_names = [c.name for c in self.channels] + all_channel_names = [c.name for c in self.base.channels] if name in channel_names: - channel_cols = list(channel.channel_params.keys()) - channel_cols += list(channel.channel_states.keys()) + channel_cols = list({**channel.params, **channel.states}.keys()) self.base.nodes.loc[self._nodes_in_view, channel_cols] = float("nan") self.base.nodes.loc[self._nodes_in_view, name] = False @@ -1809,7 +1845,6 @@ def step( Returns: The updated state of the module. """ - # Extract the voltages voltages = u["v"] @@ -1934,48 +1969,50 @@ def _step_channels( ) return states, current_terms - def _step_channels_state( + def _step_mech_state( self, - states, - delta_t, - channels: List[Channel], - channel_nodes: pd.DataFrame, + states: Dict[str, jnp.ndarray], + delta_t: float, + mechs: List, + mech_data: pd.DataFrame, params: Dict[str, jnp.ndarray], ) -> Dict[str, jnp.ndarray]: - """One integration step of the channels.""" voltages = states["v"] - # Update states of the channels. - indices = channel_nodes["global_comp_index"].to_numpy() - for channel in channels: - channel_param_names = list(channel.channel_params) - channel_param_names += [ - "radius", - "length", - "axial_resistivity", - "capacitance", - ] - channel_state_names = list(channel.channel_states) - channel_state_names += self.membrane_current_names - channel_indices = indices[channel_nodes[channel._name].astype(bool)] - - channel_params = query_channel_states_and_params( - params, channel_param_names, channel_indices - ) - channel_states = query_channel_states_and_params( - states, channel_state_names, channel_indices + for mech in mechs: + # States updates. + mech_states = self._filter_by_mech(states, mech) + mech_params = self._filter_by_mech(params, mech) + v_mech = ( + (voltages[mech.indices],) + if isinstance(mech, Channel) + else (voltages[mech.pre_indices], voltages[mech.post_indices]) ) - states_updated = channel.update_states( - channel_states, delta_t, voltages[channel_indices], channel_params + mech_states_updated = mech.update_states( + mech_states, delta_t, *v_mech, mech_params ) + # Rebuild state. This has to be done within the loop over channels to allow # multiple channels which modify the same state. - for key, val in states_updated.items(): - states[key] = states[key].at[channel_indices].set(val) + for key, val in mech_states_updated.items(): + param_state_inds = self._inds_of_state_param(key) + inds = index_of_a_in_b(mech.indices, param_state_inds) + states[key] = states[key].at[inds].set(val) return states + def _step_channels_state( + self, + states, + delta_t, + channels: List[Channel], + channel_nodes: pd.DataFrame, + params: Dict[str, jnp.ndarray], + ) -> Dict[str, jnp.ndarray]: + """One integration step of the channels.""" + return self._step_mech_state(states, delta_t, channels, channel_nodes, params) + def _channel_currents( self, states: Dict[str, jnp.ndarray], @@ -1989,53 +2026,41 @@ def _channel_currents( This is also updates `state` because the `state` also contains the current. """ voltages = states["v"] + morph_params = ["radius", "length", "axial_resistivity", "capacitance"] + morph_params = {pkey: params[pkey] for pkey in morph_params} # Compute current through channels. - voltage_terms = jnp.zeros_like(voltages) - constant_terms = jnp.zeros_like(voltages) + zeros = jnp.zeros_like(voltages) + voltage_terms, const_terms = zeros, zeros # Run with two different voltages that are `diff` apart to infer the slope and # offset. diff = 1e-3 - - current_states = {} - for name in self.membrane_current_names: - current_states[name] = jnp.zeros_like(voltages) - + current_states = {name: zeros for name in self.membrane_current_names} for channel in channels: - name = channel._name - channel_param_names = list(channel.channel_params.keys()) - channel_state_names = list(channel.channel_states.keys()) - indices = channel_nodes.loc[channel_nodes[name]][ - "global_comp_index" - ].to_numpy() - - channel_params = {} - for p in channel_param_names: - channel_params[p] = params[p][indices] - channel_params["radius"] = params["radius"][indices] - channel_params["length"] = params["length"][indices] - channel_params["axial_resistivity"] = params["axial_resistivity"][indices] + channel_inds = channel.indices + v_channel = voltages[channel_inds] + v_and_perturbed = jnp.array([v_channel, v_channel + diff]) - channel_states = {} - for s in channel_state_names: - channel_states[s] = states[s][indices] + channel_states = self._filter_by_mech(states, channel) + channel_params = self._filter_by_mech(params, channel) - v_and_perturbed = jnp.stack([voltages[indices], voltages[indices] + diff]) membrane_currents = vmap(channel.compute_current, in_axes=(None, 0, None))( channel_states, v_and_perturbed, channel_params ) + + # Split into voltage and constant terms. voltage_term = (membrane_currents[1] - membrane_currents[0]) / diff - constant_term = membrane_currents[0] - voltage_term * voltages[indices] + const_term = membrane_currents[0] - voltage_term * v_channel # * 1000 to convert from mA/cm^2 to uA/cm^2. - voltage_terms = voltage_terms.at[indices].add(voltage_term * 1000.0) - constant_terms = constant_terms.at[indices].add(-constant_term * 1000.0) + voltage_terms = voltage_terms.at[channel_inds].add(voltage_term * 1000.0) + const_terms = const_terms.at[channel_inds].add(-const_term * 1000.0) # Save the current (for the unperturbed voltage) as a state that will # also be passed to the state update. current_states[channel.current_name] = ( current_states[channel.current_name] - .at[indices] + .at[channel_inds] .add(membrane_currents[0]) ) @@ -2043,8 +2068,7 @@ def _channel_currents( # recorded and used by `Channel.update_states()`. for name in self.membrane_current_names: states[name] = current_states[name] - - return states, (voltage_terms, constant_terms) + return states, (voltage_terms, const_terms) def _step_synapse( self, @@ -2085,7 +2109,7 @@ def _get_external_input( length_single_compartment: um. """ zero_vec = jnp.zeros_like(voltages) - current = convert_point_process_to_distributed( + current = compute_current_density( i_stim, radius[i_inds], length_single_compartment[i_inds] ) @@ -2443,10 +2467,12 @@ def __init__( self.ncomp = pointer.ncomp self.nodes = pointer.nodes.loc[self._nodes_in_view] - ptr_edges = pointer.edges - self.edges = ( - ptr_edges if ptr_edges.empty else ptr_edges.loc[self._edges_in_view] - ) + self.edges = pointer.edges + if not self.edges.empty: + self.edges = pointer.edges.loc[self._edges_in_view] + + # re-enumerate type_inds + self.edges["type_ind"] = self.edges["type"].astype("category").cat.codes self.xyzr = self._xyzr_in_view() self.ncomp = 1 if len(self.nodes) == 1 else pointer.ncomp @@ -2458,8 +2484,8 @@ def __init__( self.ncomp_per_branch = self.base.ncomp_per_branch[self._branches_in_view] self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch) - self.synapse_names = np.unique(self.edges["type"]).tolist() - self._set_synapses_in_view(pointer) + self.synapses = self._synapses_in_view(pointer) + self.channels = self._channels_in_view(pointer) ptr_recs = pointer.recordings self.recordings = ( @@ -2468,7 +2494,6 @@ def __init__( else ptr_recs.loc[ptr_recs["rec_index"].isin(self._comps_in_view)] ) - self.channels = self._channels_in_view(pointer) self.membrane_current_names = [c.current_name for c in self.channels] self.synapse_current_names = pointer.synapse_current_names self._set_trainables_in_view() # run after synapses and channels @@ -2485,9 +2510,8 @@ def __init__( k: np.intersect1d(v, self._nodes_in_view) for k, v in pointer.groups.items() } - self.jaxnodes, self.jaxedges = self._jax_arrays_in_view( - pointer - ) # run after trainables + if pointer.jax: + self.to_jax() self._current_view = "view" # if not instantiated via `comp`, `cell` etc. self._update_local_indices() @@ -2536,31 +2560,6 @@ def _set_inds_in_view( self._nodes_in_view = nodes self._edges_in_view = edges - def _jax_arrays_in_view(self, pointer: Union[Module, View]): - """Update jaxnodes/jaxedges to show only those currently in view.""" - a_intersects_b_at = lambda a, b: jnp.intersect1d(a, b, return_indices=True)[1] - jaxnodes = {} if pointer.jaxnodes is not None else None - if self.jaxnodes is not None: - comp_inds = pointer.jaxnodes["global_comp_index"] - common_inds = a_intersects_b_at(comp_inds, self._nodes_in_view) - jaxnodes = { - k: v[common_inds] - for k, v in pointer.jaxnodes.items() - if len(common_inds) > 0 - } - - jaxedges = {} if pointer.jaxedges is not None else None - if pointer.jaxedges is not None: - for key, values in self.base.jaxedges.items(): - if (syn_name := key.split("_")[0]) in self.synapse_names: - syn_edges = self.base.edges[self.base.edges["type"] == syn_name] - inds = np.intersect1d( - self._edges_in_view, syn_edges.index, return_indices=True - )[2] - if len(inds) > 0: - jaxedges[key] = values[inds] - return jaxnodes, jaxedges - def _set_externals_in_view(self): """Update external inputs to show only those currently in view.""" self.externals = {} @@ -2599,13 +2598,9 @@ def _filter_trainables( ): pkey, pval = next(iter(params.items())) trainable_inds_in_view = None - if pkey in sum( - [list(c.channel_params.keys()) for c in self.base.channels], [] - ): + if pkey in sum([list(c.params.keys()) for c in self.base.channels], []): trainable_inds_in_view = np.intersect1d(inds, self._nodes_in_view) - elif pkey in sum( - [list(s.synapse_params.keys()) for s in self.base.synapses], [] - ): + elif pkey in sum([list(s.params.keys()) for s in self.base.synapses], []): trainable_inds_in_view = np.intersect1d(inds, self._edges_in_view) in_view = is_viewed == np.isin(inds, trainable_inds_in_view) @@ -2651,28 +2646,15 @@ def _set_trainables_in_view(self): def _channels_in_view(self, pointer: Union[Module, View]) -> List[Channel]: """Set channels to show only those in view.""" - names = [name._name for name in pointer.channels] + names = [c.name for c in pointer.channels] channel_in_view = self.nodes[names].any(axis=0) channel_in_view = channel_in_view[channel_in_view].index - return [c for c in pointer.channels if c._name in channel_in_view] + return [c for c in self.base.channels if c.name in channel_in_view] - def _set_synapses_in_view(self, pointer: Union[Module, View]): + def _synapses_in_view(self, pointer: Union[Module, View]): """Set synapses to show only those in view.""" - viewed_synapses = [] - viewed_params = [] - viewed_states = [] - if pointer.synapses is not None: - for syn in pointer.synapses: - if syn is not None: # needed for recurive viewing - in_view = syn._name in self.synapse_names - viewed_synapses += ( - [syn] if in_view else [None] - ) # padded with None to keep indices consistent - viewed_params += list(syn.synapse_params.keys()) if in_view else [] - viewed_states += list(syn.synapse_states.keys()) if in_view else [] - self.synapses = viewed_synapses - self.synapse_param_names = viewed_params - self.synapse_state_names = viewed_states + names = self.edges["type"].unique() + return [deepcopy(syn) for syn in pointer.synapses if syn.name in names] def _nbranches_per_cell_in_view(self) -> np.ndarray: cell_nodes = self.nodes.groupby("global_cell_index") diff --git a/jaxley/modules/branch.py b/jaxley/modules/branch.py index 74ca31a4..f237c1b1 100644 --- a/jaxley/modules/branch.py +++ b/jaxley/modules/branch.py @@ -10,7 +10,7 @@ from jaxley.modules.base import Module from jaxley.modules.compartment import Compartment -from jaxley.utils.cell_utils import compute_children_and_parents +from jaxley.utils.cell_utils import compute_children_and_parents, dtype_aware_concat from jaxley.utils.misc_utils import cumsum_leading_zero, deprecated_kwargs from jaxley.utils.solver_utils import JaxleySolveIndexer, comp_edges_to_indices @@ -73,7 +73,7 @@ def __init__( self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch) # Indexing. - self.nodes = pd.concat([c.nodes for c in compartment_list], ignore_index=True) + self.nodes = dtype_aware_concat([c.nodes for c in compartment_list]) self._append_params_and_states(self.branch_params, self.branch_states) self.nodes["global_comp_index"] = np.arange(self.ncomp).tolist() self.nodes["global_branch_index"] = [0] * self.ncomp diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index 3d6b39da..c440384b 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -18,6 +18,7 @@ compute_levels, compute_morphology_indices_in_levels, compute_parents_in_level, + dtype_aware_concat, ) from jaxley.utils.misc_utils import cumsum_leading_zero, deprecated_kwargs from jaxley.utils.solver_utils import ( @@ -102,7 +103,7 @@ def __init__( self._internal_node_inds = np.arange(self.cumsum_ncomp[-1]) # Build nodes. Has to be changed when `.set_ncomp()` is run. - self.nodes = pd.concat([c.nodes for c in branch_list], ignore_index=True) + self.nodes = dtype_aware_concat([c.nodes for c in branch_list]) self.nodes["global_comp_index"] = np.arange(self.cumsum_ncomp[-1]) self.nodes["global_branch_index"] = np.repeat( np.arange(self.total_nbranches), self.ncomp_per_branch diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 15183bd6..6f86a92e 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -18,7 +18,9 @@ from jaxley.utils.cell_utils import ( build_branchpoint_group_inds, compute_children_and_parents, - convert_point_process_to_distributed, + compute_current_density, + dtype_aware_concat, + index_of_a_in_b, loc_of_index, merge_cells, ) @@ -66,7 +68,7 @@ def __init__( self.total_nbranches = sum(self.nbranches_per_cell) self._cumsum_nbranches = cumsum_leading_zero(self.nbranches_per_cell) - self.nodes = pd.concat([c.nodes for c in cells], ignore_index=True) + self.nodes = dtype_aware_concat([c.nodes for c in cells]) self.nodes["global_comp_index"] = np.arange(self.cumsum_ncomp[-1]) self.nodes["global_branch_index"] = np.repeat( np.arange(self.total_nbranches), self.ncomp_per_branch @@ -262,44 +264,7 @@ def _step_synapse_state( delta_t: float, edges: pd.DataFrame, ) -> Dict: - voltages = states["v"] - - grouped_syns = edges.groupby("type", sort=False, group_keys=False) - pre_syn_inds = grouped_syns["pre_global_comp_index"].apply(list) - post_syn_inds = grouped_syns["post_global_comp_index"].apply(list) - synapse_names = list(grouped_syns.indices.keys()) - - for i, synapse_type in enumerate(syn_channels): - assert ( - synapse_names[i] == synapse_type._name - ), "Mixup in the ordering of synapses. Please create an issue on Github." - synapse_param_names = list(synapse_type.synapse_params.keys()) - synapse_state_names = list(synapse_type.synapse_states.keys()) - - synapse_params = {} - for p in synapse_param_names: - synapse_params[p] = params[p] - synapse_states = {} - for s in synapse_state_names: - synapse_states[s] = states[s] - - pre_inds = np.asarray(pre_syn_inds[synapse_names[i]]) - post_inds = np.asarray(post_syn_inds[synapse_names[i]]) - - # State updates. - states_updated = synapse_type.update_states( - synapse_states, - delta_t, - voltages[pre_inds], - voltages[post_inds], - synapse_params, - ) - - # Rebuild state. - for key, val in states_updated.items(): - states[key] = val - - return states + return self._step_mech_state(states, delta_t, syn_channels, edges, params) def _synapse_currents( self, @@ -311,50 +276,35 @@ def _synapse_currents( ) -> Tuple[Dict, Tuple[jnp.ndarray, jnp.ndarray]]: voltages = states["v"] - grouped_syns = edges.groupby("type", sort=False, group_keys=False) - pre_syn_inds = grouped_syns["pre_global_comp_index"].apply(list) - post_syn_inds = grouped_syns["post_global_comp_index"].apply(list) - synapse_names = list(grouped_syns.indices.keys()) - - syn_voltage_terms = jnp.zeros_like(voltages) - syn_constant_terms = jnp.zeros_like(voltages) + # Compute current through synapses. + zeros = jnp.zeros_like(voltages) + syn_voltage_terms, syn_const_terms = zeros, zeros # Run with two different voltages that are `diff` apart to infer the slope and # offset. diff = 1e-3 - for i, synapse_type in enumerate(syn_channels): - assert ( - synapse_names[i] == synapse_type._name - ), "Mixup in the ordering of synapses. Please create an issue on Github." - synapse_param_names = list(synapse_type.synapse_params.keys()) - synapse_state_names = list(synapse_type.synapse_states.keys()) - - synapse_params = {} - for p in synapse_param_names: - synapse_params[p] = params[p] - synapse_states = {} - for s in synapse_state_names: - synapse_states[s] = states[s] - - # Get pre and post indexes of the current synapse type. - pre_inds = np.asarray(pre_syn_inds[synapse_names[i]]) - post_inds = np.asarray(post_syn_inds[synapse_names[i]]) - - # Compute slope and offset of the current through every synapse. - pre_v_and_perturbed = jnp.stack( - [voltages[pre_inds], voltages[pre_inds] + diff] - ) - post_v_and_perturbed = jnp.stack( - [voltages[post_inds], voltages[post_inds] + diff] - ) + + num_comp = len(voltages) + synapse_current_states = {f"i_{s.name}": zeros for s in syn_channels} + for synapse in syn_channels: + pre_inds = synapse.pre_indices + post_inds = synapse.post_indices + + v_pre, v_post = voltages[pre_inds], voltages[post_inds] + pre_v_and_perturbed = jnp.array([v_pre, v_pre + diff]) + post_v_and_perturbed = jnp.array([v_post, v_post + diff]) + + syn_states = self._filter_by_mech(states, synapse) + syn_params = self._filter_by_mech(params, synapse) + synapse_currents = vmap( - synapse_type.compute_current, in_axes=(None, 0, 0, None) + synapse.compute_current, in_axes=(None, 0, 0, None) )( - synapse_states, + syn_states, pre_v_and_perturbed, post_v_and_perturbed, - synapse_params, + syn_params, ) - synapse_currents_dist = convert_point_process_to_distributed( + synapse_currents_dist = compute_current_density( synapse_currents, params["radius"][post_inds], params["length"][post_inds], @@ -362,27 +312,28 @@ def _synapse_currents( # Split into voltage and constant terms. voltage_term = (synapse_currents_dist[1] - synapse_currents_dist[0]) / diff - constant_term = ( - synapse_currents_dist[0] - voltage_term * voltages[post_inds] - ) + constant_term = synapse_currents_dist[0] - voltage_term * v_post + syn_voltages = voltage_term, constant_term # Gather slope and offset for every postsynaptic compartment. - gathered_syn_currents = gather_synapes( - len(voltages), - post_inds, - voltage_term, - constant_term, + # import jax; jax.debug.print("{}", synapse_params) + gathered_syn_currents = gather_synapes(num_comp, post_inds, *syn_voltages) + + syn_voltage_terms = syn_voltage_terms.at[:].add(gathered_syn_currents[0]) + syn_const_terms = syn_const_terms.at[:].add(-gathered_syn_currents[1]) + # Save the current (for the unperturbed voltage) as a state that will + # also be passed to the state update. + synapse_current_states[f"i_{synapse.name}"] = ( + synapse_current_states[f"i_{synapse.name}"] + .at[post_inds] + .add(synapse_currents_dist[0]) ) - syn_voltage_terms += gathered_syn_currents[0] - syn_constant_terms -= gathered_syn_currents[1] - - # Add the synaptic currents through every compartment as state. - # `post_syn_currents` is a `jnp.ndarray` of as many elements as there are - # compartments in the network. - # `[0]` because we only use the non-perturbed voltage. - states[f"i_{synapse_type._name}"] = synapse_currents[0] - return states, (syn_voltage_terms, syn_constant_terms) + # Copy the currents into the `state` dictionary such that they can be + # recorded and used by `Channel.update_states()`. + for name in [s.name for s in self.synapses]: + states[f"i_{name}"] = synapse_current_states[f"i_{name}"] + return states, (syn_voltage_terms, syn_const_terms) def arrange_in_layers( self, @@ -506,25 +457,18 @@ def vis( return ax def _infer_synapse_type_ind(self, synapse_name): - syn_names = self.base.synapse_names + syn_names = [s.name for s in self.base.synapses] is_new_type = False if synapse_name in syn_names else True type_ind = len(syn_names) if is_new_type else syn_names.index(synapse_name) return type_ind, is_new_type - def _update_synapse_state_names(self, synapse_type): - # (Potentially) update variables that track meta information about synapses. - self.base.synapse_names.append(synapse_type._name) - self.base.synapse_param_names += list(synapse_type.synapse_params.keys()) - self.base.synapse_state_names += list(synapse_type.synapse_states.keys()) - self.base.synapses.append(synapse_type) - def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type): # Add synapse types to the module and infer their unique identifier. - synapse_name = synapse_type._name + synapse_name = synapse_type.name synapse_current_name = f"i_{synapse_name}" type_ind, is_new = self._infer_synapse_type_ind(synapse_name) if is_new: # synapse is not known - self._update_synapse_state_names(synapse_type) + self.base.synapses.append(synapse_type) self.base.synapse_current_names.append(synapse_current_name) index = len(self.base.edges) @@ -567,9 +511,9 @@ def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type): def _add_params_to_edges(self, synapse_type, indices): # Add parameters and states to the `.edges` table. - for key, param_val in synapse_type.synapse_params.items(): + for key, param_val in synapse_type.params.items(): self.base.edges.loc[indices, key] = param_val # Update synaptic state array. - for key, state_val in synapse_type.synapse_states.items(): + for key, state_val in synapse_type.states.items(): self.base.edges.loc[indices, key] = state_val diff --git a/jaxley/synapses/ionotropic.py b/jaxley/synapses/ionotropic.py index da89113f..101dd95b 100644 --- a/jaxley/synapses/ionotropic.py +++ b/jaxley/synapses/ionotropic.py @@ -32,12 +32,12 @@ class IonotropicSynapse(Synapse): def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.synapse_params = { + self.params = { f"{prefix}_gS": 1e-4, f"{prefix}_e_syn": 0.0, f"{prefix}_k_minus": 0.025, } - self.synapse_states = {f"{prefix}_s": 0.2} + self.states = {f"{prefix}_s": 0.2} def update_states( self, diff --git a/jaxley/synapses/synapse.py b/jaxley/synapses/synapse.py index a3b4752f..38cd7d3f 100644 --- a/jaxley/synapses/synapse.py +++ b/jaxley/synapses/synapse.py @@ -38,22 +38,22 @@ def change_name(self, new_name: str): new_prefix = new_name + "_" self._name = new_name - self.synapse_params = { + self.params = { ( new_prefix + key[len(old_prefix) :] if key.startswith(old_prefix) else key ): value - for key, value in self.synapse_params.items() + for key, value in self.params.items() } - self.synapse_states = { + self.states = { ( new_prefix + key[len(old_prefix) :] if key.startswith(old_prefix) else key ): value - for key, value in self.synapse_states.items() + for key, value in self.states.items() } return self diff --git a/jaxley/synapses/tanh_rate.py b/jaxley/synapses/tanh_rate.py index e006a278..6bbd49cc 100644 --- a/jaxley/synapses/tanh_rate.py +++ b/jaxley/synapses/tanh_rate.py @@ -16,12 +16,12 @@ class TanhRateSynapse(Synapse): def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.synapse_params = { + self.params = { f"{prefix}_gS": 1e-4, f"{prefix}_x_offset": -70.0, f"{prefix}_slope": 1.0, } - self.synapse_states = {} + self.states = {} def update_states( self, diff --git a/jaxley/synapses/test.py b/jaxley/synapses/test.py index 49a7311e..84cb5d4d 100644 --- a/jaxley/synapses/test.py +++ b/jaxley/synapses/test.py @@ -19,8 +19,8 @@ class TestSynapse(Synapse): def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.synapse_params = {f"{prefix}_gC": 1e-4} - self.synapse_states = {f"{prefix}_c": 0.2} + self.params = {f"{prefix}_gC": 1e-4} + self.states = {f"{prefix}_c": 0.2} def update_states( self, diff --git a/jaxley/utils/cell_utils.py b/jaxley/utils/cell_utils.py index d71014c2..ae408dbc 100644 --- a/jaxley/utils/cell_utils.py +++ b/jaxley/utils/cell_utils.py @@ -610,7 +610,7 @@ def params_to_pstate( ] -def convert_point_process_to_distributed( +def compute_current_density( current: jnp.ndarray, radius: jnp.ndarray, length: jnp.ndarray ) -> jnp.ndarray: """Convert current point process (nA) to distributed current (uA/cm2). @@ -686,21 +686,6 @@ def group_and_sum( return group_sums -def query_channel_states_and_params(d, keys, idcs): - """Get dict with subset of keys and values from d. - - This is used to restrict a dict where every item contains __all__ states to only - the ones that are relevant for the channel. E.g. - - ```states = {'eCa': Array([ 0., 0., nan]}``` - - will be - ```states = {'eCa': Array([ 0., 0.]}``` - - Only loops over necessary keys, as opposed to looping over `d.items()`.""" - return dict(zip(keys, (v[idcs] for v in map(d.get, keys)))) - - def compute_axial_conductances( comp_edges: pd.DataFrame, params: Dict[str, jnp.ndarray] ) -> jnp.ndarray: @@ -774,3 +759,68 @@ def compute_children_and_parents( child_belongs_to_branchpoint = remap_to_consecutive(par_inds) par_inds = np.unique(par_inds) return par_inds, child_inds, child_belongs_to_branchpoint + + +def dtype_aware_concat(dfs): + concat_df = pd.concat(dfs, ignore_index=True) + # replace nans with Nones + # this correctly casts float(None) -> NaN, bool(None) -> NaN, etc. + concat_df[concat_df.isna()] = None + for col in concat_df.columns[concat_df.dtypes == "object"]: + for df in dfs: + if col in df.columns: + concat_df[col] = concat_df[col].astype(df[col].dtype) + break # first match is sufficient + return concat_df + + +def index_of_a_in_b(A: jnp.ndarray, B: jnp.ndarray) -> jnp.ndarray: + """Replace values in A with the indices of the corresponding values in B. + + Mainly used to determine the indices of parameters in `jax` based on the global + indices of the parameters in the cell. All values in A that are not in B are + replaced with -1. + + Example: + - indices_of_gNa = [5,6,7,8,9] + - indices_to_change = [6,7] + - index_of_a_in_b(indices_to_change, indices_of_gNa) -> [1,2] + + Args: + A: Array of shape (N, M). + B: Array of shape (N, K). + + Returns: + Array of shape of A with the indices of the values of A in B.""" + A_is_flat = A.ndim == 1 + A = A.reshape(1, -1) if A_is_flat else A + matches = A[:, :, None] == B + exists_in_B = matches.any(axis=-1) # mask for vals also in B + indices = jnp.where( + matches, jnp.arange(len(B))[None, None, :], 0 + ) # inds of matches + result = jnp.sum(indices, axis=-1) # Sum along last axis to get the indices + inds = jnp.where(exists_in_B, result, -1) # Replace values not in B with -1 + return inds.flatten() if A_is_flat else inds + + +def iterate_leaves(tree, path=[]): + """Iterate over all leafs (arrays) in a pytree while keeping track of their paths. + + Args: + tree: The pytree to iterate over + path: Current path in the tree (used recursively) + + Yields: + tuple: (final_key, array_value, full_path) + """ + if isinstance(tree, dict): + for key, value in tree.items(): + yield from iterate_leaves(value, path + [key]) + elif isinstance(tree, (list, tuple)): + for i, value in enumerate(tree): + yield from iterate_leaves(value, path + [str(i)]) + else: + # Assuming any non-dict/list/tuple is a leaf node (Array in this case) + if path: # Only yield if we have a path + yield path[-1], tree, path diff --git a/tests/test_channels.py b/tests/test_channels.py index 4063fd3e..62f069c9 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -25,13 +25,13 @@ def __init__( ): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = { + self.params = { f"{self._name}_gamma": 0.05, # Fraction of free calcium (not buffered) f"{self._name}_decay": 80, # Rate of removal of calcium in ms f"{self._name}_depth": 0.1, # Depth of shell in um f"{self._name}_minCai": 1e-4, # Minimum intracellular calcium concentration in mM } - self.channel_states = { + self.states = { f"CaCon_i": 5e-05, # Initial internal calcium concentration in mM } self.current_name = f"i_Ca" @@ -84,8 +84,8 @@ def __init__( "T": 279.45, # Kelvin (temperature) "R": 8.314, # J/(mol K) (gas constant) } - self.channel_params = {} - self.channel_states = {"eCa": 0.0, "CaCon_i": 5e-05, "CaCon_e": 2.0} + self.params = {} + self.states = {"eCa": 0.0, "CaCon_i": 5e-05, "CaCon_e": 2.0} self.current_name = f"i_Ca" def update_states(self, u, dt, voltages, params): @@ -117,21 +117,21 @@ def test_channel_set_name(): # channel name can be set in the constructor na = Na(name="NaPospischil") assert na.name == "NaPospischil" - assert "NaPospischil_gNa" in na.channel_params.keys() - assert "eNa" in na.channel_params.keys() - assert "NaPospischil_h" in na.channel_states.keys() - assert "NaPospischil_m" in na.channel_states.keys() - assert "NaPospischil_vt" not in na.channel_params.keys() - assert "vt" in na.channel_params.keys() + assert "NaPospischil_gNa" in na.params.keys() + assert "eNa" in na.params.keys() + assert "NaPospischil_h" in na.states.keys() + assert "NaPospischil_m" in na.states.keys() + assert "NaPospischil_vt" not in na.params.keys() + assert "vt" in na.params.keys() # channel name can not be changed directly k = K() with pytest.raises(AttributeError): k.name = "KPospischil" - assert "KPospischil_gNa" not in k.channel_params.keys() - assert "eNa" not in k.channel_params.keys() - assert "KPospischil_h" not in k.channel_states.keys() - assert "KPospischil_m" not in k.channel_states.keys() + assert "KPospischil_gNa" not in k.params.keys() + assert "eNa" not in k.params.keys() + assert "KPospischil_h" not in k.states.keys() + assert "KPospischil_m" not in k.states.keys() def test_channel_change_name(): @@ -139,12 +139,12 @@ def test_channel_change_name(): # (and only this way after initialization) na = Na().change_name("NaPospischil") assert na.name == "NaPospischil" - assert "NaPospischil_gNa" in na.channel_params.keys() - assert "eNa" in na.channel_params.keys() - assert "NaPospischil_h" in na.channel_states.keys() - assert "NaPospischil_m" in na.channel_states.keys() - assert "NaPospischil_vt" not in na.channel_params.keys() - assert "vt" in na.channel_params.keys() + assert "NaPospischil_gNa" in na.params.keys() + assert "eNa" in na.params.keys() + assert "NaPospischil_h" in na.states.keys() + assert "NaPospischil_m" in na.states.keys() + assert "NaPospischil_vt" not in na.params.keys() + assert "vt" in na.params.keys() def test_integration_with_renamed_channels(): @@ -200,12 +200,12 @@ def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_q10_ch": 3, f"{prefix}_q10_ch0": 22, "celsius": 22, } - self.channel_states = {f"{prefix}_m": 0.02, "CaCon_i": 1e-4} + self.states = {f"{prefix}_m": 0.02, "CaCon_i": 1e-4} self.current_name = f"i_K" def update_states( @@ -291,8 +291,8 @@ class User(Channel): def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = {} - self.channel_states = {"cumulative": 0.0} + self.params = {} + self.states = {"cumulative": 0.0} self.current_name = f"i_User" def update_states(self, states, dt, v, params): @@ -307,8 +307,8 @@ class Dummy1(Channel): def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = {} - self.channel_states = {} + self.params = {} + self.states = {} self.current_name = f"i_Dummy" def update_states(self, states, dt, v, params): @@ -321,8 +321,8 @@ class Dummy2(Channel): def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = {} - self.channel_states = {} + self.params = {} + self.states = {} self.current_name = f"i_Dummy" def update_states(self, states, dt, v, params): @@ -365,9 +365,7 @@ def test_delete_channel(SimpleBranch): branch3.delete_channel(K()) def channel_present(view, channel, partial=False): - states_and_params = list(channel.channel_states.keys()) + list( - channel.channel_params.keys() - ) + states_and_params = list(channel.states.keys()) + list(channel.params.keys()) # none of the states or params should be in nodes cols = view.nodes.columns.to_list() channel_cols = [ diff --git a/tests/test_shared_state.py b/tests/test_shared_state.py index 0de88bb5..83541acd 100644 --- a/tests/test_shared_state.py +++ b/tests/test_shared_state.py @@ -22,8 +22,8 @@ class Dummy1(Channel): def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = {} - self.channel_states = {"Dummy_s": 0.0} + self.params = {} + self.states = {"Dummy_s": 0.0} self.current_name = f"i_Dummy1" @staticmethod @@ -45,8 +45,8 @@ class Dummy2(Channel): def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = {} - self.channel_states = {"Dummy_s": 0.0} + self.params = {} + self.states = {"Dummy_s": 0.0} self.current_name = f"i_Dummy2" @staticmethod @@ -68,10 +68,10 @@ class CaHVA(Channel): def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = { + self.params = { f"{self._name}_gCaHVA": 0.00001, # S/cm^2 } - self.channel_states = { + self.states = { f"{self._name}_m": 0.1, # Initial value for m gating variable f"{self._name}_h": 0.1, # Initial value for h gating variable "eCa": 0.0, # mV, assuming eca for demonstration @@ -140,13 +140,13 @@ def __init__( ): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = { + self.params = { f"{self._name}_gamma": 0.05, # Fraction of free calcium (not buffered) f"{self._name}_decay": 80, # Rate of removal of calcium in ms f"{self._name}_depth": 0.1, # Depth of shell in um f"{self._name}_minCai": 1e-4, # Minimum intracellular calcium concentration in mM } - self.channel_states = { + self.states = { f"CaCon_i": 5e-05, # Initial internal calcium concentration in mM } self.current_name = f"i_Ca" diff --git a/tests/test_syn.py b/tests/test_syn.py index 3159e036..840fb341 100644 --- a/tests/test_syn.py +++ b/tests/test_syn.py @@ -27,7 +27,7 @@ def test_set_and_querying_params_one_type(SimpleNet): connect(pre, post, IonotropicSynapse()) # Get the synapse parameters to test setting - syn_params = list(IonotropicSynapse().synapse_params.keys()) + syn_params = list(IonotropicSynapse().params.keys()) for p in syn_params: net.set(p, 0.15) assert np.all(net.edges[p].to_numpy() == 0.15) diff --git a/tests/test_synapse_indexing.py b/tests/test_synapse_indexing.py index 150a5d83..d61934c4 100644 --- a/tests/test_synapse_indexing.py +++ b/tests/test_synapse_indexing.py @@ -68,7 +68,7 @@ def test_set_and_querying_params_one_type(synapse_type, SimpleNet): connect(pre, post, synapse_type) # Get the synapse parameters to test setting - syn_params = list(synapse_type.synapse_params.keys()) + syn_params = list(synapse_type.params.keys()) for p in syn_params: net.set(p, 0.15) assert np.all(net.edges[p].to_numpy() == 0.15) @@ -105,8 +105,8 @@ def test_set_and_querying_params_two_types(synapse_type, SimpleNet): post = net.cell(post_ind).branch(0).loc(0.0) connect(pre, post, synapse) - type1_params = list(IonotropicSynapse().synapse_params.keys()) - synapse_type_params = list(synapse_type.synapse_params.keys()) + type1_params = list(IonotropicSynapse().params.keys()) + synapse_type_params = list(synapse_type.params.keys()) default_synapse_type = net.edges[synapse_type_params[0]].to_numpy()[[1, 3]]