Skip to content

Commit

Permalink
Test for channel init
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Feb 3, 2024
1 parent 59952f5 commit 381ce4a
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 21 deletions.
50 changes: 31 additions & 19 deletions jaxley/channels/hh.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def update_states(
"""Return updated HH channel state."""
prefix = self._name
ms, hs, ns = u[f"{prefix}_m"], u[f"{prefix}_h"], u[f"{prefix}_n"]
new_m = solve_gate_exponential(ms, dt, *_m_gate(voltages))
new_h = solve_gate_exponential(hs, dt, *_h_gate(voltages))
new_n = solve_gate_exponential(ns, dt, *_n_gate(voltages))
new_m = solve_gate_exponential(ms, dt, *self.m_gate(voltages))
new_h = solve_gate_exponential(hs, dt, *self.h_gate(voltages))
new_n = solve_gate_exponential(ns, dt, *self.n_gate(voltages))
return {f"{prefix}_m": new_m, f"{prefix}_h": new_h, f"{prefix}_n": new_n}

def compute_current(
Expand All @@ -55,23 +55,35 @@ def compute_current(
+ leak_conds * (voltages - params[f"{prefix}_eLeak"])
)

def init_state(self, voltages, params):
"""Initialize the state such at fixed point of gate dynamics."""
prefix = self._name
alpha_m, beta_m = self.m_gate(voltages)
alpha_h, beta_h = self.h_gate(voltages)
alpha_n, beta_n = self.n_gate(voltages)
return {
f"{prefix}_m": alpha_m / (alpha_m + beta_m),
f"{prefix}_h": alpha_h / (alpha_h + beta_h),
f"{prefix}_n": alpha_n / (alpha_n + beta_n),
}

def _m_gate(v):
alpha = 0.1 * _vtrap(-(v + 40), 10)
beta = 4.0 * jnp.exp(-(v + 65) / 18)
return alpha, beta


def _h_gate(v):
alpha = 0.07 * jnp.exp(-(v + 65) / 20)
beta = 1.0 / (jnp.exp(-(v + 35) / 10) + 1)
return alpha, beta


def _n_gate(v):
alpha = 0.01 * _vtrap(-(v + 55), 10)
beta = 0.125 * jnp.exp(-(v + 65) / 80)
return alpha, beta
@staticmethod
def m_gate(v):
alpha = 0.1 * _vtrap(-(v + 40), 10)
beta = 4.0 * jnp.exp(-(v + 65) / 18)
return alpha, beta

@staticmethod
def h_gate(v):
alpha = 0.07 * jnp.exp(-(v + 65) / 20)
beta = 1.0 / (jnp.exp(-(v + 35) / 10) + 1)
return alpha, beta

@staticmethod
def n_gate(v):
alpha = 0.01 * _vtrap(-(v + 55), 10)
beta = 0.125 * jnp.exp(-(v + 65) / 80)
return alpha, beta


def _vtrap(x, y):
Expand Down
2 changes: 1 addition & 1 deletion jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def initialize(self):

def init_states(self) -> None:
"""Initialize all mechanisms in their steady state.
This considers the voltages and parameters of each compartment."""
# Update states of the channels.
channel_nodes = self.nodes
Expand Down
34 changes: 33 additions & 1 deletion tests/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

import jaxley as jx
from jaxley.channels import HH, K, Na
from jaxley.channels import HH, CaL, CaT, K, Km, Leak, Na


def test_channel_set_name():
Expand Down Expand Up @@ -64,3 +64,35 @@ def test_integration_with_renamed_channels():

# Test if voltage is `NaN` which happens when channels get mixed up.
assert np.invert(np.any(np.isnan(v)))


def test_init_states():
"""Functional test for `init_states()`.
Checks whether, if everything is initialized in its steady state, the voltage
after 10ms is almost exactly the same as after 0ms.
"""
comp = jx.Compartment()
branch = jx.Branch(comp, 4)
cell = jx.Cell(branch, [-1, 0])
cell.branch(0).comp(0.0).record()

cell.branch(0).insert(Na())
cell.branch(1).insert(K())
cell.branch(1).comp(0.0).insert(Km())
cell.branch(0).comp(1.0).insert(CaT())
cell.insert(CaL())
cell.insert(Leak())

cell.insert(HH())

cell.set("voltages", -62.0) # At -70.0 there is a rebound spike.
cell.init_states()
v = jx.integrate(cell, t_max=20.0)

last_voltage = v[0, -1]
cell.set("voltages", last_voltage)
cell.init_states()

v = jx.integrate(cell, t_max=10.0)
assert np.abs(v[0, 0] - v[0, -1]) < 0.02

0 comments on commit 381ce4a

Please sign in to comment.