Skip to content

Commit

Permalink
Merge pull request #222 from huangziwei/channel_params_prefix
Browse files Browse the repository at this point in the history
Channel name prefix to params and states
  • Loading branch information
huangziwei authored Jan 18, 2024
2 parents 5e7d1e6 + 562a65b commit a17e686
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 133 deletions.
37 changes: 32 additions & 5 deletions jaxley/channels/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,50 @@


class Channel:
_name = None
channel_params = None
channel_states = None

def __init__(self):
def __init__(self, name: Optional[str] = None):
self._name = name if name else self.__class__.__name__
self.vmaped_update_states = vmap(self.update_states, in_axes=(0, None, 0, 0))
self.vmapped_compute_current = vmap(
self.compute_current, in_axes=(None, 0, None)
)

@staticmethod
@property
def name(self) -> Optional[str]:
return self._name

def change_name(self, new_name: str):
old_prefix = self._name + "_"
new_prefix = new_name + "_"

self._name = new_name
self.channel_params = {
(
new_prefix + key[len(old_prefix) :]
if key.startswith(old_prefix)
else key
): value
for key, value in self.channel_params.items()
}

self.channel_states = {
(
new_prefix + key[len(old_prefix) :]
if key.startswith(old_prefix)
else key
): value
for key, value in self.channel_states.items()
}

def update_states(
u, dt, voltages, params
self, u, dt, voltages, params
) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:
pass

@staticmethod
def compute_current(
u: Dict[str, jnp.ndarray], voltages, params: Dict[str, jnp.ndarray]
self, u: Dict[str, jnp.ndarray], voltages, params: Dict[str, jnp.ndarray]
):
pass
53 changes: 30 additions & 23 deletions jaxley/channels/hh.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,43 +9,50 @@
class HH(Channel):
"""Hodgkin-Huxley channel."""

channel_params = {
"HH_gNa": 0.12,
"HH_gK": 0.036,
"HH_gLeak": 0.0003,
"HH_eNa": 50.0,
"HH_eK": -77.0,
"HH_eLeak": -54.3,
}
channel_states = {"HH_m": 0.2, "HH_h": 0.2, "HH_n": 0.2}

@staticmethod
def __init__(self, name: Optional[str] = None):
super().__init__(name)
prefix = self._name
self.channel_params = {
f"{prefix}_gNa": 0.12,
f"{prefix}_gK": 0.036,
f"{prefix}_gLeak": 0.0003,
f"{prefix}_eNa": 50.0,
f"{prefix}_eK": -77.0,
f"{prefix}_eLeak": -54.3,
}
self.channel_states = {
f"{prefix}_m": 0.2,
f"{prefix}_h": 0.2,
f"{prefix}_n": 0.2,
}

def update_states(
u: Dict[str, jnp.ndarray], dt, voltages, params: Dict[str, jnp.ndarray]
self, u: Dict[str, jnp.ndarray], dt, voltages, params: Dict[str, jnp.ndarray]
):
"""Return updated HH channel state."""
ms, hs, ns = u["HH_m"], u["HH_h"], u["HH_n"]
prefix = self._name
ms, hs, ns = u[f"{prefix}_m"], u[f"{prefix}_h"], u[f"{prefix}_n"]
new_m = solve_gate_exponential(ms, dt, *_m_gate(voltages))
new_h = solve_gate_exponential(hs, dt, *_h_gate(voltages))
new_n = solve_gate_exponential(ns, dt, *_n_gate(voltages))
return {"HH_m": new_m, "HH_h": new_h, "HH_n": new_n}
return {f"{prefix}_m": new_m, f"{prefix}_h": new_h, f"{prefix}_n": new_n}

@staticmethod
def compute_current(
u: Dict[str, jnp.ndarray], voltages, params: Dict[str, jnp.ndarray]
self, u: Dict[str, jnp.ndarray], voltages, params: Dict[str, jnp.ndarray]
):
"""Return current through HH channels."""
ms, hs, ns = u["HH_m"], u["HH_h"], u["HH_n"]
prefix = self._name
ms, hs, ns = u[f"{prefix}_m"], u[f"{prefix}_h"], u[f"{prefix}_n"]

# Multiply with 1000 to convert Siemens to milli Siemens.
na_conds = params["HH_gNa"] * (ms**3) * hs * 1000 # mS/cm^2
kd_conds = params["HH_gK"] * ns**4 * 1000 # mS/cm^2
leak_conds = params["HH_gLeak"] * 1000 # mS/cm^2
na_conds = params[f"{prefix}_gNa"] * (ms**3) * hs * 1000 # mS/cm^2
kd_conds = params[f"{prefix}_gK"] * ns**4 * 1000 # mS/cm^2
leak_conds = params[f"{prefix}_gLeak"] * 1000 # mS/cm^2

return (
na_conds * (voltages - params["HH_eNa"])
+ kd_conds * (voltages - params["HH_eK"])
+ leak_conds * (voltages - params["HH_eLeak"])
na_conds * (voltages - params[f"{prefix}_eNa"])
+ kd_conds * (voltages - params[f"{prefix}_eK"])
+ leak_conds * (voltages - params[f"{prefix}_eLeak"])
)


Expand Down
Loading

0 comments on commit a17e686

Please sign in to comment.