From a0874b484fc002dc832d5ad00a2554df2af3b6b1 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Sun, 22 Dec 2024 18:53:11 +0100 Subject: [PATCH] wip: started work on new channel API --- jaxley/channels/channel.py | 5 +- jaxley/channels/hh.py | 76 ++++++++---------------- jaxley/modules/base.py | 106 ++++++++++++++++++++++++++++++++++ jaxley/synapses/ionotropic.py | 21 +++---- jaxley/synapses/synapse.py | 4 +- jaxley/synapses/tanh_rate.py | 14 ++--- jaxley/synapses/test.py | 13 ++--- 7 files changed, 152 insertions(+), 87 deletions(-) diff --git a/jaxley/channels/channel.py b/jaxley/channels/channel.py index b8a1dc41..81d48fdb 100644 --- a/jaxley/channels/channel.py +++ b/jaxley/channels/channel.py @@ -16,9 +16,8 @@ class Channel: `uA/cm2`.""" _name = None - channel_params = None - channel_states = None - current_name = None + params = None + states = None def __init__(self, name: Optional[str] = None): contact = ( diff --git a/jaxley/channels/hh.py b/jaxley/channels/hh.py index 8f9072c2..a06dc5a5 100644 --- a/jaxley/channels/hh.py +++ b/jaxley/channels/hh.py @@ -16,16 +16,8 @@ def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - prefix = self._name - self.params = { - f"{prefix}_gNa": 0.12, - f"{prefix}_eNa": 50.0, - } - self.states = { - f"{prefix}_m": 0.2, - f"{prefix}_h": 0.2, - } - self.current_name = f"i_Na" + self.params = {"gNa": 0.12, "eNa": 50.0} + self.states = {"m": 0.2, "h": 0.2} def update_states( self, @@ -35,11 +27,11 @@ def update_states( params: Dict[str, jnp.ndarray], ) -> Dict[str, jnp.ndarray]: """Return updated HH channel state.""" - prefix = self._name - m, h = states[f"{prefix}_m"], states[f"{prefix}_h"] + m, h = states["m"], states["h"] + new_m = solve_gate_exponential(m, dt, *self.m_gate(v)) new_h = solve_gate_exponential(h, dt, *self.h_gate(v)) - return {f"{prefix}_m": new_m, f"{prefix}_h": new_h} + return {"m": new_m, "h": new_h} def compute_current( self, @@ -48,12 +40,11 @@ def compute_current( params: Dict[str, jnp.ndarray], ) -> jnp.ndarray: """Return current through HH channels.""" - prefix = self._name - m, h = states[f"{prefix}_m"], states[f"{prefix}_h"] - - gNa = params[f"{prefix}_gNa"] * (m**3) * h # S/cm^2 + m, h = states["m"], states["h"] + gNa, eNa = params["gNa"], params["eNa"] - return gNa * (v - params[f"{prefix}_eNa"]) + gNa = gNa * (m**3) * h # S/cm^2 + return gNa * (v - eNa) def init_state( self, @@ -63,13 +54,9 @@ def init_state( dt: float, ) -> Dict[str, jnp.ndarray]: """Initialize the state such at fixed point of gate dynamics.""" - prefix = self._name alpha_m, beta_m = self.m_gate(v) alpha_h, beta_h = self.h_gate(v) - return { - f"{prefix}_m": alpha_m / (alpha_m + beta_m), - f"{prefix}_h": alpha_h / (alpha_h + beta_h), - } + return {"m": alpha_m / (alpha_m + beta_m), "h": alpha_h / (alpha_h + beta_h)} @staticmethod def m_gate(v): @@ -91,15 +78,8 @@ def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - prefix = self._name - self.params = { - f"{prefix}_gK": 0.036, - f"{prefix}_eK": -77.0, - } - self.states = { - f"{prefix}_n": 0.2, - } - self.current_name = f"i_K" + self.params = {"gK": 0.036, "eK": -77.0} + self.states = {"n": 0.2} def update_states( self, @@ -109,10 +89,10 @@ def update_states( params: Dict[str, jnp.ndarray], ) -> Dict[str, jnp.ndarray]: """Return updated HH channel state.""" - prefix = self._name - n = states[f"{prefix}_n"] + n = states["n"] + new_n = solve_gate_exponential(n, dt, *self.n_gate(v)) - return {f"{prefix}_n": new_n} + return {"n": new_n} def compute_current( self, @@ -121,12 +101,11 @@ def compute_current( params: Dict[str, jnp.ndarray], ) -> jnp.ndarray: """Return current through HH channels.""" - prefix = self._name - n = states[f"{prefix}_n"] + n = states["n"] + gK, eK = params["gK"], params["eK"] - gK = params[f"{prefix}_gK"] * n**4 # S/cm^2 - - return gK * (v - params[f"{prefix}_eK"]) + gK = gK * n**4 # S/cm^2 + return gK * (v - eK) def init_state( self, @@ -136,9 +115,8 @@ def init_state( dt: float, ) -> Dict[str, jnp.ndarray]: """Initialize the state such at fixed point of gate dynamics.""" - prefix = self._name alpha_n, beta_n = self.n_gate(v) - return {f"{prefix}_n": alpha_n / (alpha_n + beta_n)} + return {"n": alpha_n / (alpha_n + beta_n)} @staticmethod def n_gate(v): @@ -154,13 +132,8 @@ def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - prefix = self._name - self.params = { - f"{prefix}_gLeak": 0.0003, - f"{prefix}_eLeak": -54.3, - } + self.params = {"gLeak": 0.0003, "eLeak": -54.3} self.states = {} - self.current_name = f"i_Leak" def update_states( self, @@ -179,10 +152,9 @@ def compute_current( params: Dict[str, jnp.ndarray], ) -> jnp.ndarray: """Return current through HH channels.""" - prefix = self._name - gLeak = params[f"{prefix}_gLeak"] # S/cm^2 + gLeak, eLeak = params["gLeak"], params["eLeak"] - return gLeak * (v - params[f"{prefix}_eLeak"]) + return gLeak * (v - eLeak) def init_state( self, @@ -219,8 +191,6 @@ def __init__(self, name: Optional[str] = None): **self.Leak.states, } - self.current_name = f"i_HH" - def change_name(self, new_name: str): self._name = new_name for channel in self.channels: diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index d344061a..a4683f0d 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -11,6 +11,7 @@ import jax.numpy as jnp import numpy as np +from optree import tree_map_with_path import pandas as pd from jax import jit, vmap from jax.lax import ScatterDimensionNumbers, scatter_add @@ -2734,3 +2735,108 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, exc_traceback): pass + + +class AutoPrefix: + """Wrapper for Channel classes that transparently handles name prefixing using pytrees.""" + + def __init__(self, mech: Union[Channel, Synapse]): + """Initialize wrapper with a channel instance.""" + # Store the wrapped channel + self._wrapped = mech + self.prefix = f"{self._wrapped.name}_" + self._name = self._wrapped._name + self.current_name = "i_" + self.prefix[:-1] + + # Make this class pretend to be the wrapped class + self.__class__.__name__ = mech.__class__.__name__ + self.__class__.__qualname__ = mech.__class__.__qualname__ + self.__class__.__module__ = mech.__class__.__module__ + + if isinstance(self._wrapped, Synapse): + delattr(self.__class__, "init_state") + + def __getattr__(self, name: str): + """Delegate unknown attributes to the wrapped channel.""" + return getattr(self._wrapped, name) + + def __repr__(self): + """Return the same representation as the wrapped channel.""" + return repr(self._wrapped) + + def __str__(self): + """Return the same string representation as the wrapped channel.""" + return str(self._wrapped) + + def _transform_dict_keys( + self, d: Dict[str, Any], add_prefix: bool = True + ) -> Dict[str, Any]: + """Transform dictionary keys using pytree operations.""" + + def transform_key(path, value): + key = path[-1] # Get the leaf key + if isinstance(key, str): + if add_prefix: + return f"{self.prefix}{key}", value + elif key.startswith(self.prefix): + return key[len(self.prefix) :], value + return key, value + + transformed = tree_map_with_path( + lambda p, v: transform_key(p, v)[1], + d, + is_leaf=lambda x: isinstance(x, (jnp.ndarray, float, int)), + ) + return transformed + + @property + def params(self) -> Dict[str, Any]: + """Get prefixed parameters.""" + return self._transform_dict_keys(self._wrapped.params, add_prefix=True) + + @property + def states(self) -> Dict[str, Any]: + """Get prefixed states.""" + return self._transform_dict_keys(self._wrapped.states, add_prefix=True) + + def update_states( + self, + states: Dict[str, Any], + dt: float, + v: jnp.ndarray, + params: Dict[str, Any], + ) -> Dict[str, Any]: + """Update states with automatic prefix handling using pytrees.""" + states = self._transform_dict_keys(states, add_prefix=False) + params = self._transform_dict_keys(params, add_prefix=False) + + states = self._wrapped.update_states(states, dt, v, params) + + return self._transform_dict_keys(states, add_prefix=True) + + def compute_current( + self, + states: Dict[str, Any], + v: jnp.ndarray, + params: Dict[str, Any], + ) -> jnp.ndarray: + """Compute current with automatic prefix handling using pytrees.""" + states = self._transform_dict_keys(states, add_prefix=False) + params = self._transform_dict_keys(params, add_prefix=False) + + return self._wrapped.compute_current(states, v, params) + + def init_state( + self, + states: Dict[str, Any], + v: jnp.ndarray, + params: Dict[str, Any], + dt: float, + ) -> Dict[str, Any]: + """Initialize states with automatic prefix handling using pytrees.""" + states = self._transform_dict_keys(states, add_prefix=False) + params = self._transform_dict_keys(params, add_prefix=False) + + init_states = self._wrapped.init_state(states, v, params, dt) + + return self._transform_dict_keys(init_states, add_prefix=True) diff --git a/jaxley/synapses/ionotropic.py b/jaxley/synapses/ionotropic.py index 101dd95b..58514c04 100644 --- a/jaxley/synapses/ionotropic.py +++ b/jaxley/synapses/ionotropic.py @@ -31,13 +31,12 @@ class IonotropicSynapse(Synapse): def __init__(self, name: Optional[str] = None): super().__init__(name) - prefix = self._name self.params = { - f"{prefix}_gS": 1e-4, - f"{prefix}_e_syn": 0.0, - f"{prefix}_k_minus": 0.025, + "gS": 1e-4, + "e_syn": 0.0, + "k_minus": 0.025, } - self.states = {f"{prefix}_s": 0.2} + self.states = {"s": 0.2} def update_states( self, @@ -48,21 +47,19 @@ def update_states( params: Dict, ) -> Dict: """Return updated synapse state and current.""" - prefix = self._name v_th = -35.0 # mV delta = 10.0 # mV s_inf = 1.0 / (1.0 + save_exp((v_th - pre_voltage) / delta)) - tau_s = (1.0 - s_inf) / params[f"{prefix}_k_minus"] + tau_s = (1.0 - s_inf) / params["k_minus"] slope = -1.0 / tau_s exp_term = save_exp(slope * delta_t) - new_s = states[f"{prefix}_s"] * exp_term + s_inf * (1.0 - exp_term) - return {f"{prefix}_s": new_s} + new_s = states["s"] * exp_term + s_inf * (1.0 - exp_term) + return {"s": new_s} def compute_current( self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict ) -> float: - prefix = self._name - g_syn = params[f"{prefix}_gS"] * states[f"{prefix}_s"] - return g_syn * (post_voltage - params[f"{prefix}_e_syn"]) + g_syn = params["gS"] * states["s"] + return g_syn * (post_voltage - params["e_syn"]) diff --git a/jaxley/synapses/synapse.py b/jaxley/synapses/synapse.py index 38cd7d3f..ea460512 100644 --- a/jaxley/synapses/synapse.py +++ b/jaxley/synapses/synapse.py @@ -15,8 +15,8 @@ class Synapse: """ _name = None - synapse_params = None - synapse_states = None + params = None + states = None def __init__(self, name: Optional[str] = None): self._name = name if name else self.__class__.__name__ diff --git a/jaxley/synapses/tanh_rate.py b/jaxley/synapses/tanh_rate.py index 6bbd49cc..8a95e79d 100644 --- a/jaxley/synapses/tanh_rate.py +++ b/jaxley/synapses/tanh_rate.py @@ -15,11 +15,10 @@ class TanhRateSynapse(Synapse): def __init__(self, name: Optional[str] = None): super().__init__(name) - prefix = self._name self.params = { - f"{prefix}_gS": 1e-4, - f"{prefix}_x_offset": -70.0, - f"{prefix}_slope": 1.0, + "gS": 1e-4, + "x_offset": -70.0, + "slope": 1.0, } self.states = {} @@ -38,12 +37,9 @@ def compute_current( self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict ) -> float: """Return updated synapse state and current.""" - prefix = self._name current = ( -1 - * params[f"{prefix}_gS"] - * jnp.tanh( - (pre_voltage - params[f"{prefix}_x_offset"]) * params[f"{prefix}_slope"] - ) + * params["gS"] + * jnp.tanh((pre_voltage - params["x_offset"]) * params["slope"]) ) return current diff --git a/jaxley/synapses/test.py b/jaxley/synapses/test.py index 84cb5d4d..95b7d0fa 100644 --- a/jaxley/synapses/test.py +++ b/jaxley/synapses/test.py @@ -18,9 +18,8 @@ class TestSynapse(Synapse): def __init__(self, name: Optional[str] = None): super().__init__(name) - prefix = self._name - self.params = {f"{prefix}_gC": 1e-4} - self.states = {f"{prefix}_c": 0.2} + self.params = {"gC": 1e-4} + self.states = {"c": 0.2} def update_states( self, @@ -31,7 +30,6 @@ def update_states( params: Dict, ) -> Dict: """Return updated synapse state and current.""" - prefix = self._name v_th = -35.0 delta = 10.0 k_minus = 1.0 / 40.0 @@ -42,13 +40,12 @@ def update_states( s_inf = s_bar slope = -1.0 / tau_s exp_term = save_exp(slope * delta_t) - new_s = states[f"{prefix}_c"] * exp_term + s_inf * (1.0 - exp_term) - return {f"{prefix}_c": new_s} + new_s = states["c"] * exp_term + s_inf * (1.0 - exp_term) + return {"c": new_s} def compute_current( self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict ) -> float: - prefix = self._name e_syn = 0.0 - g_syn = params[f"{prefix}_gC"] * states[f"{prefix}_c"] + g_syn = params["gC"] * states["c"] return g_syn * (post_voltage - e_syn)