Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make all gates staticmethods of the channel classes #231

Merged
merged 5 commits into from
Feb 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 31 additions & 19 deletions jaxley/channels/hh.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def update_states(
"""Return updated HH channel state."""
prefix = self._name
ms, hs, ns = u[f"{prefix}_m"], u[f"{prefix}_h"], u[f"{prefix}_n"]
new_m = solve_gate_exponential(ms, dt, *_m_gate(voltages))
new_h = solve_gate_exponential(hs, dt, *_h_gate(voltages))
new_n = solve_gate_exponential(ns, dt, *_n_gate(voltages))
new_m = solve_gate_exponential(ms, dt, *self.m_gate(voltages))
new_h = solve_gate_exponential(hs, dt, *self.h_gate(voltages))
new_n = solve_gate_exponential(ns, dt, *self.n_gate(voltages))
return {f"{prefix}_m": new_m, f"{prefix}_h": new_h, f"{prefix}_n": new_n}

def compute_current(
Expand All @@ -55,23 +55,35 @@ def compute_current(
+ leak_conds * (voltages - params[f"{prefix}_eLeak"])
)

def init_state(self, voltages, params):
"""Initialize the state such at fixed point of gate dynamics."""
prefix = self._name
alpha_m, beta_m = self.m_gate(voltages)
alpha_h, beta_h = self.h_gate(voltages)
alpha_n, beta_n = self.n_gate(voltages)
return {
f"{prefix}_m": alpha_m / (alpha_m + beta_m),
f"{prefix}_h": alpha_h / (alpha_h + beta_h),
f"{prefix}_n": alpha_n / (alpha_n + beta_n),
}

def _m_gate(v):
alpha = 0.1 * _vtrap(-(v + 40), 10)
beta = 4.0 * jnp.exp(-(v + 65) / 18)
return alpha, beta


def _h_gate(v):
alpha = 0.07 * jnp.exp(-(v + 65) / 20)
beta = 1.0 / (jnp.exp(-(v + 35) / 10) + 1)
return alpha, beta


def _n_gate(v):
alpha = 0.01 * _vtrap(-(v + 55), 10)
beta = 0.125 * jnp.exp(-(v + 65) / 80)
return alpha, beta
@staticmethod
def m_gate(v):
alpha = 0.1 * _vtrap(-(v + 40), 10)
beta = 4.0 * jnp.exp(-(v + 65) / 18)
return alpha, beta

@staticmethod
def h_gate(v):
alpha = 0.07 * jnp.exp(-(v + 65) / 20)
beta = 1.0 / (jnp.exp(-(v + 35) / 10) + 1)
return alpha, beta

@staticmethod
def n_gate(v):
alpha = 0.01 * _vtrap(-(v + 55), 10)
beta = 0.125 * jnp.exp(-(v + 65) / 80)
return alpha, beta


def _vtrap(x, y):
Expand Down
188 changes: 91 additions & 97 deletions jaxley/channels/pospischil.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def compute_current(
leak_conds = params[f"{prefix}_gl"] * 1000 # mS/cm^2
return leak_conds * (voltages - params[f"{prefix}_el"])

def init_state(self, voltages, params):
return {}


class Na(Channel):
"""Sodium channel"""
Expand All @@ -71,8 +74,8 @@ def update_states(
"""Update state."""
prefix = self._name
ms, hs = u[f"{prefix}_m"], u[f"{prefix}_h"]
new_m = solve_gate_exponential(ms, dt, *_m_gate(voltages, params["vt"]))
new_h = solve_gate_exponential(hs, dt, *_h_gate(voltages, params["vt"]))
new_m = solve_gate_exponential(ms, dt, *self.m_gate(voltages, params["vt"]))
new_h = solve_gate_exponential(hs, dt, *self.h_gate(voltages, params["vt"]))
return {f"{prefix}_m": new_m, f"{prefix}_h": new_h}

def compute_current(
Expand All @@ -88,23 +91,33 @@ def compute_current(
current = na_conds * (voltages - params[f"{prefix}_eNa"])
return current

def init_state(self, voltages, params):
"""Initialize the state such at fixed point of gate dynamics."""
prefix = self._name
alpha_m, beta_m = self.m_gate(voltages, params["vt"])
alpha_h, beta_h = self.h_gate(voltages, params["vt"])
return {
f"{prefix}_m": alpha_m / (alpha_m + beta_m),
f"{prefix}_h": alpha_h / (alpha_h + beta_h),
}

def _m_gate(v, vt):
v_alpha = v - vt - 13.0
alpha = 0.32 * efun(-0.25 * v_alpha) / 0.25

v_beta = v - vt - 40.0
beta = 0.28 * efun(0.2 * v_beta) / 0.2
return alpha, beta
@staticmethod
def m_gate(v, vt):
v_alpha = v - vt - 13.0
alpha = 0.32 * efun(-0.25 * v_alpha) / 0.25

v_beta = v - vt - 40.0
beta = 0.28 * efun(0.2 * v_beta) / 0.2
return alpha, beta

def _h_gate(v, vt):
v_alpha = v - vt - 17.0
alpha = 0.128 * jnp.exp(-v_alpha / 18.0)
@staticmethod
def h_gate(v, vt):
v_alpha = v - vt - 17.0
alpha = 0.128 * jnp.exp(-v_alpha / 18.0)

v_beta = v - vt - 40.0
beta = 4.0 / (jnp.exp(-v_beta / 5.0) + 1.0)
return alpha, beta
v_beta = v - vt - 40.0
beta = 4.0 / (jnp.exp(-v_beta / 5.0) + 1.0)
return alpha, beta


class K(Channel):
Expand All @@ -126,7 +139,7 @@ def update_states(
"""Update state."""
prefix = self._name
ns = u[f"{prefix}_n"]
new_n = solve_gate_exponential(ns, dt, *_n_gate(voltages, params["vt"]))
new_n = solve_gate_exponential(ns, dt, *self.n_gate(voltages, params["vt"]))
return {f"{prefix}_n": new_n}

def compute_current(
Expand All @@ -141,14 +154,20 @@ def compute_current(

return k_conds * (voltages - params[f"{prefix}_eK"])

def init_state(self, voltages, params):
"""Initialize the state such at fixed point of gate dynamics."""
prefix = self._name
alpha_n, beta_n = self.n_gate(voltages, params["vt"])
return {f"{prefix}_n": alpha_n / (alpha_n + beta_n)}

def _n_gate(v, vt):
v_alpha = v - vt - 15.0
alpha = 0.032 * efun(-0.2 * v_alpha) / 0.2
@staticmethod
def n_gate(v, vt):
v_alpha = v - vt - 15.0
alpha = 0.032 * efun(-0.2 * v_alpha) / 0.2

v_beta = v - vt - 10.0
beta = 0.5 * jnp.exp(-v_beta / 40.0)
return alpha, beta
v_beta = v - vt - 10.0
beta = 0.5 * jnp.exp(-v_beta / 40.0)
return alpha, beta


class Km(Channel):
Expand All @@ -171,7 +190,7 @@ def update_states(
prefix = self._name
ps = u[f"{prefix}_p"]
new_p = solve_inf_gate_exponential(
ps, dt, *_p_gate(voltages, params[f"{prefix}_taumax"])
ps, dt, *self.p_gate(voltages, params[f"{prefix}_taumax"])
)
return {f"{prefix}_p": new_p}

Expand All @@ -186,61 +205,20 @@ def compute_current(
m_conds = params[f"{prefix}_gM"] * ps * 1000 # mS/cm^2
return m_conds * (voltages - params["eM"])


def _p_gate(v, taumax):
v_p = v + 35.0
p_inf = 1.0 / (1.0 + jnp.exp(-0.1 * v_p))

tau_p = taumax / (3.3 * jnp.exp(0.05 * v_p) + jnp.exp(-0.05 * v_p))

return p_inf, tau_p


class NaK(Channel):
"""Sodium and Potassium channel"""

def __init__(self, name: Optional[str] = None):
super().__init__(name)
prefix = self._name
self.channel_params = {
f"{prefix}_gNa": 0.05,
f"{prefix}_eNa": 50.0,
f"{prefix}_gK": 0.005,
f"{prefix}_eK": -90.0,
"vt": -60, # Global parameter, not prefixed with `NaK`.
}
self.channel_states = {
f"{prefix}_m": 0.2,
f"{prefix}_h": 0.2,
f"{prefix}_n": 0.2,
}

def update_states(
self, u: Dict[str, jnp.ndarray], dt, voltages, params: Dict[str, jnp.ndarray]
):
"""Update state."""

def init_state(self, voltages, params):
"""Initialize the state such at fixed point of gate dynamics."""
prefix = self._name
ms, hs, ns = u[f"{prefix}_m"], u[f"{prefix}_h"], u[f"{prefix}_n"]
new_m = solve_gate_exponential(ms, dt, *_m_gate(voltages, params["vt"]))
new_h = solve_gate_exponential(hs, dt, *_h_gate(voltages, params["vt"]))
new_n = solve_gate_exponential(ns, dt, *_n_gate(voltages, params["vt"]))
return {f"{prefix}_m": new_m, f"{prefix}_h": new_h, f"{prefix}_n": new_n}
alpha_p, beta_p = self.p_gate(voltages, params[f"{prefix}_taumax"])
return {f"{prefix}_p": alpha_p / (alpha_p + beta_p)}

def compute_current(
self, u: Dict[str, jnp.ndarray], voltages, params: Dict[str, jnp.ndarray]
):
"""Return current."""
prefix = self._name
ms, hs, ns = u[f"{prefix}_m"], u[f"{prefix}_h"], u[f"{prefix}_n"]
@staticmethod
def p_gate(v, taumax):
v_p = v + 35.0
p_inf = 1.0 / (1.0 + jnp.exp(-0.1 * v_p))

# Multiply with 1000 to convert Siemens to milli Siemens.
na_conds = params[f"{prefix}_gNa"] * (ms**3) * hs * 1000 # mS/cm^2
k_conds = params[f"{prefix}_gK"] * (ns**4) * 1000 # mS/cm^2
tau_p = taumax / (3.3 * jnp.exp(0.05 * v_p) + jnp.exp(-0.05 * v_p))

return na_conds * (voltages - params[f"{prefix}_eNa"]) + k_conds * (
voltages - params[f"{prefix}_eK"]
)
return p_inf, tau_p


class CaL(Channel):
Expand All @@ -261,8 +239,8 @@ def update_states(
"""Update state."""
prefix = self._name
qs, rs = u[f"{prefix}_q"], u[f"{prefix}_r"]
new_q = solve_gate_exponential(qs, dt, *_q_gate(voltages))
new_r = solve_gate_exponential(rs, dt, *_r_gate(voltages))
new_q = solve_gate_exponential(qs, dt, *self.q_gate(voltages))
new_r = solve_gate_exponential(rs, dt, *self.r_gate(voltages))
return {f"{prefix}_q": new_q, f"{prefix}_r": new_r}

def compute_current(
Expand All @@ -277,23 +255,33 @@ def compute_current(

return ca_conds * (voltages - params["eCa"])

def init_state(self, voltages, params):
"""Initialize the state such at fixed point of gate dynamics."""
prefix = self._name
alpha_q, beta_q = self.q_gate(voltages)
alpha_r, beta_r = self.r_gate(voltages)
return {
f"{prefix}_q": alpha_q / (alpha_q + beta_q),
f"{prefix}_r": alpha_r / (alpha_r + beta_r),
}

def _q_gate(v):
v_alpha = -v - 27.0
alpha = 0.055 * efun(v_alpha / 3.8) * 3.8

v_beta = -v - 75.0
beta = 0.94 * jnp.exp(v_beta / 17.0)
return alpha, beta
@staticmethod
def q_gate(v):
v_alpha = -v - 27.0
alpha = 0.055 * efun(v_alpha / 3.8) * 3.8

v_beta = -v - 75.0
beta = 0.94 * jnp.exp(v_beta / 17.0)
return alpha, beta

def _r_gate(v):
v_alpha = -v - 13.0
alpha = 0.000457 * jnp.exp(v_alpha / 50)
@staticmethod
def r_gate(v):
v_alpha = -v - 13.0
alpha = 0.000457 * jnp.exp(v_alpha / 50)

v_beta = -v - 15.0
beta = 0.0065 / (jnp.exp(v_beta / 28.0) + 1)
return alpha, beta
v_beta = -v - 15.0
beta = 0.0065 / (jnp.exp(v_beta / 28.0) + 1)
return alpha, beta


class CaT(Channel):
Expand All @@ -316,7 +304,7 @@ def update_states(
prefix = self._name
us = u[f"{prefix}_u"]
new_u = solve_inf_gate_exponential(
us, dt, *_u_gate(voltages, params[f"{prefix}_vx"])
us, dt, *self.u_gate(voltages, params[f"{prefix}_vx"])
)
return {f"{prefix}_u": new_u}

Expand All @@ -333,13 +321,19 @@ def compute_current(

return ca_conds * (voltages - params["eCa"])

def init_state(self, voltages, params):
"""Initialize the state such at fixed point of gate dynamics."""
prefix = self._name
alpha_u, beta_u = self.u_gate(voltages, params[f"{prefix}_vx"])
return {f"{prefix}_u": alpha_u / (alpha_u + beta_u)}

def _u_gate(v, vx):
v_u1 = v + vx + 81.0
u_inf = 1.0 / (1.0 + jnp.exp(v_u1 / 4))
@staticmethod
def u_gate(v, vx):
v_u1 = v + vx + 81.0
u_inf = 1.0 / (1.0 + jnp.exp(v_u1 / 4))

tau_u = (30.8 + (211.4 + jnp.exp((v + vx + 113.2) / 5.0))) / (
3.7 * (1 + jnp.exp((v + vx + 84.0) / 3.2))
)
tau_u = (30.8 + (211.4 + jnp.exp((v + vx + 113.2) / 5.0))) / (
3.7 * (1 + jnp.exp((v + vx + 84.0) / 3.2))
)

return u_inf, tau_u
return u_inf, tau_u
22 changes: 22 additions & 0 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,28 @@ def initialize(self):
self.init_morph()
return self

def init_states(self) -> None:
"""Initialize all mechanisms in their steady state.

This considers the voltages and parameters of each compartment."""
# Update states of the channels.
channel_nodes = self.nodes

for channel in self.channels:
name = channel._name
indices = channel_nodes.loc[channel_nodes[name]]["comp_index"].to_numpy()
voltages = channel_nodes.loc[indices, "voltages"].to_numpy()

channel_param_names = list(channel.channel_params.keys())
channel_params = {}
for p in channel_param_names:
channel_params[p] = channel_nodes[p][indices].to_numpy()

init_state = channel.init_state(voltages, channel_params)

for key, val in init_state.items():
self.nodes.loc[indices, key] = val

def record(self, state: str = "voltages"):
"""Insert a recording into the compartment."""
view = deepcopy(self.nodes)
Expand Down
Loading
Loading