diff --git a/jaxley/channels/pospischil.py b/jaxley/channels/pospischil.py index 8602a72c..5df21fa6 100644 --- a/jaxley/channels/pospischil.py +++ b/jaxley/channels/pospischil.py @@ -39,13 +39,12 @@ def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - prefix = self._name self.params = { - f"{prefix}_gLeak": 1e-4, - f"{prefix}_eLeak": -70.0, + "gLeak": 1e-4, + "eLeak": -70.0, } self.states = {} - self.current_name = f"i_{prefix}" + # self.current_name = f"i_Leak" def update_states( self, @@ -61,9 +60,8 @@ def compute_current( self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray] ): """Return current.""" - prefix = self._name - gLeak = params[f"{prefix}_gLeak"] # S/cm^2 - return gLeak * (v - params[f"{prefix}_eLeak"]) + gLeak = params["gLeak"] # S/cm^2 + return gLeak * (v - params["eLeak"]) def init_state(self, states, v, params, delta_t): return {} @@ -76,14 +74,13 @@ def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - prefix = self._name self.params = { - f"{prefix}_gNa": 50e-3, - "eNa": 50.0, - "vt": -60.0, # Global parameter, not prefixed with `Na`. + "gNa": 50e-3, + # "eNa": 50.0, + # "vt": -60.0, # Global parameter, not prefixed with `Na`. } - self.states = {f"{prefix}_m": 0.2, f"{prefix}_h": 0.2} - self.current_name = f"i_Na" + self.states = {"m": 0.2, "h": 0.2} + # self.current_name = f"i_Na" def update_states( self, @@ -93,32 +90,29 @@ def update_states( params: Dict[str, jnp.ndarray], ): """Update state.""" - prefix = self._name - m, h = states[f"{prefix}_m"], states[f"{prefix}_h"] + m, h = states["m"], states["h"] new_m = solve_gate_exponential(m, dt, *self.m_gate(v, params["vt"])) new_h = solve_gate_exponential(h, dt, *self.h_gate(v, params["vt"])) - return {f"{prefix}_m": new_m, f"{prefix}_h": new_h} + return {"m": new_m, "h": new_h} def compute_current( self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray] ): """Return current.""" - prefix = self._name - m, h = states[f"{prefix}_m"], states[f"{prefix}_h"] + m, h = states["m"], states["h"] - gNa = params[f"{prefix}_gNa"] * (m**3) * h # S/cm^2 + gNa = params["gNa"] * (m**3) * h # S/cm^2 current = gNa * (v - params["eNa"]) return current def init_state(self, states, v, params, delta_t): """Initialize the state such at fixed point of gate dynamics.""" - prefix = self._name alpha_m, beta_m = self.m_gate(v, params["vt"]) alpha_h, beta_h = self.h_gate(v, params["vt"]) return { - f"{prefix}_m": alpha_m / (alpha_m + beta_m), - f"{prefix}_h": alpha_h / (alpha_h + beta_h), + "m": alpha_m / (alpha_m + beta_m), + "h": alpha_h / (alpha_h + beta_h), } @staticmethod @@ -147,14 +141,13 @@ def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - prefix = self._name self.params = { - f"{prefix}_gK": 5e-3, - "eK": -90.0, - "vt": -60.0, # Global parameter, not prefixed with `Na`. + "gK": 5e-3, + # "eK": -90.0, + # "vt": -60.0, # Global parameter, not prefixed with `Na`. } - self.states = {f"{prefix}_n": 0.2} - self.current_name = f"i_K" + self.states = {"n": 0.2} + # self.current_name = f"i_K" def update_states( self, @@ -164,27 +157,24 @@ def update_states( params: Dict[str, jnp.ndarray], ): """Update state.""" - prefix = self._name - n = states[f"{prefix}_n"] + n = states["n"] new_n = solve_gate_exponential(n, dt, *self.n_gate(v, params["vt"])) - return {f"{prefix}_n": new_n} + return {"n": new_n} def compute_current( self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray] ): """Return current.""" - prefix = self._name - n = states[f"{prefix}_n"] + n = states["n"] - gK = params[f"{prefix}_gK"] * (n**4) # S/cm^2 + gK = params["gK"] * (n**4) # S/cm^2 return gK * (v - params["eK"]) def init_state(self, states, v, params, delta_t): """Initialize the state such at fixed point of gate dynamics.""" - prefix = self._name alpha_n, beta_n = self.n_gate(v, params["vt"]) - return {f"{prefix}_n": alpha_n / (alpha_n + beta_n)} + return {"n": alpha_n / (alpha_n + beta_n)} @staticmethod def n_gate(v, vt): @@ -203,14 +193,13 @@ def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - prefix = self._name self.params = { - f"{prefix}_gKm": 0.004e-3, - f"{prefix}_taumax": 4000.0, - f"eK": -90.0, + "gKm": 0.004e-3, + "taumax": 4000.0, + # f"eK": -90.0, } - self.states = {f"{prefix}_p": 0.2} - self.current_name = f"i_K" + self.states = {"p": 0.2} + # self.current_name = f"i_K" def update_states( self, @@ -220,28 +209,23 @@ def update_states( params: Dict[str, jnp.ndarray], ): """Update state.""" - prefix = self._name - p = states[f"{prefix}_p"] - new_p = solve_inf_gate_exponential( - p, dt, *self.p_gate(v, params[f"{prefix}_taumax"]) - ) - return {f"{prefix}_p": new_p} + p = states["p"] + new_p = solve_inf_gate_exponential(p, dt, *self.p_gate(v, params["taumax"])) + return {"p": new_p} def compute_current( self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray] ): """Return current.""" - prefix = self._name - p = states[f"{prefix}_p"] + p = states["p"] - gKm = params[f"{prefix}_gKm"] * p # S/cm^2 + gKm = params["gKm"] * p # S/cm^2 return gKm * (v - params["eK"]) def init_state(self, states, v, params, delta_t): """Initialize the state such at fixed point of gate dynamics.""" - prefix = self._name - alpha_p, beta_p = self.p_gate(v, params[f"{prefix}_taumax"]) - return {f"{prefix}_p": alpha_p / (alpha_p + beta_p)} + alpha_p, beta_p = self.p_gate(v, params["taumax"]) + return {"p": alpha_p / (alpha_p + beta_p)} @staticmethod def p_gate(v, taumax): @@ -260,13 +244,12 @@ def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - prefix = self._name self.params = { - f"{prefix}_gCaL": 0.1e-3, - "eCa": 120.0, + "gCaL": 0.1e-3, + # "eCa": 120.0, } - self.states = {f"{prefix}_q": 0.2, f"{prefix}_r": 0.2} - self.current_name = f"i_Ca" + self.states = {"q": 0.2, "r": 0.2} + # self.current_name = f"i_Ca" def update_states( self, @@ -276,30 +259,27 @@ def update_states( params: Dict[str, jnp.ndarray], ): """Update state.""" - prefix = self._name - q, r = states[f"{prefix}_q"], states[f"{prefix}_r"] + q, r = states["q"], states["r"] new_q = solve_gate_exponential(q, dt, *self.q_gate(v)) new_r = solve_gate_exponential(r, dt, *self.r_gate(v)) - return {f"{prefix}_q": new_q, f"{prefix}_r": new_r} + return {"q": new_q, "r": new_r} def compute_current( self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray] ): """Return current.""" - prefix = self._name - q, r = states[f"{prefix}_q"], states[f"{prefix}_r"] - gCaL = params[f"{prefix}_gCaL"] * (q**2) * r # S/cm^2 + q, r = states["q"], states["r"] + gCaL = params["gCaL"] * (q**2) * r # S/cm^2 return gCaL * (v - params["eCa"]) def init_state(self, states, v, params, delta_t): """Initialize the state such at fixed point of gate dynamics.""" - prefix = self._name alpha_q, beta_q = self.q_gate(v) alpha_r, beta_r = self.r_gate(v) return { - f"{prefix}_q": alpha_q / (alpha_q + beta_q), - f"{prefix}_r": alpha_r / (alpha_r + beta_r), + "q": alpha_q / (alpha_q + beta_q), + "r": alpha_r / (alpha_r + beta_r), } @staticmethod @@ -328,14 +308,13 @@ def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - prefix = self._name self.params = { - f"{prefix}_gCaT": 0.4e-4, - f"{prefix}_vx": 2.0, - "eCa": 120.0, # Global parameter, not prefixed with `CaT`. + "gCaT": 0.4e-4, + "vx": 2.0, + # "eCa": 120.0, # Global parameter, not prefixed with `CaT`. } - self.states = {f"{prefix}_u": 0.2} - self.current_name = f"i_Ca" + self.states = {"u": 0.2} + # self.current_name = f"i_Ca" def update_states( self, @@ -345,30 +324,25 @@ def update_states( params: Dict[str, jnp.ndarray], ): """Update state.""" - prefix = self._name - u = states[f"{prefix}_u"] - new_u = solve_inf_gate_exponential( - u, dt, *self.u_gate(v, params[f"{prefix}_vx"]) - ) - return {f"{prefix}_u": new_u} + u = states["u"] + new_u = solve_inf_gate_exponential(u, dt, *self.u_gate(v, params["vx"])) + return {"u": new_u} def compute_current( self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray] ): """Return current.""" - prefix = self._name - u = states[f"{prefix}_u"] - s_inf = 1.0 / (1.0 + save_exp(-(v + params[f"{prefix}_vx"] + 57.0) / 6.2)) + u = states["u"] + s_inf = 1.0 / (1.0 + save_exp(-(v + params["vx"] + 57.0) / 6.2)) - gCaT = params[f"{prefix}_gCaT"] * (s_inf**2) * u # S/cm^2 + gCaT = params["gCaT"] * (s_inf**2) * u # S/cm^2 return gCaT * (v - params["eCa"]) def init_state(self, states, v, params, delta_t): """Initialize the state such at fixed point of gate dynamics.""" - prefix = self._name - alpha_u, beta_u = self.u_gate(v, params[f"{prefix}_vx"]) - return {f"{prefix}_u": alpha_u / (alpha_u + beta_u)} + alpha_u, beta_u = self.u_gate(v, params["vx"]) + return {"u": alpha_u / (alpha_u + beta_u)} @staticmethod def u_gate(v, vx): diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index a4683f0d..cf02cb00 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1725,6 +1725,7 @@ def insert(self, channel: Channel): Args: channel: The channel to insert.""" + channel = AutoPrefix(channel) name = channel._name # Channel does not yet exist in the `jx.Module` at all. @@ -2746,7 +2747,11 @@ def __init__(self, mech: Union[Channel, Synapse]): self._wrapped = mech self.prefix = f"{self._wrapped.name}_" self._name = self._wrapped._name - self.current_name = "i_" + self.prefix[:-1] + self.current_name = ( + self._wrapped.current_name + if hasattr(self._wrapped, "current_name") + else "i_" + self.prefix[:-1] + ) # Make this class pretend to be the wrapped class self.__class__.__name__ = mech.__class__.__name__ @@ -2840,3 +2845,35 @@ def init_state( init_states = self._wrapped.init_state(states, v, params, dt) return self._transform_dict_keys(init_states, add_prefix=True) + + +def infer_global_params_states(mech: Union[Channel, Synapse]) -> List[str]: + """Infer the global parameters and states of a channel or synapse. + + Infers global params and states by testing for KeyErrors in the `update_states`, + `compute_current`, and `init_state` methods. + + Args: + mech: The channel or synapse to infer the global params and states of. + + Returns: + A list of the inferred global parameters and states. + """ + global_state_params = {} + mech_states = mech.states + mech_params = mech.params + + while True: + try: + states = {**global_state_params, **mech_states} + params = {**global_state_params, **mech_params} + + mech.update_states(states, 0.025, -70, params) + mech.compute_current(states, -70, params) + if isinstance(mech, Channel): + mech.init_state(states, -70, params, 0.025) + break + except KeyError as e: + missing_key = e.args[0] + global_state_params[missing_key] = 0 + return list(global_state_params) diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 5727446a..0019fdb4 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -13,7 +13,7 @@ from matplotlib import pyplot as plt from matplotlib.axes import Axes -from jaxley.modules.base import Module +from jaxley.modules.base import AutoPrefix, Module from jaxley.modules.cell import Cell from jaxley.utils.cell_utils import ( build_branchpoint_group_inds, @@ -520,6 +520,7 @@ def _update_synapse_state_names(self, synapse_type): def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type): # Add synapse types to the module and infer their unique identifier. + synapse_type = AutoPrefix(synapse_type) synapse_name = synapse_type._name synapse_current_name = f"i_{synapse_name}" type_ind, is_new = self._infer_synapse_type_ind(synapse_name)