Skip to content

Commit

Permalink
wip: add infer global param method and other fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Dec 22, 2024
1 parent a0874b4 commit 5c11a41
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 89 deletions.
148 changes: 61 additions & 87 deletions jaxley/channels/pospischil.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,12 @@ def __init__(self, name: Optional[str] = None):
self.current_is_in_mA_per_cm2 = True

super().__init__(name)
prefix = self._name
self.params = {
f"{prefix}_gLeak": 1e-4,
f"{prefix}_eLeak": -70.0,
"gLeak": 1e-4,
"eLeak": -70.0,
}
self.states = {}
self.current_name = f"i_{prefix}"
# self.current_name = f"i_Leak"

def update_states(
self,
Expand All @@ -61,9 +60,8 @@ def compute_current(
self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]
):
"""Return current."""
prefix = self._name
gLeak = params[f"{prefix}_gLeak"] # S/cm^2
return gLeak * (v - params[f"{prefix}_eLeak"])
gLeak = params["gLeak"] # S/cm^2
return gLeak * (v - params["eLeak"])

def init_state(self, states, v, params, delta_t):
return {}
Expand All @@ -76,14 +74,13 @@ def __init__(self, name: Optional[str] = None):
self.current_is_in_mA_per_cm2 = True

super().__init__(name)
prefix = self._name
self.params = {
f"{prefix}_gNa": 50e-3,
"eNa": 50.0,
"vt": -60.0, # Global parameter, not prefixed with `Na`.
"gNa": 50e-3,
# "eNa": 50.0,
# "vt": -60.0, # Global parameter, not prefixed with `Na`.
}
self.states = {f"{prefix}_m": 0.2, f"{prefix}_h": 0.2}
self.current_name = f"i_Na"
self.states = {"m": 0.2, "h": 0.2}
# self.current_name = f"i_Na"

def update_states(
self,
Expand All @@ -93,32 +90,29 @@ def update_states(
params: Dict[str, jnp.ndarray],
):
"""Update state."""
prefix = self._name
m, h = states[f"{prefix}_m"], states[f"{prefix}_h"]
m, h = states["m"], states["h"]
new_m = solve_gate_exponential(m, dt, *self.m_gate(v, params["vt"]))
new_h = solve_gate_exponential(h, dt, *self.h_gate(v, params["vt"]))
return {f"{prefix}_m": new_m, f"{prefix}_h": new_h}
return {"m": new_m, "h": new_h}

def compute_current(
self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]
):
"""Return current."""
prefix = self._name
m, h = states[f"{prefix}_m"], states[f"{prefix}_h"]
m, h = states["m"], states["h"]

gNa = params[f"{prefix}_gNa"] * (m**3) * h # S/cm^2
gNa = params["gNa"] * (m**3) * h # S/cm^2

current = gNa * (v - params["eNa"])
return current

def init_state(self, states, v, params, delta_t):
"""Initialize the state such at fixed point of gate dynamics."""
prefix = self._name
alpha_m, beta_m = self.m_gate(v, params["vt"])
alpha_h, beta_h = self.h_gate(v, params["vt"])
return {
f"{prefix}_m": alpha_m / (alpha_m + beta_m),
f"{prefix}_h": alpha_h / (alpha_h + beta_h),
"m": alpha_m / (alpha_m + beta_m),
"h": alpha_h / (alpha_h + beta_h),
}

@staticmethod
Expand Down Expand Up @@ -147,14 +141,13 @@ def __init__(self, name: Optional[str] = None):
self.current_is_in_mA_per_cm2 = True

super().__init__(name)
prefix = self._name
self.params = {
f"{prefix}_gK": 5e-3,
"eK": -90.0,
"vt": -60.0, # Global parameter, not prefixed with `Na`.
"gK": 5e-3,
# "eK": -90.0,
# "vt": -60.0, # Global parameter, not prefixed with `Na`.
}
self.states = {f"{prefix}_n": 0.2}
self.current_name = f"i_K"
self.states = {"n": 0.2}
# self.current_name = f"i_K"

def update_states(
self,
Expand All @@ -164,27 +157,24 @@ def update_states(
params: Dict[str, jnp.ndarray],
):
"""Update state."""
prefix = self._name
n = states[f"{prefix}_n"]
n = states["n"]
new_n = solve_gate_exponential(n, dt, *self.n_gate(v, params["vt"]))
return {f"{prefix}_n": new_n}
return {"n": new_n}

def compute_current(
self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]
):
"""Return current."""
prefix = self._name
n = states[f"{prefix}_n"]
n = states["n"]

gK = params[f"{prefix}_gK"] * (n**4) # S/cm^2
gK = params["gK"] * (n**4) # S/cm^2

return gK * (v - params["eK"])

def init_state(self, states, v, params, delta_t):
"""Initialize the state such at fixed point of gate dynamics."""
prefix = self._name
alpha_n, beta_n = self.n_gate(v, params["vt"])
return {f"{prefix}_n": alpha_n / (alpha_n + beta_n)}
return {"n": alpha_n / (alpha_n + beta_n)}

@staticmethod
def n_gate(v, vt):
Expand All @@ -203,14 +193,13 @@ def __init__(self, name: Optional[str] = None):
self.current_is_in_mA_per_cm2 = True

super().__init__(name)
prefix = self._name
self.params = {
f"{prefix}_gKm": 0.004e-3,
f"{prefix}_taumax": 4000.0,
f"eK": -90.0,
"gKm": 0.004e-3,
"taumax": 4000.0,
# f"eK": -90.0,
}
self.states = {f"{prefix}_p": 0.2}
self.current_name = f"i_K"
self.states = {"p": 0.2}
# self.current_name = f"i_K"

def update_states(
self,
Expand All @@ -220,28 +209,23 @@ def update_states(
params: Dict[str, jnp.ndarray],
):
"""Update state."""
prefix = self._name
p = states[f"{prefix}_p"]
new_p = solve_inf_gate_exponential(
p, dt, *self.p_gate(v, params[f"{prefix}_taumax"])
)
return {f"{prefix}_p": new_p}
p = states["p"]
new_p = solve_inf_gate_exponential(p, dt, *self.p_gate(v, params["taumax"]))
return {"p": new_p}

def compute_current(
self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]
):
"""Return current."""
prefix = self._name
p = states[f"{prefix}_p"]
p = states["p"]

gKm = params[f"{prefix}_gKm"] * p # S/cm^2
gKm = params["gKm"] * p # S/cm^2
return gKm * (v - params["eK"])

def init_state(self, states, v, params, delta_t):
"""Initialize the state such at fixed point of gate dynamics."""
prefix = self._name
alpha_p, beta_p = self.p_gate(v, params[f"{prefix}_taumax"])
return {f"{prefix}_p": alpha_p / (alpha_p + beta_p)}
alpha_p, beta_p = self.p_gate(v, params["taumax"])
return {"p": alpha_p / (alpha_p + beta_p)}

@staticmethod
def p_gate(v, taumax):
Expand All @@ -260,13 +244,12 @@ def __init__(self, name: Optional[str] = None):
self.current_is_in_mA_per_cm2 = True

super().__init__(name)
prefix = self._name
self.params = {
f"{prefix}_gCaL": 0.1e-3,
"eCa": 120.0,
"gCaL": 0.1e-3,
# "eCa": 120.0,
}
self.states = {f"{prefix}_q": 0.2, f"{prefix}_r": 0.2}
self.current_name = f"i_Ca"
self.states = {"q": 0.2, "r": 0.2}
# self.current_name = f"i_Ca"

def update_states(
self,
Expand All @@ -276,30 +259,27 @@ def update_states(
params: Dict[str, jnp.ndarray],
):
"""Update state."""
prefix = self._name
q, r = states[f"{prefix}_q"], states[f"{prefix}_r"]
q, r = states["q"], states["r"]
new_q = solve_gate_exponential(q, dt, *self.q_gate(v))
new_r = solve_gate_exponential(r, dt, *self.r_gate(v))
return {f"{prefix}_q": new_q, f"{prefix}_r": new_r}
return {"q": new_q, "r": new_r}

def compute_current(
self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]
):
"""Return current."""
prefix = self._name
q, r = states[f"{prefix}_q"], states[f"{prefix}_r"]
gCaL = params[f"{prefix}_gCaL"] * (q**2) * r # S/cm^2
q, r = states["q"], states["r"]
gCaL = params["gCaL"] * (q**2) * r # S/cm^2

return gCaL * (v - params["eCa"])

def init_state(self, states, v, params, delta_t):
"""Initialize the state such at fixed point of gate dynamics."""
prefix = self._name
alpha_q, beta_q = self.q_gate(v)
alpha_r, beta_r = self.r_gate(v)
return {
f"{prefix}_q": alpha_q / (alpha_q + beta_q),
f"{prefix}_r": alpha_r / (alpha_r + beta_r),
"q": alpha_q / (alpha_q + beta_q),
"r": alpha_r / (alpha_r + beta_r),
}

@staticmethod
Expand Down Expand Up @@ -328,14 +308,13 @@ def __init__(self, name: Optional[str] = None):
self.current_is_in_mA_per_cm2 = True

super().__init__(name)
prefix = self._name
self.params = {
f"{prefix}_gCaT": 0.4e-4,
f"{prefix}_vx": 2.0,
"eCa": 120.0, # Global parameter, not prefixed with `CaT`.
"gCaT": 0.4e-4,
"vx": 2.0,
# "eCa": 120.0, # Global parameter, not prefixed with `CaT`.
}
self.states = {f"{prefix}_u": 0.2}
self.current_name = f"i_Ca"
self.states = {"u": 0.2}
# self.current_name = f"i_Ca"

def update_states(
self,
Expand All @@ -345,30 +324,25 @@ def update_states(
params: Dict[str, jnp.ndarray],
):
"""Update state."""
prefix = self._name
u = states[f"{prefix}_u"]
new_u = solve_inf_gate_exponential(
u, dt, *self.u_gate(v, params[f"{prefix}_vx"])
)
return {f"{prefix}_u": new_u}
u = states["u"]
new_u = solve_inf_gate_exponential(u, dt, *self.u_gate(v, params["vx"]))
return {"u": new_u}

def compute_current(
self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]
):
"""Return current."""
prefix = self._name
u = states[f"{prefix}_u"]
s_inf = 1.0 / (1.0 + save_exp(-(v + params[f"{prefix}_vx"] + 57.0) / 6.2))
u = states["u"]
s_inf = 1.0 / (1.0 + save_exp(-(v + params["vx"] + 57.0) / 6.2))

gCaT = params[f"{prefix}_gCaT"] * (s_inf**2) * u # S/cm^2
gCaT = params["gCaT"] * (s_inf**2) * u # S/cm^2

return gCaT * (v - params["eCa"])

def init_state(self, states, v, params, delta_t):
"""Initialize the state such at fixed point of gate dynamics."""
prefix = self._name
alpha_u, beta_u = self.u_gate(v, params[f"{prefix}_vx"])
return {f"{prefix}_u": alpha_u / (alpha_u + beta_u)}
alpha_u, beta_u = self.u_gate(v, params["vx"])
return {"u": alpha_u / (alpha_u + beta_u)}

@staticmethod
def u_gate(v, vx):
Expand Down
39 changes: 38 additions & 1 deletion jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1725,6 +1725,7 @@ def insert(self, channel: Channel):
Args:
channel: The channel to insert."""
channel = AutoPrefix(channel)
name = channel._name

# Channel does not yet exist in the `jx.Module` at all.
Expand Down Expand Up @@ -2746,7 +2747,11 @@ def __init__(self, mech: Union[Channel, Synapse]):
self._wrapped = mech
self.prefix = f"{self._wrapped.name}_"
self._name = self._wrapped._name
self.current_name = "i_" + self.prefix[:-1]
self.current_name = (
self._wrapped.current_name
if hasattr(self._wrapped, "current_name")
else "i_" + self.prefix[:-1]
)

# Make this class pretend to be the wrapped class
self.__class__.__name__ = mech.__class__.__name__
Expand Down Expand Up @@ -2840,3 +2845,35 @@ def init_state(
init_states = self._wrapped.init_state(states, v, params, dt)

return self._transform_dict_keys(init_states, add_prefix=True)


def infer_global_params_states(mech: Union[Channel, Synapse]) -> List[str]:
"""Infer the global parameters and states of a channel or synapse.
Infers global params and states by testing for KeyErrors in the `update_states`,
`compute_current`, and `init_state` methods.
Args:
mech: The channel or synapse to infer the global params and states of.
Returns:
A list of the inferred global parameters and states.
"""
global_state_params = {}
mech_states = mech.states
mech_params = mech.params

while True:
try:
states = {**global_state_params, **mech_states}
params = {**global_state_params, **mech_params}

mech.update_states(states, 0.025, -70, params)
mech.compute_current(states, -70, params)
if isinstance(mech, Channel):
mech.init_state(states, -70, params, 0.025)
break
except KeyError as e:
missing_key = e.args[0]
global_state_params[missing_key] = 0
return list(global_state_params)
3 changes: 2 additions & 1 deletion jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from matplotlib import pyplot as plt
from matplotlib.axes import Axes

from jaxley.modules.base import Module
from jaxley.modules.base import AutoPrefix, Module
from jaxley.modules.cell import Cell
from jaxley.utils.cell_utils import (
build_branchpoint_group_inds,
Expand Down Expand Up @@ -520,6 +520,7 @@ def _update_synapse_state_names(self, synapse_type):

def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type):
# Add synapse types to the module and infer their unique identifier.
synapse_type = AutoPrefix(synapse_type)
synapse_name = synapse_type._name
synapse_current_name = f"i_{synapse_name}"
type_ind, is_new = self._infer_synapse_type_ind(synapse_name)
Expand Down

0 comments on commit 5c11a41

Please sign in to comment.