Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefix all param and state names with channel name #200

Merged
merged 10 commits into from
Dec 18, 2023
11 changes: 2 additions & 9 deletions jaxley/channels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
from jaxley.channels.channel import Channel
from jaxley.channels.hh import HHChannel
from jaxley.channels.pospischil import (
CaLChannelPospi,
CaTChannelPospi,
KChannelPospi,
KmChannelPospi,
Leak,
NaChannelPospi,
)
from jaxley.channels.hh import HH
from jaxley.channels.pospischil import CaL, CaT, K, Km, Leak, Na
1 change: 1 addition & 0 deletions jaxley/channels/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class Channel:
channel_states = None

def __init__(self):
self.vmaped_update_states = vmap(self.update_states, in_axes=(0, None, 0, 0))
self.vmapped_compute_current = vmap(
self.compute_current, in_axes=(None, 0, None)
)
Expand Down
34 changes: 17 additions & 17 deletions jaxley/channels/hh.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,46 +6,46 @@
from jaxley.solver_gate import solve_gate_exponential


class HHChannel(Channel):
class HH(Channel):
"""Hodgkin-Huxley channel."""

channel_params = {
"gNa": 0.12,
"gK": 0.036,
"gLeak": 0.0003,
"eNa": 50.0,
"eK": -77.0,
"eLeak": -54.3,
"HH_gNa": 0.12,
"HH_gK": 0.036,
"HH_gLeak": 0.0003,
"HH_eNa": 50.0,
"HH_eK": -77.0,
"HH_eLeak": -54.3,
}
channel_states = {"m": 0.2, "h": 0.2, "n": 0.2}
channel_states = {"HH_m": 0.2, "HH_h": 0.2, "HH_n": 0.2}

@staticmethod
def update_states(
u: Dict[str, jnp.ndarray], dt, voltages, params: Dict[str, jnp.ndarray]
):
"""Return updated HH channel state."""
ms, hs, ns = u["m"], u["h"], u["n"]
ms, hs, ns = u["HH_m"], u["HH_h"], u["HH_n"]
new_m = solve_gate_exponential(ms, dt, *_m_gate(voltages))
new_h = solve_gate_exponential(hs, dt, *_h_gate(voltages))
new_n = solve_gate_exponential(ns, dt, *_n_gate(voltages))
return {"m": new_m, "h": new_h, "n": new_n}
return {"HH_m": new_m, "HH_h": new_h, "HH_n": new_n}

@staticmethod
def compute_current(
u: Dict[str, jnp.ndarray], voltages, params: Dict[str, jnp.ndarray]
):
"""Return current through HH channels."""
ms, hs, ns = u["m"], u["h"], u["n"]
ms, hs, ns = u["HH_m"], u["HH_h"], u["HH_n"]

# Multiply with 1000 to convert Siemens to milli Siemens.
na_conds = params["gNa"] * (ms**3) * hs * 1000 # mS/cm^2
kd_conds = params["gK"] * ns**4 * 1000 # mS/cm^2
leak_conds = params["gLeak"] * 1000 # mS/cm^2
na_conds = params["HH_gNa"] * (ms**3) * hs * 1000 # mS/cm^2
kd_conds = params["HH_gK"] * ns**4 * 1000 # mS/cm^2
leak_conds = params["HH_gLeak"] * 1000 # mS/cm^2

return (
na_conds * (voltages - params["eNa"])
+ kd_conds * (voltages - params["eK"])
+ leak_conds * (voltages - params["eLeak"])
na_conds * (voltages - params["HH_eNa"])
+ kd_conds * (voltages - params["HH_eK"])
+ leak_conds * (voltages - params["HH_eLeak"])
)


Expand Down
151 changes: 82 additions & 69 deletions jaxley/channels/pospischil.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ class Leak(Channel):
"""Leak current"""

channel_params = {
"gl": 1e-4,
"el": -70.0,
"Leak_gl": 1e-4,
"Leak_el": -70.0,
}
channel_states = {}

Expand All @@ -46,37 +46,41 @@ def compute_current(
):
"""Return current."""
# Multiply with 1000 to convert Siemens to milli Siemens.
leak_conds = params["gl"] * 1000 # mS/cm^2
return leak_conds * (voltages - params["el"])
leak_conds = params["Leak_gl"] * 1000 # mS/cm^2
return leak_conds * (voltages - params["Leak_el"])


class NaChannelPospi(Channel):
class Na(Channel):
"""Sodium channel"""

channel_params = {"gNa": 50e-3, "eNa": 50.0, "vt": -60.0}
channel_states = {"m": 0.2, "h": 0.2}
channel_params = {
"Na_gNa": 50e-3,
"Na_eNa": 50.0,
"vt": -60.0, # Global parameter, not prefixed with `Na`.
}
channel_states = {"Na_m": 0.2, "Na_h": 0.2}

@staticmethod
def update_states(
u: Dict[str, jnp.ndarray], dt, voltages, params: Dict[str, jnp.ndarray]
):
"""Update state."""
ms, hs = u["m"], u["h"]
ms, hs = u["Na_m"], u["Na_h"]
new_m = solve_gate_exponential(ms, dt, *_m_gate(voltages, params["vt"]))
new_h = solve_gate_exponential(hs, dt, *_h_gate(voltages, params["vt"]))
return {"m": new_m, "h": new_h}
return {"Na_m": new_m, "Na_h": new_h}

@staticmethod
def compute_current(
u: Dict[str, jnp.ndarray], voltages, params: Dict[str, jnp.ndarray]
):
"""Return current."""
ms, hs = u["m"], u["h"]
ms, hs = u["Na_m"], u["Na_h"]

# Multiply with 1000 to convert Siemens to milli Siemens.
na_conds = params["gNa"] * (ms**3) * hs * 1000 # mS/cm^2
na_conds = params["Na_gNa"] * (ms**3) * hs * 1000 # mS/cm^2

current = na_conds * (voltages - params["eNa"])
current = na_conds * (voltages - params["Na_eNa"])
return current


Expand All @@ -98,34 +102,36 @@ def _h_gate(v, vt):
return alpha, beta


class KChannelPospi(Channel):
class K(Channel):
"""Potassium channel"""

# KChannelPospi.vt_ should be set to the same value as NaChannelPospi.vt
# if the Na channel is also present
channel_params = {"gK": 5e-3, "eK": -90.0, "vt_": -60.0}
channel_states = {"n": 0.2}
channel_params = {
"K_gK": 5e-3,
"K_eK": -90.0,
"vt": -60.0, # Global parameter, not prefixed with `Na`.
}
channel_states = {"K_n": 0.2}

@staticmethod
def update_states(
u: Dict[str, jnp.ndarray], dt, voltages, params: Dict[str, jnp.ndarray]
):
"""Update state."""
ns = u["n"]
new_n = solve_gate_exponential(ns, dt, *_n_gate(voltages, params["vt_"]))
return {"n": new_n}
ns = u["K_n"]
new_n = solve_gate_exponential(ns, dt, *_n_gate(voltages, params["vt"]))
return {"K_n": new_n}

@staticmethod
def compute_current(
u: Dict[str, jnp.ndarray], voltages, params: Dict[str, jnp.ndarray]
):
"""Return current."""
ns = u["n"]
ns = u["K_n"]

# Multiply with 1000 to convert Siemens to milli Siemens.
k_conds = params["gK"] * (ns**4) * 1000 # mS/cm^2
k_conds = params["K_gK"] * (ns**4) * 1000 # mS/cm^2

return k_conds * (voltages - params["eK"])
return k_conds * (voltages - params["K_eK"])


def _n_gate(v, vt):
Expand All @@ -137,33 +143,37 @@ def _n_gate(v, vt):
return alpha, beta


class KmChannelPospi(Channel):
class Km(Channel):
"""Slow M Potassium channel"""

# `eM` is the reversal potential of K, should be set to eK if another K channel is
# present
channel_params = {"gM": 0.004e-3, "taumax": 4000.0, "eM": -90.0} # ms
channel_states = {"p": 0.2}
channel_params = {
"Km_gM": 0.004e-3,
"Km_taumax": 4000.0,
"eM": -90.0, # Global parameter, not prefixed with `Km`.
}
channel_states = {"Km_p": 0.2}

@staticmethod
def update_states(
u: Dict[str, jnp.ndarray], dt, voltages, params: Dict[str, jnp.ndarray]
):
"""Update state."""
ps = u["p"]
new_p = solve_inf_gate_exponential(ps, dt, *_p_gate(voltages, params["taumax"]))
return {"p": new_p}
ps = u["Km_p"]
new_p = solve_inf_gate_exponential(
ps, dt, *_p_gate(voltages, params["Km_taumax"])
)
return {"Km_p": new_p}

@staticmethod
def compute_current(
u: Dict[str, jnp.ndarray], voltages, params: Dict[str, jnp.ndarray]
):
"""Return current."""
ps = u["p"]
ps = u["Km_p"]

# Multiply with 1000 to convert Siemens to milli Siemens.
m_conds = params["gM"] * ps * 1000 # mS/cm^2
return m_conds * (voltages - params["eM"])
m_conds = params["Km_gM"] * ps * 1000 # mS/cm^2
return m_conds * (voltages - params["Km_eM"])


def _p_gate(v, taumax):
Expand All @@ -175,73 +185,74 @@ def _p_gate(v, taumax):
return p_inf, tau_p


class NaKChannelsPospi(Channel):
class NaK(Channel):
"""Sodium and Potassium channel"""

channel_params = {
"gNa": 0.05,
"eNa": 50.0,
"gK": 0.005,
"eK": -90.0,
"vt": -60,
"NaK_gNa": 0.05,
"NaK_eNa": 50.0,
"NaK_gK": 0.005,
"NaK_eK": -90.0,
"vt": -60, # Global parameter, not prefixed with `NaK`.
}

channel_states = {"m": 0.2, "h": 0.2, "n": 0.2}
channel_states = {"NaK_m": 0.2, "NaK_h": 0.2, "NaK_n": 0.2}

@staticmethod
def update_states(
u: Dict[str, jnp.ndarray], dt, voltages, params: Dict[str, jnp.ndarray]
):
"""Update state."""
ms, hs, ns = u["m"], u["h"], u["n"]
ms, hs, ns = u["NaK_m"], u["NaK_h"], u["NaK_n"]
new_m = solve_gate_exponential(ms, dt, *_m_gate(voltages, params["vt"]))
new_h = solve_gate_exponential(hs, dt, *_h_gate(voltages, params["vt"]))
new_n = solve_gate_exponential(ns, dt, *_n_gate(voltages, params["vt"]))
return {"m": new_m, "h": new_h, "n": new_n}
return {"NaK_m": new_m, "NaK_h": new_h, "NaK_n": new_n}

@staticmethod
def compute_current(
u: Dict[str, jnp.ndarray], voltages, params: Dict[str, jnp.ndarray]
):
"""Return current."""
ms, hs, ns = u["m"], u["h"], u["n"]
ms, hs, ns = u["NaK_m"], u["NaK_h"], u["NaK_n"]

# Multiply with 1000 to convert Siemens to milli Siemens.
na_conds = params["gNa"] * (ms**3) * hs * 1000 # mS/cm^2
k_conds = params["gK"] * (ns**4) * 1000 # mS/cm^2
na_conds = params["NaK_gNa"] * (ms**3) * hs * 1000 # mS/cm^2
k_conds = params["NaK_gK"] * (ns**4) * 1000 # mS/cm^2

return na_conds * (voltages - params["eNa"]) + k_conds * (
voltages - params["eK"]
return na_conds * (voltages - params["NaK_eNa"]) + k_conds * (
voltages - params["NaK_eK"]
)


class CaLChannelPospi(Channel):
class CaL(Channel):
"""L-type Calcium channel"""

# `eCa` is the reversal potential of Ca, should be set to eCa if another Ca channel
# is present
channel_params = {"gCaL": 0.1e-3, "eCa": 120.0} # S/cm^2
channel_states = {"q": 0.2, "r": 0.2}
channel_params = {
"CaL_gCaL": 0.1e-3,
"eCa": 120.0, # Global parameter, not prefixed with `CaL`.
}
channel_states = {"CaL_q": 0.2, "CaL_r": 0.2}

@staticmethod
def update_states(
u: Dict[str, jnp.ndarray], dt, voltages, params: Dict[str, jnp.ndarray]
):
"""Update state."""
qs, rs = u["q"], u["r"]
qs, rs = u["CaL_q"], u["CaL_r"]
new_q = solve_gate_exponential(qs, dt, *_q_gate(voltages))
new_r = solve_gate_exponential(rs, dt, *_r_gate(voltages))
return {"q": new_q, "r": new_r}
return {"CaL_q": new_q, "CaL_r": new_r}

@staticmethod
def compute_current(
u: Dict[str, jnp.ndarray], voltages, params: Dict[str, jnp.ndarray]
):
"""Return current."""
qs, rs = u["q"], u["r"]
qs, rs = u["CaL_q"], u["CaL_r"]

# Multiply with 1000 to convert Siemens to milli Siemens.
ca_conds = params["gCaL"] * (qs**2) * rs * 1000 # mS/cm^2
ca_conds = params["CaL_gCaL"] * (qs**2) * rs * 1000 # mS/cm^2

return ca_conds * (voltages - params["eCa"])

Expand All @@ -264,35 +275,37 @@ def _r_gate(v):
return alpha, beta


class CaTChannelPospi(Channel):
class CaT(Channel):
"""T-type Calcium channel"""

channel_params = {"gCaT": 0.4e-4, "eCa_": 120.0, "vx": 2.0} # S/cm^2
# eCa_ is the reversal potential of Ca, should be set to eCa if another Ca channel is present

channel_states = {"u": 0.2}
channel_params = {
"CaT_gCaT": 0.4e-4,
"CaT_vx": 2.0,
"eCa": 120.0, # Global parameter, not prefixed with `CaT`.
}
channel_states = {"CaT_u": 0.2}

@staticmethod
def update_states(
u: Dict[str, jnp.ndarray], dt, voltages, params: Dict[str, jnp.ndarray]
):
"""Update state."""
us = u["u"]
new_u = solve_inf_gate_exponential(us, dt, *_u_gate(voltages, params["vx"]))
return {"u": new_u}
us = u["CaT_u"]
new_u = solve_inf_gate_exponential(us, dt, *_u_gate(voltages, params["CaT_vx"]))
return {"CaT_u": new_u}

@staticmethod
def compute_current(
u: Dict[str, jnp.ndarray], voltages, params: Dict[str, jnp.ndarray]
):
"""Return current."""
us = u["u"]
s_inf = 1.0 / (1.0 + jnp.exp(-(voltages + params["vx"] + 57.0) / 6.2))
us = u["CaT_u"]
s_inf = 1.0 / (1.0 + jnp.exp(-(voltages + params["CaT_vx"] + 57.0) / 6.2))

# Multiply with 1000 to convert Siemens to milli Siemens.
ca_conds = params["gCaT"] * (s_inf**2) * us * 1000 # mS/cm^2
ca_conds = params["CaT_gCaT"] * (s_inf**2) * us * 1000 # mS/cm^2

return ca_conds * (voltages - params["eCa_"])
return ca_conds * (voltages - params["eCa"])


def _u_gate(v, vx):
Expand Down
Loading