Skip to content

Commit

Permalink
wip: started work on new channel API
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Dec 22, 2024
1 parent 6ec8962 commit a0874b4
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 87 deletions.
5 changes: 2 additions & 3 deletions jaxley/channels/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@ class Channel:
`uA/cm2`."""

_name = None
channel_params = None
channel_states = None
current_name = None
params = None
states = None

def __init__(self, name: Optional[str] = None):
contact = (
Expand Down
76 changes: 23 additions & 53 deletions jaxley/channels/hh.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,8 @@ def __init__(self, name: Optional[str] = None):
self.current_is_in_mA_per_cm2 = True

super().__init__(name)
prefix = self._name
self.params = {
f"{prefix}_gNa": 0.12,
f"{prefix}_eNa": 50.0,
}
self.states = {
f"{prefix}_m": 0.2,
f"{prefix}_h": 0.2,
}
self.current_name = f"i_Na"
self.params = {"gNa": 0.12, "eNa": 50.0}
self.states = {"m": 0.2, "h": 0.2}

def update_states(
self,
Expand All @@ -35,11 +27,11 @@ def update_states(
params: Dict[str, jnp.ndarray],
) -> Dict[str, jnp.ndarray]:
"""Return updated HH channel state."""
prefix = self._name
m, h = states[f"{prefix}_m"], states[f"{prefix}_h"]
m, h = states["m"], states["h"]

new_m = solve_gate_exponential(m, dt, *self.m_gate(v))
new_h = solve_gate_exponential(h, dt, *self.h_gate(v))
return {f"{prefix}_m": new_m, f"{prefix}_h": new_h}
return {"m": new_m, "h": new_h}

def compute_current(
self,
Expand All @@ -48,12 +40,11 @@ def compute_current(
params: Dict[str, jnp.ndarray],
) -> jnp.ndarray:
"""Return current through HH channels."""
prefix = self._name
m, h = states[f"{prefix}_m"], states[f"{prefix}_h"]

gNa = params[f"{prefix}_gNa"] * (m**3) * h # S/cm^2
m, h = states["m"], states["h"]
gNa, eNa = params["gNa"], params["eNa"]

return gNa * (v - params[f"{prefix}_eNa"])
gNa = gNa * (m**3) * h # S/cm^2
return gNa * (v - eNa)

def init_state(
self,
Expand All @@ -63,13 +54,9 @@ def init_state(
dt: float,
) -> Dict[str, jnp.ndarray]:
"""Initialize the state such at fixed point of gate dynamics."""
prefix = self._name
alpha_m, beta_m = self.m_gate(v)
alpha_h, beta_h = self.h_gate(v)
return {
f"{prefix}_m": alpha_m / (alpha_m + beta_m),
f"{prefix}_h": alpha_h / (alpha_h + beta_h),
}
return {"m": alpha_m / (alpha_m + beta_m), "h": alpha_h / (alpha_h + beta_h)}

@staticmethod
def m_gate(v):
Expand All @@ -91,15 +78,8 @@ def __init__(self, name: Optional[str] = None):
self.current_is_in_mA_per_cm2 = True

super().__init__(name)
prefix = self._name
self.params = {
f"{prefix}_gK": 0.036,
f"{prefix}_eK": -77.0,
}
self.states = {
f"{prefix}_n": 0.2,
}
self.current_name = f"i_K"
self.params = {"gK": 0.036, "eK": -77.0}
self.states = {"n": 0.2}

def update_states(
self,
Expand All @@ -109,10 +89,10 @@ def update_states(
params: Dict[str, jnp.ndarray],
) -> Dict[str, jnp.ndarray]:
"""Return updated HH channel state."""
prefix = self._name
n = states[f"{prefix}_n"]
n = states["n"]

new_n = solve_gate_exponential(n, dt, *self.n_gate(v))
return {f"{prefix}_n": new_n}
return {"n": new_n}

def compute_current(
self,
Expand All @@ -121,12 +101,11 @@ def compute_current(
params: Dict[str, jnp.ndarray],
) -> jnp.ndarray:
"""Return current through HH channels."""
prefix = self._name
n = states[f"{prefix}_n"]
n = states["n"]
gK, eK = params["gK"], params["eK"]

gK = params[f"{prefix}_gK"] * n**4 # S/cm^2

return gK * (v - params[f"{prefix}_eK"])
gK = gK * n**4 # S/cm^2
return gK * (v - eK)

def init_state(
self,
Expand All @@ -136,9 +115,8 @@ def init_state(
dt: float,
) -> Dict[str, jnp.ndarray]:
"""Initialize the state such at fixed point of gate dynamics."""
prefix = self._name
alpha_n, beta_n = self.n_gate(v)
return {f"{prefix}_n": alpha_n / (alpha_n + beta_n)}
return {"n": alpha_n / (alpha_n + beta_n)}

@staticmethod
def n_gate(v):
Expand All @@ -154,13 +132,8 @@ def __init__(self, name: Optional[str] = None):
self.current_is_in_mA_per_cm2 = True

super().__init__(name)
prefix = self._name
self.params = {
f"{prefix}_gLeak": 0.0003,
f"{prefix}_eLeak": -54.3,
}
self.params = {"gLeak": 0.0003, "eLeak": -54.3}
self.states = {}
self.current_name = f"i_Leak"

def update_states(
self,
Expand All @@ -179,10 +152,9 @@ def compute_current(
params: Dict[str, jnp.ndarray],
) -> jnp.ndarray:
"""Return current through HH channels."""
prefix = self._name
gLeak = params[f"{prefix}_gLeak"] # S/cm^2
gLeak, eLeak = params["gLeak"], params["eLeak"]

return gLeak * (v - params[f"{prefix}_eLeak"])
return gLeak * (v - eLeak)

def init_state(
self,
Expand Down Expand Up @@ -219,8 +191,6 @@ def __init__(self, name: Optional[str] = None):
**self.Leak.states,
}

self.current_name = f"i_HH"

def change_name(self, new_name: str):
self._name = new_name
for channel in self.channels:
Expand Down
106 changes: 106 additions & 0 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import jax.numpy as jnp
import numpy as np
from optree import tree_map_with_path
import pandas as pd
from jax import jit, vmap
from jax.lax import ScatterDimensionNumbers, scatter_add
Expand Down Expand Up @@ -2734,3 +2735,108 @@ def __enter__(self):

def __exit__(self, exc_type, exc_value, exc_traceback):
pass


class AutoPrefix:
"""Wrapper for Channel classes that transparently handles name prefixing using pytrees."""

def __init__(self, mech: Union[Channel, Synapse]):
"""Initialize wrapper with a channel instance."""
# Store the wrapped channel
self._wrapped = mech
self.prefix = f"{self._wrapped.name}_"
self._name = self._wrapped._name
self.current_name = "i_" + self.prefix[:-1]

# Make this class pretend to be the wrapped class
self.__class__.__name__ = mech.__class__.__name__
self.__class__.__qualname__ = mech.__class__.__qualname__
self.__class__.__module__ = mech.__class__.__module__

if isinstance(self._wrapped, Synapse):
delattr(self.__class__, "init_state")

def __getattr__(self, name: str):
"""Delegate unknown attributes to the wrapped channel."""
return getattr(self._wrapped, name)

def __repr__(self):
"""Return the same representation as the wrapped channel."""
return repr(self._wrapped)

def __str__(self):
"""Return the same string representation as the wrapped channel."""
return str(self._wrapped)

def _transform_dict_keys(
self, d: Dict[str, Any], add_prefix: bool = True
) -> Dict[str, Any]:
"""Transform dictionary keys using pytree operations."""

def transform_key(path, value):
key = path[-1] # Get the leaf key
if isinstance(key, str):
if add_prefix:
return f"{self.prefix}{key}", value
elif key.startswith(self.prefix):
return key[len(self.prefix) :], value
return key, value

transformed = tree_map_with_path(
lambda p, v: transform_key(p, v)[1],
d,
is_leaf=lambda x: isinstance(x, (jnp.ndarray, float, int)),
)
return transformed

@property
def params(self) -> Dict[str, Any]:
"""Get prefixed parameters."""
return self._transform_dict_keys(self._wrapped.params, add_prefix=True)

@property
def states(self) -> Dict[str, Any]:
"""Get prefixed states."""
return self._transform_dict_keys(self._wrapped.states, add_prefix=True)

def update_states(
self,
states: Dict[str, Any],
dt: float,
v: jnp.ndarray,
params: Dict[str, Any],
) -> Dict[str, Any]:
"""Update states with automatic prefix handling using pytrees."""
states = self._transform_dict_keys(states, add_prefix=False)
params = self._transform_dict_keys(params, add_prefix=False)

states = self._wrapped.update_states(states, dt, v, params)

return self._transform_dict_keys(states, add_prefix=True)

def compute_current(
self,
states: Dict[str, Any],
v: jnp.ndarray,
params: Dict[str, Any],
) -> jnp.ndarray:
"""Compute current with automatic prefix handling using pytrees."""
states = self._transform_dict_keys(states, add_prefix=False)
params = self._transform_dict_keys(params, add_prefix=False)

return self._wrapped.compute_current(states, v, params)

def init_state(
self,
states: Dict[str, Any],
v: jnp.ndarray,
params: Dict[str, Any],
dt: float,
) -> Dict[str, Any]:
"""Initialize states with automatic prefix handling using pytrees."""
states = self._transform_dict_keys(states, add_prefix=False)
params = self._transform_dict_keys(params, add_prefix=False)

init_states = self._wrapped.init_state(states, v, params, dt)

return self._transform_dict_keys(init_states, add_prefix=True)
21 changes: 9 additions & 12 deletions jaxley/synapses/ionotropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,12 @@ class IonotropicSynapse(Synapse):

def __init__(self, name: Optional[str] = None):
super().__init__(name)
prefix = self._name
self.params = {
f"{prefix}_gS": 1e-4,
f"{prefix}_e_syn": 0.0,
f"{prefix}_k_minus": 0.025,
"gS": 1e-4,
"e_syn": 0.0,
"k_minus": 0.025,
}
self.states = {f"{prefix}_s": 0.2}
self.states = {"s": 0.2}

def update_states(
self,
Expand All @@ -48,21 +47,19 @@ def update_states(
params: Dict,
) -> Dict:
"""Return updated synapse state and current."""
prefix = self._name
v_th = -35.0 # mV
delta = 10.0 # mV

s_inf = 1.0 / (1.0 + save_exp((v_th - pre_voltage) / delta))
tau_s = (1.0 - s_inf) / params[f"{prefix}_k_minus"]
tau_s = (1.0 - s_inf) / params["k_minus"]

slope = -1.0 / tau_s
exp_term = save_exp(slope * delta_t)
new_s = states[f"{prefix}_s"] * exp_term + s_inf * (1.0 - exp_term)
return {f"{prefix}_s": new_s}
new_s = states["s"] * exp_term + s_inf * (1.0 - exp_term)
return {"s": new_s}

def compute_current(
self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict
) -> float:
prefix = self._name
g_syn = params[f"{prefix}_gS"] * states[f"{prefix}_s"]
return g_syn * (post_voltage - params[f"{prefix}_e_syn"])
g_syn = params["gS"] * states["s"]
return g_syn * (post_voltage - params["e_syn"])
4 changes: 2 additions & 2 deletions jaxley/synapses/synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class Synapse:
"""

_name = None
synapse_params = None
synapse_states = None
params = None
states = None

def __init__(self, name: Optional[str] = None):
self._name = name if name else self.__class__.__name__
Expand Down
14 changes: 5 additions & 9 deletions jaxley/synapses/tanh_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@ class TanhRateSynapse(Synapse):

def __init__(self, name: Optional[str] = None):
super().__init__(name)
prefix = self._name
self.params = {
f"{prefix}_gS": 1e-4,
f"{prefix}_x_offset": -70.0,
f"{prefix}_slope": 1.0,
"gS": 1e-4,
"x_offset": -70.0,
"slope": 1.0,
}
self.states = {}

Expand All @@ -38,12 +37,9 @@ def compute_current(
self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict
) -> float:
"""Return updated synapse state and current."""
prefix = self._name
current = (
-1
* params[f"{prefix}_gS"]
* jnp.tanh(
(pre_voltage - params[f"{prefix}_x_offset"]) * params[f"{prefix}_slope"]
)
* params["gS"]
* jnp.tanh((pre_voltage - params["x_offset"]) * params["slope"])
)
return current
Loading

0 comments on commit a0874b4

Please sign in to comment.