diff --git a/jaxley/channels/hh.py b/jaxley/channels/hh.py index b47b9723..e3081d4b 100644 --- a/jaxley/channels/hh.py +++ b/jaxley/channels/hh.py @@ -32,9 +32,9 @@ def update_states( """Return updated HH channel state.""" prefix = self._name ms, hs, ns = u[f"{prefix}_m"], u[f"{prefix}_h"], u[f"{prefix}_n"] - new_m = solve_gate_exponential(ms, dt, *_m_gate(voltages)) - new_h = solve_gate_exponential(hs, dt, *_h_gate(voltages)) - new_n = solve_gate_exponential(ns, dt, *_n_gate(voltages)) + new_m = solve_gate_exponential(ms, dt, *self.m_gate(voltages)) + new_h = solve_gate_exponential(hs, dt, *self.h_gate(voltages)) + new_n = solve_gate_exponential(ns, dt, *self.n_gate(voltages)) return {f"{prefix}_m": new_m, f"{prefix}_h": new_h, f"{prefix}_n": new_n} def compute_current( @@ -55,23 +55,35 @@ def compute_current( + leak_conds * (voltages - params[f"{prefix}_eLeak"]) ) + def init_state(self, voltages, params): + """Initialize the state such at fixed point of gate dynamics.""" + prefix = self._name + alpha_m, beta_m = self.m_gate(voltages) + alpha_h, beta_h = self.h_gate(voltages) + alpha_n, beta_n = self.n_gate(voltages) + return { + f"{prefix}_m": alpha_m / (alpha_m + beta_m), + f"{prefix}_h": alpha_h / (alpha_h + beta_h), + f"{prefix}_n": alpha_n / (alpha_n + beta_n), + } -def _m_gate(v): - alpha = 0.1 * _vtrap(-(v + 40), 10) - beta = 4.0 * jnp.exp(-(v + 65) / 18) - return alpha, beta - - -def _h_gate(v): - alpha = 0.07 * jnp.exp(-(v + 65) / 20) - beta = 1.0 / (jnp.exp(-(v + 35) / 10) + 1) - return alpha, beta - - -def _n_gate(v): - alpha = 0.01 * _vtrap(-(v + 55), 10) - beta = 0.125 * jnp.exp(-(v + 65) / 80) - return alpha, beta + @staticmethod + def m_gate(v): + alpha = 0.1 * _vtrap(-(v + 40), 10) + beta = 4.0 * jnp.exp(-(v + 65) / 18) + return alpha, beta + + @staticmethod + def h_gate(v): + alpha = 0.07 * jnp.exp(-(v + 65) / 20) + beta = 1.0 / (jnp.exp(-(v + 35) / 10) + 1) + return alpha, beta + + @staticmethod + def n_gate(v): + alpha = 0.01 * _vtrap(-(v + 55), 10) + beta = 0.125 * jnp.exp(-(v + 65) / 80) + return alpha, beta def _vtrap(x, y): diff --git a/jaxley/channels/pospischil.py b/jaxley/channels/pospischil.py index a35cdb32..67e54cbe 100644 --- a/jaxley/channels/pospischil.py +++ b/jaxley/channels/pospischil.py @@ -51,6 +51,9 @@ def compute_current( leak_conds = params[f"{prefix}_gl"] * 1000 # mS/cm^2 return leak_conds * (voltages - params[f"{prefix}_el"]) + def init_state(self, voltages, params): + return {} + class Na(Channel): """Sodium channel""" @@ -71,8 +74,8 @@ def update_states( """Update state.""" prefix = self._name ms, hs = u[f"{prefix}_m"], u[f"{prefix}_h"] - new_m = solve_gate_exponential(ms, dt, *_m_gate(voltages, params["vt"])) - new_h = solve_gate_exponential(hs, dt, *_h_gate(voltages, params["vt"])) + new_m = solve_gate_exponential(ms, dt, *self.m_gate(voltages, params["vt"])) + new_h = solve_gate_exponential(hs, dt, *self.h_gate(voltages, params["vt"])) return {f"{prefix}_m": new_m, f"{prefix}_h": new_h} def compute_current( @@ -88,23 +91,33 @@ def compute_current( current = na_conds * (voltages - params[f"{prefix}_eNa"]) return current + def init_state(self, voltages, params): + """Initialize the state such at fixed point of gate dynamics.""" + prefix = self._name + alpha_m, beta_m = self.m_gate(voltages, params["vt"]) + alpha_h, beta_h = self.h_gate(voltages, params["vt"]) + return { + f"{prefix}_m": alpha_m / (alpha_m + beta_m), + f"{prefix}_h": alpha_h / (alpha_h + beta_h), + } -def _m_gate(v, vt): - v_alpha = v - vt - 13.0 - alpha = 0.32 * efun(-0.25 * v_alpha) / 0.25 - - v_beta = v - vt - 40.0 - beta = 0.28 * efun(0.2 * v_beta) / 0.2 - return alpha, beta + @staticmethod + def m_gate(v, vt): + v_alpha = v - vt - 13.0 + alpha = 0.32 * efun(-0.25 * v_alpha) / 0.25 + v_beta = v - vt - 40.0 + beta = 0.28 * efun(0.2 * v_beta) / 0.2 + return alpha, beta -def _h_gate(v, vt): - v_alpha = v - vt - 17.0 - alpha = 0.128 * jnp.exp(-v_alpha / 18.0) + @staticmethod + def h_gate(v, vt): + v_alpha = v - vt - 17.0 + alpha = 0.128 * jnp.exp(-v_alpha / 18.0) - v_beta = v - vt - 40.0 - beta = 4.0 / (jnp.exp(-v_beta / 5.0) + 1.0) - return alpha, beta + v_beta = v - vt - 40.0 + beta = 4.0 / (jnp.exp(-v_beta / 5.0) + 1.0) + return alpha, beta class K(Channel): @@ -126,7 +139,7 @@ def update_states( """Update state.""" prefix = self._name ns = u[f"{prefix}_n"] - new_n = solve_gate_exponential(ns, dt, *_n_gate(voltages, params["vt"])) + new_n = solve_gate_exponential(ns, dt, *self.n_gate(voltages, params["vt"])) return {f"{prefix}_n": new_n} def compute_current( @@ -141,14 +154,20 @@ def compute_current( return k_conds * (voltages - params[f"{prefix}_eK"]) + def init_state(self, voltages, params): + """Initialize the state such at fixed point of gate dynamics.""" + prefix = self._name + alpha_n, beta_n = self.n_gate(voltages, params["vt"]) + return {f"{prefix}_n": alpha_n / (alpha_n + beta_n)} -def _n_gate(v, vt): - v_alpha = v - vt - 15.0 - alpha = 0.032 * efun(-0.2 * v_alpha) / 0.2 + @staticmethod + def n_gate(v, vt): + v_alpha = v - vt - 15.0 + alpha = 0.032 * efun(-0.2 * v_alpha) / 0.2 - v_beta = v - vt - 10.0 - beta = 0.5 * jnp.exp(-v_beta / 40.0) - return alpha, beta + v_beta = v - vt - 10.0 + beta = 0.5 * jnp.exp(-v_beta / 40.0) + return alpha, beta class Km(Channel): @@ -171,7 +190,7 @@ def update_states( prefix = self._name ps = u[f"{prefix}_p"] new_p = solve_inf_gate_exponential( - ps, dt, *_p_gate(voltages, params[f"{prefix}_taumax"]) + ps, dt, *self.p_gate(voltages, params[f"{prefix}_taumax"]) ) return {f"{prefix}_p": new_p} @@ -186,61 +205,20 @@ def compute_current( m_conds = params[f"{prefix}_gM"] * ps * 1000 # mS/cm^2 return m_conds * (voltages - params["eM"]) - -def _p_gate(v, taumax): - v_p = v + 35.0 - p_inf = 1.0 / (1.0 + jnp.exp(-0.1 * v_p)) - - tau_p = taumax / (3.3 * jnp.exp(0.05 * v_p) + jnp.exp(-0.05 * v_p)) - - return p_inf, tau_p - - -class NaK(Channel): - """Sodium and Potassium channel""" - - def __init__(self, name: Optional[str] = None): - super().__init__(name) - prefix = self._name - self.channel_params = { - f"{prefix}_gNa": 0.05, - f"{prefix}_eNa": 50.0, - f"{prefix}_gK": 0.005, - f"{prefix}_eK": -90.0, - "vt": -60, # Global parameter, not prefixed with `NaK`. - } - self.channel_states = { - f"{prefix}_m": 0.2, - f"{prefix}_h": 0.2, - f"{prefix}_n": 0.2, - } - - def update_states( - self, u: Dict[str, jnp.ndarray], dt, voltages, params: Dict[str, jnp.ndarray] - ): - """Update state.""" - + def init_state(self, voltages, params): + """Initialize the state such at fixed point of gate dynamics.""" prefix = self._name - ms, hs, ns = u[f"{prefix}_m"], u[f"{prefix}_h"], u[f"{prefix}_n"] - new_m = solve_gate_exponential(ms, dt, *_m_gate(voltages, params["vt"])) - new_h = solve_gate_exponential(hs, dt, *_h_gate(voltages, params["vt"])) - new_n = solve_gate_exponential(ns, dt, *_n_gate(voltages, params["vt"])) - return {f"{prefix}_m": new_m, f"{prefix}_h": new_h, f"{prefix}_n": new_n} + alpha_p, beta_p = self.p_gate(voltages, params[f"{prefix}_taumax"]) + return {f"{prefix}_p": alpha_p / (alpha_p + beta_p)} - def compute_current( - self, u: Dict[str, jnp.ndarray], voltages, params: Dict[str, jnp.ndarray] - ): - """Return current.""" - prefix = self._name - ms, hs, ns = u[f"{prefix}_m"], u[f"{prefix}_h"], u[f"{prefix}_n"] + @staticmethod + def p_gate(v, taumax): + v_p = v + 35.0 + p_inf = 1.0 / (1.0 + jnp.exp(-0.1 * v_p)) - # Multiply with 1000 to convert Siemens to milli Siemens. - na_conds = params[f"{prefix}_gNa"] * (ms**3) * hs * 1000 # mS/cm^2 - k_conds = params[f"{prefix}_gK"] * (ns**4) * 1000 # mS/cm^2 + tau_p = taumax / (3.3 * jnp.exp(0.05 * v_p) + jnp.exp(-0.05 * v_p)) - return na_conds * (voltages - params[f"{prefix}_eNa"]) + k_conds * ( - voltages - params[f"{prefix}_eK"] - ) + return p_inf, tau_p class CaL(Channel): @@ -261,8 +239,8 @@ def update_states( """Update state.""" prefix = self._name qs, rs = u[f"{prefix}_q"], u[f"{prefix}_r"] - new_q = solve_gate_exponential(qs, dt, *_q_gate(voltages)) - new_r = solve_gate_exponential(rs, dt, *_r_gate(voltages)) + new_q = solve_gate_exponential(qs, dt, *self.q_gate(voltages)) + new_r = solve_gate_exponential(rs, dt, *self.r_gate(voltages)) return {f"{prefix}_q": new_q, f"{prefix}_r": new_r} def compute_current( @@ -277,23 +255,33 @@ def compute_current( return ca_conds * (voltages - params["eCa"]) + def init_state(self, voltages, params): + """Initialize the state such at fixed point of gate dynamics.""" + prefix = self._name + alpha_q, beta_q = self.q_gate(voltages) + alpha_r, beta_r = self.r_gate(voltages) + return { + f"{prefix}_q": alpha_q / (alpha_q + beta_q), + f"{prefix}_r": alpha_r / (alpha_r + beta_r), + } -def _q_gate(v): - v_alpha = -v - 27.0 - alpha = 0.055 * efun(v_alpha / 3.8) * 3.8 - - v_beta = -v - 75.0 - beta = 0.94 * jnp.exp(v_beta / 17.0) - return alpha, beta + @staticmethod + def q_gate(v): + v_alpha = -v - 27.0 + alpha = 0.055 * efun(v_alpha / 3.8) * 3.8 + v_beta = -v - 75.0 + beta = 0.94 * jnp.exp(v_beta / 17.0) + return alpha, beta -def _r_gate(v): - v_alpha = -v - 13.0 - alpha = 0.000457 * jnp.exp(v_alpha / 50) + @staticmethod + def r_gate(v): + v_alpha = -v - 13.0 + alpha = 0.000457 * jnp.exp(v_alpha / 50) - v_beta = -v - 15.0 - beta = 0.0065 / (jnp.exp(v_beta / 28.0) + 1) - return alpha, beta + v_beta = -v - 15.0 + beta = 0.0065 / (jnp.exp(v_beta / 28.0) + 1) + return alpha, beta class CaT(Channel): @@ -316,7 +304,7 @@ def update_states( prefix = self._name us = u[f"{prefix}_u"] new_u = solve_inf_gate_exponential( - us, dt, *_u_gate(voltages, params[f"{prefix}_vx"]) + us, dt, *self.u_gate(voltages, params[f"{prefix}_vx"]) ) return {f"{prefix}_u": new_u} @@ -333,13 +321,19 @@ def compute_current( return ca_conds * (voltages - params["eCa"]) + def init_state(self, voltages, params): + """Initialize the state such at fixed point of gate dynamics.""" + prefix = self._name + alpha_u, beta_u = self.u_gate(voltages, params[f"{prefix}_vx"]) + return {f"{prefix}_u": alpha_u / (alpha_u + beta_u)} -def _u_gate(v, vx): - v_u1 = v + vx + 81.0 - u_inf = 1.0 / (1.0 + jnp.exp(v_u1 / 4)) + @staticmethod + def u_gate(v, vx): + v_u1 = v + vx + 81.0 + u_inf = 1.0 / (1.0 + jnp.exp(v_u1 / 4)) - tau_u = (30.8 + (211.4 + jnp.exp((v + vx + 113.2) / 5.0))) / ( - 3.7 * (1 + jnp.exp((v + vx + 84.0) / 3.2)) - ) + tau_u = (30.8 + (211.4 + jnp.exp((v + vx + 113.2) / 5.0))) / ( + 3.7 * (1 + jnp.exp((v + vx + 84.0) / 3.2)) + ) - return u_inf, tau_u + return u_inf, tau_u diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 406f6657..8d0518b5 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -356,6 +356,28 @@ def initialize(self): self.init_morph() return self + def init_states(self) -> None: + """Initialize all mechanisms in their steady state. + + This considers the voltages and parameters of each compartment.""" + # Update states of the channels. + channel_nodes = self.nodes + + for channel in self.channels: + name = channel._name + indices = channel_nodes.loc[channel_nodes[name]]["comp_index"].to_numpy() + voltages = channel_nodes.loc[indices, "voltages"].to_numpy() + + channel_param_names = list(channel.channel_params.keys()) + channel_params = {} + for p in channel_param_names: + channel_params[p] = channel_nodes[p][indices].to_numpy() + + init_state = channel.init_state(voltages, channel_params) + + for key, val in init_state.items(): + self.nodes.loc[indices, key] = val + def record(self, state: str = "voltages"): """Insert a recording into the compartment.""" view = deepcopy(self.nodes) diff --git a/tests/test_channels.py b/tests/test_channels.py index 587ac617..2d79fe38 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -7,7 +7,7 @@ import pytest import jaxley as jx -from jaxley.channels import HH, K, Na +from jaxley.channels import HH, CaL, CaT, K, Km, Leak, Na def test_channel_set_name(): @@ -64,3 +64,35 @@ def test_integration_with_renamed_channels(): # Test if voltage is `NaN` which happens when channels get mixed up. assert np.invert(np.any(np.isnan(v))) + + +def test_init_states(): + """Functional test for `init_states()`. + + Checks whether, if everything is initialized in its steady state, the voltage + after 10ms is almost exactly the same as after 0ms. + """ + comp = jx.Compartment() + branch = jx.Branch(comp, 4) + cell = jx.Cell(branch, [-1, 0]) + cell.branch(0).comp(0.0).record() + + cell.branch(0).insert(Na()) + cell.branch(1).insert(K()) + cell.branch(1).comp(0.0).insert(Km()) + cell.branch(0).comp(1.0).insert(CaT()) + cell.insert(CaL()) + cell.insert(Leak()) + + cell.insert(HH()) + + cell.set("voltages", -62.0) # At -70.0 there is a rebound spike. + cell.init_states() + v = jx.integrate(cell, t_max=20.0) + + last_voltage = v[0, -1] + cell.set("voltages", last_voltage) + cell.init_states() + + v = jx.integrate(cell, t_max=10.0) + assert np.abs(v[0, 0] - v[0, -1]) < 0.02 diff --git a/tutorials/02_setting_parameters.ipynb b/tutorials/02_setting_parameters.ipynb index 9ce45392..926f3847 100644 --- a/tutorials/02_setting_parameters.ipynb +++ b/tutorials/02_setting_parameters.ipynb @@ -613,6 +613,26 @@ "network.show()" ] }, + { + "cell_type": "markdown", + "id": "a06abf2a-66a3-44bd-821d-832acff17d94", + "metadata": {}, + "source": [ + "### Initializing the simulation\n", + "\n", + "Optionally, you can initialize all channels in their steady state (given voltage and parameters of each compartment):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "be2d2072-dfd3-4a0a-9bf7-edb4a0120c02", + "metadata": {}, + "outputs": [], + "source": [ + "network.init_states()" + ] + }, { "cell_type": "markdown", "id": "0d8e8c15", @@ -682,7 +702,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.11.7" } }, "nbformat": 4,