From 5ad4d699d6507efcb69822c4f5e89d7699e2c362 Mon Sep 17 00:00:00 2001 From: michaeldeistler Date: Thu, 3 Oct 2024 15:39:00 +0200 Subject: [PATCH] Allow ion diffusion --- jaxley/modules/base.py | 100 ++++++++++++++++++++++++--------------- jaxley/solver_voltage.py | 4 +- 2 files changed, 65 insertions(+), 39 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index c3e6a34c..a5edcac8 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -91,6 +91,9 @@ def __init__(self): self.channels: List[Channel] = [] self.membrane_current_names: List[str] = [] + # List of all states (exluding voltage) that are being diffused. + self.diffusion_states: List[str] = [] + # For trainable parameters. self.indices_set_by_trainables: List[jnp.ndarray] = [] self.trainable_params: List[Dict[str, jnp.ndarray]] = [] @@ -395,6 +398,10 @@ def _data_set( raise KeyError("Key not recognized.") return param_state + def diffuse(self, state: str): + self.diffusion_states.append(state) + self.nodes[f"axial_resistivity_{state}"] = 1.0 + def make_trainable( self, key: str, @@ -548,6 +555,11 @@ def get_all_parameters( for key in ["radius", "length", "axial_resistivity", "capacitance"]: params[key] = self.jaxnodes[key] + for key in self.diffusion_states: + params[f"axial_resistivity_{key}"] = self.jaxnodes[ + f"axial_resistivity_{key}" + ] + for channel in self.channels: for channel_params in channel.channel_params: params[channel_params] = self.jaxnodes[channel_params] @@ -952,25 +964,32 @@ def step( cm = params["capacitance"] # Abbreviation. # Arguments used by all solvers. - solver_kwargs = { - "voltages": voltages, - "voltage_terms": (v_terms + syn_v_terms) / cm, - "constant_terms": (const_terms + i_ext + syn_const_terms) / cm, - "axial_conductances": params["axial_conductances"], - "internal_node_inds": self._internal_node_inds, + state_vals = { + "voltages": jnp.stack([voltages, u["CaCon_i"]]), + "voltage_terms": jnp.stack( + [(v_terms + syn_v_terms) / cm, jnp.zeros_like(v_terms)] + ), + "constant_terms": jnp.stack( + [ + (const_terms + i_ext + syn_const_terms) / cm, + jnp.zeros_like(const_terms), + ] + ), + "axial_conductances": jnp.stack( + [params["axial_conductances"], params["axial_conductances"]] + ), } # Add solver specific arguments. if voltage_solver == "jax.sparse": - solver_kwargs.update( - { - "sinks": np.asarray(self._comp_edges["sink"].to_list()), - "data_inds": self._data_inds, - "indices": self._indices_jax_spsolve, - "indptr": self._indptr_jax_spsolve, - "n_nodes": self._n_nodes, - } - ) + solver_kwargs = { + "data_inds": self._data_inds, + "indices": self._indices_jax_spsolve, + "indptr": self._indptr_jax_spsolve, + "sinks": np.asarray(self._comp_edges["sink"].to_list()), + "n_nodes": self._n_nodes, + "internal_node_inds": self._internal_node_inds, + } # Only for `bwd_euler` and `cranck-nicolson`. step_voltage_implicit = step_voltage_implicit_with_jax_spsolve else: @@ -980,42 +999,49 @@ def step( # Currently, the forward Euler solver also uses this format. However, # this is only for historical reasons and we are planning to change this in # the future. - solver_kwargs.update( - { - "sinks": np.asarray(self._comp_edges["sink"].to_list()), - "sources": np.asarray(self._comp_edges["source"].to_list()), - "types": np.asarray(self._comp_edges["type"].to_list()), - "masked_node_inds": self._remapped_node_indices, - "nseg_per_branch": self.nseg_per_branch, - "nseg": self.nseg, - "par_inds": self.par_inds, - "child_inds": self.child_inds, - "nbranches": self.total_nbranches, - "solver": voltage_solver, - "children_in_level": self.children_in_level, - "parents_in_level": self.parents_in_level, - "root_inds": self.root_inds, - "branchpoint_group_inds": self.branchpoint_group_inds, - "debug_states": self.debug_states, - } - ) + solver_kwargs = { + "internal_node_inds": self._internal_node_inds, + "sinks": np.asarray(self._comp_edges["sink"].to_list()), + "sources": np.asarray(self._comp_edges["source"].to_list()), + "types": np.asarray(self._comp_edges["type"].to_list()), + "masked_node_inds": self._remapped_node_indices, + "nseg_per_branch": self.nseg_per_branch, + "nseg": self.nseg, + "par_inds": self.par_inds, + "child_inds": self.child_inds, + "nbranches": self.total_nbranches, + "solver": voltage_solver, + "children_in_level": self.children_in_level, + "parents_in_level": self.parents_in_level, + "root_inds": self.root_inds, + "branchpoint_group_inds": self.branchpoint_group_inds, + "debug_states": self.debug_states, + } # Only for `bwd_euler` and `cranck-nicolson`. step_voltage_implicit = step_voltage_implicit_with_jaxley_spsolve if solver == "bwd_euler": - u["v"] = step_voltage_implicit(**solver_kwargs, delta_t=delta_t) + nones = [None] * len(solver_kwargs) + vmapped = vmap(step_voltage_implicit, in_axes=(0, 0, 0, 0, *nones, None)) + updated_states = vmapped( + *state_vals.values(), *solver_kwargs.values(), delta_t + ) + u["v"] = updated_states[0] + u["CaCon_i"] = updated_states[1] elif solver == "crank_nicolson": # Crank-Nicolson advances by half a step of backward and half a step of # forward Euler. half_step_delta_t = delta_t / 2 half_step_voltages = step_voltage_implicit( - **solver_kwargs, delta_t=half_step_delta_t + **state_vals, **solver_kwargs, delta_t=half_step_delta_t ) # The forward Euler step in Crank-Nicolson can be performed easily as # `V_{n+1} = 2 * V_{n+1/2} - V_n`. See also NEURON book Chapter 4. u["v"] = 2 * half_step_voltages - voltages elif solver == "fwd_euler": - u["v"] = step_voltage_explicit(**solver_kwargs, delta_t=delta_t) + u["v"] = step_voltage_explicit( + **state_vals, **solver_kwargs, delta_t=delta_t + ) else: raise ValueError( f"You specified `solver={solver}`. The only allowed solvers are " diff --git a/jaxley/solver_voltage.py b/jaxley/solver_voltage.py index 80c1538a..fb014d0a 100644 --- a/jaxley/solver_voltage.py +++ b/jaxley/solver_voltage.py @@ -80,12 +80,12 @@ def step_voltage_implicit_with_jaxley_spsolve( child_inds: jnp.ndarray, nbranches: int, solver: str, - delta_t: float, children_in_level: List[jnp.ndarray], parents_in_level: List[jnp.ndarray], root_inds: jnp.ndarray, branchpoint_group_inds: jnp.ndarray, debug_states, + delta_t: float, ): """Solve one timestep of branched nerve equations with implicit (backward) Euler.""" # Build diagonals. @@ -246,9 +246,9 @@ def step_voltage_implicit_with_jax_spsolve( indices, indptr, sinks, - delta_t, n_nodes, internal_node_inds, + delta_t, ): axial_conductances = delta_t * axial_conductances