diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index c3e6a34c..716f139f 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -14,6 +14,7 @@ from matplotlib.axes import Axes from jaxley.channels import Channel +from jaxley.pumps import Pump from jaxley.solver_voltage import ( step_voltage_explicit, step_voltage_implicit_with_jax_spsolve, @@ -45,6 +46,9 @@ class Module(ABC): This base class defines the scaffold for all jaxley modules (compartments, branches, cells, networks). + + Note that the `__init__()` method is not abstract. This is because each module + type has a different initialization procedure. """ def __init__(self): @@ -91,6 +95,9 @@ def __init__(self): self.channels: List[Channel] = [] self.membrane_current_names: List[str] = [] + # List of all states (exluding voltage) that are being diffused. + self.diffusion_states: List[str] = [] + # For trainable parameters. self.indices_set_by_trainables: List[jnp.ndarray] = [] self.trainable_params: List[Dict[str, jnp.ndarray]] = [] @@ -287,8 +294,12 @@ def _init_morph_jaxley_spsolve(self): raise NotImplementedError def _compute_axial_conductances(self, params: Dict[str, jnp.ndarray]): - """Given radius, length, r_a, compute the axial coupling conductances.""" - return compute_axial_conductances(self._comp_edges, params) + """Given radius, length, r_a, compute the axial coupling conductances. + + If ion diffusion was activated by the user (with `cell.diffuse()`) then this + function also compute the axial conductances for every ion. + """ + return compute_axial_conductances(self._comp_edges, params, self.diffusion_states) def _append_channel_to_nodes(self, view: pd.DataFrame, channel: "jx.Channel"): """Adds channel nodes from constituents to `self.channel_nodes`.""" @@ -313,6 +324,26 @@ def _append_channel_to_nodes(self, view: pd.DataFrame, channel: "jx.Channel"): for key in channel.channel_states: self.nodes.loc[view.index.values, key] = channel.channel_states[key] + def _append_pump_to_nodes(self, view: pd.DataFrame, pump: "jx.Pump"): + """Adds pump nodes from constituents to `self.pump_nodes`.""" + name = pump._name + + # Pump does not yet exist in the `jx.Module` at all. + if name not in [c._name for c in self.pumps]: + self.pumps.append(pump) + self.nodes[name] = False # Previous columns do not have the new pump. + + # Add a binary column that indicates if the pump is present. + self.nodes.loc[view.index.values, name] = True + + # Loop over all new parameters. + for key in pump.pump_params: + self.nodes.loc[view.index.values, key] = pump.pump_params[key] + + # Loop over all new states. + for key in pump.pump_states: + self.nodes.loc[view.index.values, key] = pump.pump_states[key] + def set(self, key: str, val: Union[float, jnp.ndarray]): """Set parameter of module (or its view) to a new value. @@ -395,6 +426,26 @@ def _data_set( raise KeyError("Key not recognized.") return param_state + def diffuse(self, state: str): + """Diffuse a particular state across compartments with Fickian diffusion. + + Args: + state: Name of the state that should be diffused. + """ + self._diffuse(state, self.nodes, self.nodes) + + def _diffuse(self, state: str, table_to_update: pd.DataFrame, view: pd.DataFrame): + self.diffusion_states.append(state) + table_to_update.loc[view.index.values, f"axial_resistivity_{state}"] = 1.0 + + # The diffused state might not exist in all compartments that across which + # we are diffusing (e.g. there are active calcium mechanisms only in the soma, + # but calcium should still diffuse into the dendrites). Here, we ensure that + # the state is not `NaN` in every compartment across which we are diffusing. + state_is_nan = pd.isna(view[state]) + average_state_value = view[state].mean() + table_to_update.loc[state_is_nan, state] = average_state_value + def make_trainable( self, key: str, @@ -548,6 +599,11 @@ def get_all_parameters( for key in ["radius", "length", "axial_resistivity", "capacitance"]: params[key] = self.jaxnodes[key] + for key in self.diffusion_states: + params[f"axial_resistivity_{key}"] = self.jaxnodes[ + f"axial_resistivity_{key}" + ] + for channel in self.channels: for channel_params in channel.channel_params: params[channel_params] = self.jaxnodes[channel_params] @@ -883,6 +939,16 @@ def insert(self, channel: Channel): def _insert(self, channel, view): self._append_channel_to_nodes(view, channel) + def pump(self, pump: Pump): + """Insert a pump into the module. + + Args: + pump: The pump to insert.""" + self._pump(pump, self.nodes) + + def _pump(self, pump, view): + self._append_pump_to_nodes(view, pump) + def init_syns(self): self.initialized_syns = True @@ -900,7 +966,7 @@ def step( This function is called inside of `integrate` and increments the state of the module by one time step. Calls `_step_channels` and `_step_synapse` to update - the states of the channels and synapses using fwd_euler. + the states of the channels and synapses. Args: u: The state of the module. voltages = u["v"] @@ -934,6 +1000,11 @@ def step( u, delta_t, self.channels, self.nodes, params ) + # # Step of the Pumps. + # u, (v_terms, const_terms) = self._step_pumps( + # u, delta_t, self.pumps, self.nodes, params + # ) + # Step of the synapse. u, (syn_v_terms, syn_const_terms) = self._step_synapse( u, @@ -952,25 +1023,29 @@ def step( cm = params["capacitance"] # Abbreviation. # Arguments used by all solvers. - solver_kwargs = { - "voltages": voltages, - "voltage_terms": (v_terms + syn_v_terms) / cm, - "constant_terms": (const_terms + i_ext + syn_const_terms) / cm, + num_diffused_states = len(self.diffusion_states) + diffused_state_zeros = [jnp.zeros_like(v_terms)] * num_diffused_states + state_vals = { + "voltages": jnp.stack([voltages] + [u[d] for d in self.diffusion_states]), + "voltage_terms": jnp.stack( + [(v_terms + syn_v_terms) / cm] + diffused_state_zeros + ), + "constant_terms": jnp.stack( + [(const_terms + i_ext + syn_const_terms) / cm] + diffused_state_zeros + ), "axial_conductances": params["axial_conductances"], - "internal_node_inds": self._internal_node_inds, } # Add solver specific arguments. if voltage_solver == "jax.sparse": - solver_kwargs.update( - { - "sinks": np.asarray(self._comp_edges["sink"].to_list()), - "data_inds": self._data_inds, - "indices": self._indices_jax_spsolve, - "indptr": self._indptr_jax_spsolve, - "n_nodes": self._n_nodes, - } - ) + solver_kwargs = { + "data_inds": self._data_inds, + "indices": self._indices_jax_spsolve, + "indptr": self._indptr_jax_spsolve, + "sinks": np.asarray(self._comp_edges["sink"].to_list()), + "n_nodes": self._n_nodes, + "internal_node_inds": self._internal_node_inds, + } # Only for `bwd_euler` and `cranck-nicolson`. step_voltage_implicit = step_voltage_implicit_with_jax_spsolve else: @@ -980,42 +1055,51 @@ def step( # Currently, the forward Euler solver also uses this format. However, # this is only for historical reasons and we are planning to change this in # the future. - solver_kwargs.update( - { - "sinks": np.asarray(self._comp_edges["sink"].to_list()), - "sources": np.asarray(self._comp_edges["source"].to_list()), - "types": np.asarray(self._comp_edges["type"].to_list()), - "masked_node_inds": self._remapped_node_indices, - "nseg_per_branch": self.nseg_per_branch, - "nseg": self.nseg, - "par_inds": self.par_inds, - "child_inds": self.child_inds, - "nbranches": self.total_nbranches, - "solver": voltage_solver, - "children_in_level": self.children_in_level, - "parents_in_level": self.parents_in_level, - "root_inds": self.root_inds, - "branchpoint_group_inds": self.branchpoint_group_inds, - "debug_states": self.debug_states, - } - ) + solver_kwargs = { + "internal_node_inds": self._internal_node_inds, + "sinks": np.asarray(self._comp_edges["sink"].to_list()), + "sources": np.asarray(self._comp_edges["source"].to_list()), + "types": np.asarray(self._comp_edges["type"].to_list()), + "masked_node_inds": self._remapped_node_indices, + "nseg_per_branch": self.nseg_per_branch, + "nseg": self.nseg, + "par_inds": self.par_inds, + "child_inds": self.child_inds, + "nbranches": self.total_nbranches, + "solver": voltage_solver, + "children_in_level": self.children_in_level, + "parents_in_level": self.parents_in_level, + "root_inds": self.root_inds, + "branchpoint_group_inds": self.branchpoint_group_inds, + "debug_states": self.debug_states, + } # Only for `bwd_euler` and `cranck-nicolson`. step_voltage_implicit = step_voltage_implicit_with_jaxley_spsolve if solver == "bwd_euler": - u["v"] = step_voltage_implicit(**solver_kwargs, delta_t=delta_t) + nones = [None] * len(solver_kwargs) + vmapped = vmap(step_voltage_implicit, in_axes=(0, 0, 0, 0, *nones, None)) + updated_states = vmapped( + *state_vals.values(), *solver_kwargs.values(), delta_t + ) + u["v"] = updated_states[0] + for i, diffusion_state in enumerate(self.diffusion_states): + # +1 because voltage is the zero-eth element. + u[diffusion_state] = updated_states[i + 1] elif solver == "crank_nicolson": # Crank-Nicolson advances by half a step of backward and half a step of # forward Euler. half_step_delta_t = delta_t / 2 half_step_voltages = step_voltage_implicit( - **solver_kwargs, delta_t=half_step_delta_t + **state_vals, **solver_kwargs, delta_t=half_step_delta_t ) # The forward Euler step in Crank-Nicolson can be performed easily as # `V_{n+1} = 2 * V_{n+1/2} - V_n`. See also NEURON book Chapter 4. u["v"] = 2 * half_step_voltages - voltages elif solver == "fwd_euler": - u["v"] = step_voltage_explicit(**solver_kwargs, delta_t=delta_t) + u["v"] = step_voltage_explicit( + **state_vals, **solver_kwargs, delta_t=delta_t + ) else: raise ValueError( f"You specified `solver={solver}`. The only allowed solvers are " @@ -1687,14 +1771,13 @@ def data_set( """Set parameter of module (or its view) to a new value within `jit`.""" return self.pointer._data_set(key, val, self.view, param_state) - def make_trainable( - self, - key: str, - init_val: Optional[Union[float, list]] = None, - verbose: bool = True, - ): - """Make a parameter trainable.""" - self.pointer._make_trainable(self.view, key, init_val, verbose=verbose) + def diffuse(self, state: str): + """Diffuse a particular state across compartments with Fickian diffusion. + + Args: + state: Name of the state that should be diffused. + """ + self._diffuse(state, self.pointer.nodes, self.view) def add_to_group(self, group_name: str): self.pointer._add_to_group(group_name, self.view) diff --git a/jaxley/pumps/__init__.py b/jaxley/pumps/__init__.py new file mode 100644 index 00000000..ab0d7593 --- /dev/null +++ b/jaxley/pumps/__init__.py @@ -0,0 +1,2 @@ +from jaxley.pumps.pump import Pump +from jaxley.pumps.ca_pump import CaPump diff --git a/jaxley/pumps/ca_pump.py b/jaxley/pumps/ca_pump.py new file mode 100644 index 00000000..46a876f2 --- /dev/null +++ b/jaxley/pumps/ca_pump.py @@ -0,0 +1,44 @@ +from typing import Optional + +from jaxley.pumps.pump import Pump + + +class CaPump(Pump): + """Calcium dynamics tracking inside calcium concentration + + Modeled after Destexhe et al. 1994. + """ + + def __init__(self, name: Optional[str] = None): + super().__init__(name) + self.pump_params = { + f"{self._name}_gamma": 0.05, # Fraction of free calcium (not buffered). + f"{self._name}_decay": 80, # Buffering time constant in ms. + f"{self._name}_depth": 0.1, # Depth of shell in um. + f"{self._name}_minCai": 1e-4, # Minimum intracell. ca concentration in mM. + } + self.pump_states = {} + self.ion_name = "CaCon_i" + self.META = { + "reference": "Modified from Destexhe et al., 1994", + "mechanism": "Calcium dynamics", + } + + def update_states(self, u, dt, voltages, params): + """Update states if necessary (but this pump has no states to update).""" + return {"CaCon_i": u["CaCon_i"]} + + def compute_current(self, u, dt, voltages, params): + """Return change of calcium concentration based on calcium current and decay.""" + prefix = self._name + ica = u["i_Ca"] / 1_000.0 + gamma = params[f"{prefix}_gamma"] + decay = params[f"{prefix}_decay"] + depth = params[f"{prefix}_depth"] + minCai = params[f"{prefix}_minCai"] + + FARADAY = 96485 # Coulombs per mole. + + # Calculate the contribution of calcium currents to cai change. + drive_channel = -10_000.0 * ica * gamma / (2 * FARADAY * depth) + return drive_channel - (u["CaCon_i"] + minCai) / decay diff --git a/jaxley/pumps/pump.py b/jaxley/pumps/pump.py new file mode 100644 index 00000000..72ebb5eb --- /dev/null +++ b/jaxley/pumps/pump.py @@ -0,0 +1,74 @@ +# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is +# licensed under the Apache License Version 2.0, see + +from abc import ABC, abstractmethod +from typing import Dict, Optional, Tuple + +import jax.numpy as jnp + + +class Pump: + """Pump base class. All pumps inherit from this class. + + A pump in Jaxley is everything that modifies the intracellular ion concentrations. + """ + + _name = None + pump_params = None + pump_states = None + current_name = None + + def __init__(self, name: Optional[str] = None): + self._name = name if name else self.__class__.__name__ + + @property + def name(self) -> Optional[str]: + """The name of the channel (by default, this is the class name).""" + return self._name + + def change_name(self, new_name: str): + """Change the pump name. + + Args: + new_name: The new name of the pump. + + Returns: + Renamed pump, such that this function is chainable. + """ + old_prefix = self._name + "_" + new_prefix = new_name + "_" + + self._name = new_name + self.pump_params = { + ( + new_prefix + key[len(old_prefix) :] + if key.startswith(old_prefix) + else key + ): value + for key, value in self.pump_params.items() + } + + self.pump_states = { + ( + new_prefix + key[len(old_prefix) :] + if key.startswith(old_prefix) + else key + ): value + for key, value in self.pump_states.items() + } + return self + + def compute_current( + self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray] + ): + """Given channel states and voltage, return the change in ion concentration. + + Args: + states: All states of the compartment. + v: Voltage of the compartment in mV. + params: Parameters of the channel (conductances in `S/cm2`). + + Returns: + Ion concentration change in `mM`. + """ + raise NotImplementedError diff --git a/jaxley/solver_voltage.py b/jaxley/solver_voltage.py index 80c1538a..1cc184f3 100644 --- a/jaxley/solver_voltage.py +++ b/jaxley/solver_voltage.py @@ -80,12 +80,12 @@ def step_voltage_implicit_with_jaxley_spsolve( child_inds: jnp.ndarray, nbranches: int, solver: str, - delta_t: float, children_in_level: List[jnp.ndarray], parents_in_level: List[jnp.ndarray], root_inds: jnp.ndarray, branchpoint_group_inds: jnp.ndarray, debug_states, + delta_t: float, ): """Solve one timestep of branched nerve equations with implicit (backward) Euler.""" # Build diagonals. @@ -147,9 +147,10 @@ def step_voltage_implicit_with_jaxley_spsolve( ) # Find unique group identifiers num_branchpoints = len(branchpoint_conds_parents) - branchpoint_diags = -group_and_sum( - all_branchpoint_vals, branchpoint_group_inds, num_branchpoints - ) + branchpoint_diags = ( + -group_and_sum(all_branchpoint_vals, branchpoint_group_inds, num_branchpoints) + + 1e-14 + ) # For numerical stability if axial_conductances == 0.0 branchpoint_solves = jnp.zeros((num_branchpoints,)) branchpoint_conds_children = -delta_t * branchpoint_conds_children @@ -246,9 +247,9 @@ def step_voltage_implicit_with_jax_spsolve( indices, indptr, sinks, - delta_t, n_nodes, internal_node_inds, + delta_t, ): axial_conductances = delta_t * axial_conductances diff --git a/jaxley/utils/cell_utils.py b/jaxley/utils/cell_utils.py index ffa8c0e3..8b77118d 100644 --- a/jaxley/utils/cell_utils.py +++ b/jaxley/utils/cell_utils.py @@ -410,7 +410,7 @@ def query_channel_states_and_params(d, keys, idcs): def compute_axial_conductances( - comp_edges: pd.DataFrame, params: Dict[str, jnp.ndarray] + comp_edges: pd.DataFrame, params: Dict[str, jnp.ndarray], diffusion_states: List[str] ) -> jnp.ndarray: """Given `comp_edges`, radius, length, r_a, cm, compute the axial conductances. @@ -422,20 +422,23 @@ def compute_axial_conductances( source_comp_inds = np.asarray(comp_edges[condition]["source"].to_list()) sink_comp_inds = np.asarray(comp_edges[condition]["sink"].to_list()) + resistivities = jnp.stack([params["axial_resistivity"]] + [params[f"axial_resistivity_{d}"] for d in diffusion_states]) + print("resistivities", resistivities.shape) + if len(sink_comp_inds) > 0: conds_c2c = ( - vmap(compute_coupling_cond, in_axes=(0, 0, 0, 0, 0, 0))( + vmap(vmap(compute_coupling_cond, in_axes=(0, 0, 0, 0, 0, 0)), in_axes=(None, None, 0, 0, None, None))( params["radius"][sink_comp_inds], params["radius"][source_comp_inds], - params["axial_resistivity"][sink_comp_inds], - params["axial_resistivity"][source_comp_inds], + resistivities[:, sink_comp_inds], + resistivities[:, source_comp_inds], params["length"][sink_comp_inds], params["length"][source_comp_inds], ) / params["capacitance"][sink_comp_inds] ) else: - conds_c2c = jnp.asarray([]) + conds_c2c = jnp.asarray([[]] * (len(diffusion_states) + 1)) # `branchpoint-to-compartment` (bp2c) axial coupling conductances. condition = comp_edges["type"].isin([1, 2]) @@ -443,34 +446,34 @@ def compute_axial_conductances( if len(sink_comp_inds) > 0: conds_bp2c = ( - vmap(compute_coupling_cond_branchpoint, in_axes=(0, 0, 0))( + vmap(vmap(compute_coupling_cond_branchpoint, in_axes=(0, 0, 0)), in_axes=(None, 0, None))( params["radius"][sink_comp_inds], - params["axial_resistivity"][sink_comp_inds], + resistivities[:, sink_comp_inds], params["length"][sink_comp_inds], ) - / params["capacitance"][sink_comp_inds] + / params["capacitance"][sink_comp_inds] # TODO only v should divide by capacitance. ) else: - conds_bp2c = jnp.asarray([]) + conds_bp2c = jnp.asarray([[]] * (len(diffusion_states) + 1)) # `compartment-to-branchpoint` (c2bp) axial coupling conductances. condition = comp_edges["type"].isin([3, 4]) source_comp_inds = np.asarray(comp_edges[condition]["source"].to_list()) if len(source_comp_inds) > 0: - conds_c2bp = vmap(compute_impact_on_node, in_axes=(0, 0, 0))( + conds_c2bp = vmap(vmap(compute_impact_on_node, in_axes=(0, 0, 0)), in_axes=(0, None, 0))( params["radius"][source_comp_inds], - params["axial_resistivity"][source_comp_inds], + resistivities[:, source_comp_inds], params["length"][source_comp_inds], ) # For numerical stability. These values are very small, but their scale # does not matter. conds_c2bp *= 1_000 else: - conds_c2bp = jnp.asarray([]) + conds_c2bp = jnp.asarray([[]] * (len(diffusion_states) + 1)) # All axial coupling conductances. - return jnp.concatenate([conds_c2c, conds_bp2c, conds_c2bp]) + return jnp.concatenate([conds_c2c, conds_bp2c, conds_c2bp], axis=1) def compute_children_and_parents(