From cdbc5108bc74d5396890861ba7b22e35a808d73c Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Sun, 22 Dec 2024 17:40:45 +0100 Subject: [PATCH 1/4] mv: rename channel and synapse param attrs --- jaxley/channels/channel.py | 8 ++-- jaxley/channels/hh.py | 4 +- jaxley/channels/pospischil.py | 24 ++++++------ jaxley/modules/base.py | 68 ++++++++++++++-------------------- jaxley/modules/network.py | 16 ++++---- jaxley/synapses/ionotropic.py | 4 +- jaxley/synapses/synapse.py | 8 ++-- jaxley/synapses/tanh_rate.py | 4 +- jaxley/synapses/test.py | 4 +- tests/test_channels.py | 60 +++++++++++++++--------------- tests/test_shared_state.py | 16 ++++---- tests/test_syn.py | 2 +- tests/test_synapse_indexing.py | 6 +-- 13 files changed, 106 insertions(+), 118 deletions(-) diff --git a/jaxley/channels/channel.py b/jaxley/channels/channel.py index 678b1e1e..b8a1dc41 100644 --- a/jaxley/channels/channel.py +++ b/jaxley/channels/channel.py @@ -59,22 +59,22 @@ def change_name(self, new_name: str): new_prefix = new_name + "_" self._name = new_name - self.channel_params = { + self.params = { ( new_prefix + key[len(old_prefix) :] if key.startswith(old_prefix) else key ): value - for key, value in self.channel_params.items() + for key, value in self.params.items() } - self.channel_states = { + self.states = { ( new_prefix + key[len(old_prefix) :] if key.startswith(old_prefix) else key ): value - for key, value in self.channel_states.items() + for key, value in self.states.items() } return self diff --git a/jaxley/channels/hh.py b/jaxley/channels/hh.py index c19bf002..70fc72b5 100644 --- a/jaxley/channels/hh.py +++ b/jaxley/channels/hh.py @@ -17,7 +17,7 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gNa": 0.12, f"{prefix}_gK": 0.036, f"{prefix}_gLeak": 0.0003, @@ -25,7 +25,7 @@ def __init__(self, name: Optional[str] = None): f"{prefix}_eK": -77.0, f"{prefix}_eLeak": -54.3, } - self.channel_states = { + self.states = { f"{prefix}_m": 0.2, f"{prefix}_h": 0.2, f"{prefix}_n": 0.2, diff --git a/jaxley/channels/pospischil.py b/jaxley/channels/pospischil.py index 5884deac..8602a72c 100644 --- a/jaxley/channels/pospischil.py +++ b/jaxley/channels/pospischil.py @@ -40,11 +40,11 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gLeak": 1e-4, f"{prefix}_eLeak": -70.0, } - self.channel_states = {} + self.states = {} self.current_name = f"i_{prefix}" def update_states( @@ -77,12 +77,12 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gNa": 50e-3, "eNa": 50.0, "vt": -60.0, # Global parameter, not prefixed with `Na`. } - self.channel_states = {f"{prefix}_m": 0.2, f"{prefix}_h": 0.2} + self.states = {f"{prefix}_m": 0.2, f"{prefix}_h": 0.2} self.current_name = f"i_Na" def update_states( @@ -148,12 +148,12 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gK": 5e-3, "eK": -90.0, "vt": -60.0, # Global parameter, not prefixed with `Na`. } - self.channel_states = {f"{prefix}_n": 0.2} + self.states = {f"{prefix}_n": 0.2} self.current_name = f"i_K" def update_states( @@ -204,12 +204,12 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gKm": 0.004e-3, f"{prefix}_taumax": 4000.0, f"eK": -90.0, } - self.channel_states = {f"{prefix}_p": 0.2} + self.states = {f"{prefix}_p": 0.2} self.current_name = f"i_K" def update_states( @@ -261,11 +261,11 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gCaL": 0.1e-3, "eCa": 120.0, } - self.channel_states = {f"{prefix}_q": 0.2, f"{prefix}_r": 0.2} + self.states = {f"{prefix}_q": 0.2, f"{prefix}_r": 0.2} self.current_name = f"i_Ca" def update_states( @@ -329,12 +329,12 @@ def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_gCaT": 0.4e-4, f"{prefix}_vx": 2.0, "eCa": 120.0, # Global parameter, not prefixed with `CaT`. } - self.channel_states = {f"{prefix}_u": 0.2} + self.states = {f"{prefix}_u": 0.2} self.current_name = f"i_Ca" def update_states( diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 2893f983..d344061a 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -746,9 +746,9 @@ def to_jax(self): edges = self.base.edges.to_dict(orient="list") for i, synapse in enumerate(self.base.synapses): condition = np.asarray(edges["type_ind"]) == i - for key in synapse.synapse_params: + for key in synapse.params: self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition]) - for key in synapse.synapse_states: + for key in synapse.states: self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition]) def show( @@ -782,12 +782,8 @@ def show( inds = [f"{s}_{i}" for i in inds for s in scopes] if indices else [] cols += inds cols += [ch._name for ch in self.channels] if channel_names else [] - cols += ( - sum([list(ch.channel_params) for ch in self.channels], []) if params else [] - ) - cols += ( - sum([list(ch.channel_states) for ch in self.channels], []) if states else [] - ) + cols += sum([list(ch.params) for ch in self.channels], []) if params else [] + cols += sum([list(ch.states) for ch in self.channels], []) if states else [] if not param_names is None: cols = ( @@ -916,12 +912,8 @@ def set_ncomp( start_idx = self.nodes["global_comp_index"].to_numpy()[0] ncomp_per_branch = self.base.ncomp_per_branch channel_names = [c._name for c in self.base.channels] - channel_param_names = list( - chain(*[c.channel_params for c in self.base.channels]) - ) - channel_state_names = list( - chain(*[c.channel_states for c in self.base.channels]) - ) + channel_param_names = list(chain(*[c.params for c in self.base.channels])) + channel_state_names = list(chain(*[c.states for c in self.base.channels])) radius_generating_fns = self.base._radius_generating_fns within_branch_radiuses = view["radius"].to_numpy() @@ -1166,9 +1158,9 @@ def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]): edges = self.base.edges.to_dict(orient="list") for i, synapse in enumerate(self.base.synapses): condition = np.asarray(edges["type_ind"]) == i - for key in list(synapse.synapse_params.keys()): + for key in list(synapse.params.keys()): self.base.edges.loc[condition, key] = all_params[key] - for key in list(synapse.synapse_states.keys()): + for key in list(synapse.states.keys()): self.base.edges.loc[condition, key] = all_states[key] def distance(self, endpoint: "View") -> float: @@ -1221,9 +1213,9 @@ def _get_state_names(self) -> Tuple[List, List]: """Collect all recordable / clampable states in the membrane and synapses. Returns states seperated by comps and edges.""" - channel_states = [name for c in self.channels for name in c.channel_states] + channel_states = [name for c in self.channels for name in c.states] synapse_states = [ - name for s in self.synapses if s is not None for name in s.synapse_states + name for s in self.synapses if s is not None for name in s.states ] membrane_states = ["v", "i"] + self.membrane_current_names return ( @@ -1283,7 +1275,7 @@ def get_all_parameters( params[key] = self.base.jaxnodes[key] for channel in self.base.channels: - for channel_params in channel.channel_params: + for channel_params in channel.params: params[channel_params] = self.base.jaxnodes[channel_params] for synapse_params in self.base.synapse_param_names: @@ -1327,7 +1319,7 @@ def _get_states_from_nodes_and_edges(self) -> Dict[str, jnp.ndarray]: states = {"v": self.base.jaxnodes["v"]} # Join node and edge states into a single state dictionary. for channel in self.base.channels: - for channel_states in channel.channel_states: + for channel_states in channel.states: states[channel_states] = self.base.jaxnodes[channel_states] for synapse_states in self.base.synapse_state_names: states[synapse_states] = self.base.jaxedges[synapse_states] @@ -1410,8 +1402,8 @@ def init_states(self, delta_t: float = 0.025): ].to_numpy() voltages = channel_nodes.loc[channel_indices, "v"].to_numpy() - channel_param_names = list(channel.channel_params.keys()) - channel_state_names = list(channel.channel_states.keys()) + channel_param_names = list(channel.params.keys()) + channel_state_names = list(channel.states.keys()) channel_states = query_channel_states_and_params( states, channel_state_names, channel_indices ) @@ -1748,12 +1740,12 @@ def insert(self, channel: Channel): self.base.nodes.loc[self._nodes_in_view, name] = True # Loop over all new parameters, e.g. gNa, eNa. - for key in channel.channel_params: - self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_params[key] + for key in channel.params: + self.base.nodes.loc[self._nodes_in_view, key] = channel.params[key] # Loop over all new parameters, e.g. gNa, eNa. - for key in channel.channel_states: - self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_states[key] + for key in channel.states: + self.base.nodes.loc[self._nodes_in_view, key] = channel.states[key] def delete_channel(self, channel: Channel): """Remove a channel from the module. @@ -1764,8 +1756,8 @@ def delete_channel(self, channel: Channel): channel_names = [c._name for c in self.channels] all_channel_names = [c._name for c in self.base.channels] if name in channel_names: - channel_cols = list(channel.channel_params.keys()) - channel_cols += list(channel.channel_states.keys()) + channel_cols = list(channel.params.keys()) + channel_cols += list(channel.states.keys()) self.base.nodes.loc[self._nodes_in_view, channel_cols] = float("nan") self.base.nodes.loc[self._nodes_in_view, name] = False @@ -1948,14 +1940,14 @@ def _step_channels_state( # Update states of the channels. indices = channel_nodes["global_comp_index"].to_numpy() for channel in channels: - channel_param_names = list(channel.channel_params) + channel_param_names = list(channel.params) channel_param_names += [ "radius", "length", "axial_resistivity", "capacitance", ] - channel_state_names = list(channel.channel_states) + channel_state_names = list(channel.states) channel_state_names += self.membrane_current_names channel_indices = indices[channel_nodes[channel._name].astype(bool)] @@ -2003,8 +1995,8 @@ def _channel_currents( for channel in channels: name = channel._name - channel_param_names = list(channel.channel_params.keys()) - channel_state_names = list(channel.channel_states.keys()) + channel_param_names = list(channel.params.keys()) + channel_state_names = list(channel.states.keys()) indices = channel_nodes.loc[channel_nodes[name]][ "global_comp_index" ].to_numpy() @@ -2599,13 +2591,9 @@ def _filter_trainables( ): pkey, pval = next(iter(params.items())) trainable_inds_in_view = None - if pkey in sum( - [list(c.channel_params.keys()) for c in self.base.channels], [] - ): + if pkey in sum([list(c.params.keys()) for c in self.base.channels], []): trainable_inds_in_view = np.intersect1d(inds, self._nodes_in_view) - elif pkey in sum( - [list(s.synapse_params.keys()) for s in self.base.synapses], [] - ): + elif pkey in sum([list(s.params.keys()) for s in self.base.synapses], []): trainable_inds_in_view = np.intersect1d(inds, self._edges_in_view) in_view = is_viewed == np.isin(inds, trainable_inds_in_view) @@ -2668,8 +2656,8 @@ def _set_synapses_in_view(self, pointer: Union[Module, View]): viewed_synapses += ( [syn] if in_view else [None] ) # padded with None to keep indices consistent - viewed_params += list(syn.synapse_params.keys()) if in_view else [] - viewed_states += list(syn.synapse_states.keys()) if in_view else [] + viewed_params += list(syn.params.keys()) if in_view else [] + viewed_states += list(syn.states.keys()) if in_view else [] self.synapses = viewed_synapses self.synapse_param_names = viewed_params self.synapse_state_names = viewed_states diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 15183bd6..5727446a 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -273,8 +273,8 @@ def _step_synapse_state( assert ( synapse_names[i] == synapse_type._name ), "Mixup in the ordering of synapses. Please create an issue on Github." - synapse_param_names = list(synapse_type.synapse_params.keys()) - synapse_state_names = list(synapse_type.synapse_states.keys()) + synapse_param_names = list(synapse_type.params.keys()) + synapse_state_names = list(synapse_type.states.keys()) synapse_params = {} for p in synapse_param_names: @@ -325,8 +325,8 @@ def _synapse_currents( assert ( synapse_names[i] == synapse_type._name ), "Mixup in the ordering of synapses. Please create an issue on Github." - synapse_param_names = list(synapse_type.synapse_params.keys()) - synapse_state_names = list(synapse_type.synapse_states.keys()) + synapse_param_names = list(synapse_type.params.keys()) + synapse_state_names = list(synapse_type.states.keys()) synapse_params = {} for p in synapse_param_names: @@ -514,8 +514,8 @@ def _infer_synapse_type_ind(self, synapse_name): def _update_synapse_state_names(self, synapse_type): # (Potentially) update variables that track meta information about synapses. self.base.synapse_names.append(synapse_type._name) - self.base.synapse_param_names += list(synapse_type.synapse_params.keys()) - self.base.synapse_state_names += list(synapse_type.synapse_states.keys()) + self.base.synapse_param_names += list(synapse_type.params.keys()) + self.base.synapse_state_names += list(synapse_type.states.keys()) self.base.synapses.append(synapse_type) def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type): @@ -567,9 +567,9 @@ def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type): def _add_params_to_edges(self, synapse_type, indices): # Add parameters and states to the `.edges` table. - for key, param_val in synapse_type.synapse_params.items(): + for key, param_val in synapse_type.params.items(): self.base.edges.loc[indices, key] = param_val # Update synaptic state array. - for key, state_val in synapse_type.synapse_states.items(): + for key, state_val in synapse_type.states.items(): self.base.edges.loc[indices, key] = state_val diff --git a/jaxley/synapses/ionotropic.py b/jaxley/synapses/ionotropic.py index da89113f..101dd95b 100644 --- a/jaxley/synapses/ionotropic.py +++ b/jaxley/synapses/ionotropic.py @@ -32,12 +32,12 @@ class IonotropicSynapse(Synapse): def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.synapse_params = { + self.params = { f"{prefix}_gS": 1e-4, f"{prefix}_e_syn": 0.0, f"{prefix}_k_minus": 0.025, } - self.synapse_states = {f"{prefix}_s": 0.2} + self.states = {f"{prefix}_s": 0.2} def update_states( self, diff --git a/jaxley/synapses/synapse.py b/jaxley/synapses/synapse.py index a3b4752f..38cd7d3f 100644 --- a/jaxley/synapses/synapse.py +++ b/jaxley/synapses/synapse.py @@ -38,22 +38,22 @@ def change_name(self, new_name: str): new_prefix = new_name + "_" self._name = new_name - self.synapse_params = { + self.params = { ( new_prefix + key[len(old_prefix) :] if key.startswith(old_prefix) else key ): value - for key, value in self.synapse_params.items() + for key, value in self.params.items() } - self.synapse_states = { + self.states = { ( new_prefix + key[len(old_prefix) :] if key.startswith(old_prefix) else key ): value - for key, value in self.synapse_states.items() + for key, value in self.states.items() } return self diff --git a/jaxley/synapses/tanh_rate.py b/jaxley/synapses/tanh_rate.py index e006a278..6bbd49cc 100644 --- a/jaxley/synapses/tanh_rate.py +++ b/jaxley/synapses/tanh_rate.py @@ -16,12 +16,12 @@ class TanhRateSynapse(Synapse): def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.synapse_params = { + self.params = { f"{prefix}_gS": 1e-4, f"{prefix}_x_offset": -70.0, f"{prefix}_slope": 1.0, } - self.synapse_states = {} + self.states = {} def update_states( self, diff --git a/jaxley/synapses/test.py b/jaxley/synapses/test.py index 49a7311e..84cb5d4d 100644 --- a/jaxley/synapses/test.py +++ b/jaxley/synapses/test.py @@ -19,8 +19,8 @@ class TestSynapse(Synapse): def __init__(self, name: Optional[str] = None): super().__init__(name) prefix = self._name - self.synapse_params = {f"{prefix}_gC": 1e-4} - self.synapse_states = {f"{prefix}_c": 0.2} + self.params = {f"{prefix}_gC": 1e-4} + self.states = {f"{prefix}_c": 0.2} def update_states( self, diff --git a/tests/test_channels.py b/tests/test_channels.py index 4063fd3e..c850707d 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -25,13 +25,13 @@ def __init__( ): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = { + self.params = { f"{self._name}_gamma": 0.05, # Fraction of free calcium (not buffered) f"{self._name}_decay": 80, # Rate of removal of calcium in ms f"{self._name}_depth": 0.1, # Depth of shell in um f"{self._name}_minCai": 1e-4, # Minimum intracellular calcium concentration in mM } - self.channel_states = { + self.states = { f"CaCon_i": 5e-05, # Initial internal calcium concentration in mM } self.current_name = f"i_Ca" @@ -84,8 +84,8 @@ def __init__( "T": 279.45, # Kelvin (temperature) "R": 8.314, # J/(mol K) (gas constant) } - self.channel_params = {} - self.channel_states = {"eCa": 0.0, "CaCon_i": 5e-05, "CaCon_e": 2.0} + self.params = {} + self.states = {"eCa": 0.0, "CaCon_i": 5e-05, "CaCon_e": 2.0} self.current_name = f"i_Ca" def update_states(self, u, dt, voltages, params): @@ -117,21 +117,21 @@ def test_channel_set_name(): # channel name can be set in the constructor na = Na(name="NaPospischil") assert na.name == "NaPospischil" - assert "NaPospischil_gNa" in na.channel_params.keys() - assert "eNa" in na.channel_params.keys() - assert "NaPospischil_h" in na.channel_states.keys() - assert "NaPospischil_m" in na.channel_states.keys() - assert "NaPospischil_vt" not in na.channel_params.keys() - assert "vt" in na.channel_params.keys() + assert "NaPospischil_gNa" in na.params.keys() + assert "eNa" in na.params.keys() + assert "NaPospischil_h" in na.states.keys() + assert "NaPospischil_m" in na.states.keys() + assert "NaPospischil_vt" not in na.params.keys() + assert "vt" in na.params.keys() # channel name can not be changed directly k = K() with pytest.raises(AttributeError): k.name = "KPospischil" - assert "KPospischil_gNa" not in k.channel_params.keys() - assert "eNa" not in k.channel_params.keys() - assert "KPospischil_h" not in k.channel_states.keys() - assert "KPospischil_m" not in k.channel_states.keys() + assert "KPospischil_gNa" not in k.params.keys() + assert "eNa" not in k.params.keys() + assert "KPospischil_h" not in k.states.keys() + assert "KPospischil_m" not in k.states.keys() def test_channel_change_name(): @@ -139,12 +139,12 @@ def test_channel_change_name(): # (and only this way after initialization) na = Na().change_name("NaPospischil") assert na.name == "NaPospischil" - assert "NaPospischil_gNa" in na.channel_params.keys() - assert "eNa" in na.channel_params.keys() - assert "NaPospischil_h" in na.channel_states.keys() - assert "NaPospischil_m" in na.channel_states.keys() - assert "NaPospischil_vt" not in na.channel_params.keys() - assert "vt" in na.channel_params.keys() + assert "NaPospischil_gNa" in na.params.keys() + assert "eNa" in na.params.keys() + assert "NaPospischil_h" in na.states.keys() + assert "NaPospischil_m" in na.states.keys() + assert "NaPospischil_vt" not in na.params.keys() + assert "vt" in na.params.keys() def test_integration_with_renamed_channels(): @@ -200,12 +200,12 @@ def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) prefix = self._name - self.channel_params = { + self.params = { f"{prefix}_q10_ch": 3, f"{prefix}_q10_ch0": 22, "celsius": 22, } - self.channel_states = {f"{prefix}_m": 0.02, "CaCon_i": 1e-4} + self.states = {f"{prefix}_m": 0.02, "CaCon_i": 1e-4} self.current_name = f"i_K" def update_states( @@ -291,8 +291,8 @@ class User(Channel): def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = {} - self.channel_states = {"cumulative": 0.0} + self.params = {} + self.states = {"cumulative": 0.0} self.current_name = f"i_User" def update_states(self, states, dt, v, params): @@ -307,8 +307,8 @@ class Dummy1(Channel): def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = {} - self.channel_states = {} + self.params = {} + self.states = {} self.current_name = f"i_Dummy" def update_states(self, states, dt, v, params): @@ -321,8 +321,8 @@ class Dummy2(Channel): def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = {} - self.channel_states = {} + self.params = {} + self.states = {} self.current_name = f"i_Dummy" def update_states(self, states, dt, v, params): @@ -365,8 +365,8 @@ def test_delete_channel(SimpleBranch): branch3.delete_channel(K()) def channel_present(view, channel, partial=False): - states_and_params = list(channel.channel_states.keys()) + list( - channel.channel_params.keys() + states_and_params = list(channel.states.keys()) + list( + channel.params.keys() ) # none of the states or params should be in nodes cols = view.nodes.columns.to_list() diff --git a/tests/test_shared_state.py b/tests/test_shared_state.py index 0de88bb5..83541acd 100644 --- a/tests/test_shared_state.py +++ b/tests/test_shared_state.py @@ -22,8 +22,8 @@ class Dummy1(Channel): def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = {} - self.channel_states = {"Dummy_s": 0.0} + self.params = {} + self.states = {"Dummy_s": 0.0} self.current_name = f"i_Dummy1" @staticmethod @@ -45,8 +45,8 @@ class Dummy2(Channel): def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = {} - self.channel_states = {"Dummy_s": 0.0} + self.params = {} + self.states = {"Dummy_s": 0.0} self.current_name = f"i_Dummy2" @staticmethod @@ -68,10 +68,10 @@ class CaHVA(Channel): def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = { + self.params = { f"{self._name}_gCaHVA": 0.00001, # S/cm^2 } - self.channel_states = { + self.states = { f"{self._name}_m": 0.1, # Initial value for m gating variable f"{self._name}_h": 0.1, # Initial value for h gating variable "eCa": 0.0, # mV, assuming eca for demonstration @@ -140,13 +140,13 @@ def __init__( ): self.current_is_in_mA_per_cm2 = True super().__init__(name) - self.channel_params = { + self.params = { f"{self._name}_gamma": 0.05, # Fraction of free calcium (not buffered) f"{self._name}_decay": 80, # Rate of removal of calcium in ms f"{self._name}_depth": 0.1, # Depth of shell in um f"{self._name}_minCai": 1e-4, # Minimum intracellular calcium concentration in mM } - self.channel_states = { + self.states = { f"CaCon_i": 5e-05, # Initial internal calcium concentration in mM } self.current_name = f"i_Ca" diff --git a/tests/test_syn.py b/tests/test_syn.py index 3159e036..840fb341 100644 --- a/tests/test_syn.py +++ b/tests/test_syn.py @@ -27,7 +27,7 @@ def test_set_and_querying_params_one_type(SimpleNet): connect(pre, post, IonotropicSynapse()) # Get the synapse parameters to test setting - syn_params = list(IonotropicSynapse().synapse_params.keys()) + syn_params = list(IonotropicSynapse().params.keys()) for p in syn_params: net.set(p, 0.15) assert np.all(net.edges[p].to_numpy() == 0.15) diff --git a/tests/test_synapse_indexing.py b/tests/test_synapse_indexing.py index 150a5d83..d61934c4 100644 --- a/tests/test_synapse_indexing.py +++ b/tests/test_synapse_indexing.py @@ -68,7 +68,7 @@ def test_set_and_querying_params_one_type(synapse_type, SimpleNet): connect(pre, post, synapse_type) # Get the synapse parameters to test setting - syn_params = list(synapse_type.synapse_params.keys()) + syn_params = list(synapse_type.params.keys()) for p in syn_params: net.set(p, 0.15) assert np.all(net.edges[p].to_numpy() == 0.15) @@ -105,8 +105,8 @@ def test_set_and_querying_params_two_types(synapse_type, SimpleNet): post = net.cell(post_ind).branch(0).loc(0.0) connect(pre, post, synapse) - type1_params = list(IonotropicSynapse().synapse_params.keys()) - synapse_type_params = list(synapse_type.synapse_params.keys()) + type1_params = list(IonotropicSynapse().params.keys()) + synapse_type_params = list(synapse_type.params.keys()) default_synapse_type = net.edges[synapse_type_params[0]].to_numpy()[[1, 3]] From 6ec896272a16fd1983ffecabe026e076e8977abe Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Sun, 22 Dec 2024 17:41:44 +0100 Subject: [PATCH 2/4] enh: refactor HH channels --- jaxley/channels/hh.py | 225 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 198 insertions(+), 27 deletions(-) diff --git a/jaxley/channels/hh.py b/jaxley/channels/hh.py index 70fc72b5..8f9072c2 100644 --- a/jaxley/channels/hh.py +++ b/jaxley/channels/hh.py @@ -9,8 +9,8 @@ from jaxley.solver_gate import save_exp, solve_gate_exponential -class HH(Channel): - """Hodgkin-Huxley channel.""" +class Na(Channel): + """Hodgkin-Huxley Sodium channel.""" def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True @@ -19,61 +19,56 @@ def __init__(self, name: Optional[str] = None): prefix = self._name self.params = { f"{prefix}_gNa": 0.12, - f"{prefix}_gK": 0.036, - f"{prefix}_gLeak": 0.0003, f"{prefix}_eNa": 50.0, - f"{prefix}_eK": -77.0, - f"{prefix}_eLeak": -54.3, } self.states = { f"{prefix}_m": 0.2, f"{prefix}_h": 0.2, - f"{prefix}_n": 0.2, } - self.current_name = f"i_HH" + self.current_name = f"i_Na" def update_states( self, states: Dict[str, jnp.ndarray], - dt, - v, + dt: float, + v: jnp.ndarray, params: Dict[str, jnp.ndarray], - ): + ) -> Dict[str, jnp.ndarray]: """Return updated HH channel state.""" prefix = self._name - m, h, n = states[f"{prefix}_m"], states[f"{prefix}_h"], states[f"{prefix}_n"] + m, h = states[f"{prefix}_m"], states[f"{prefix}_h"] new_m = solve_gate_exponential(m, dt, *self.m_gate(v)) new_h = solve_gate_exponential(h, dt, *self.h_gate(v)) - new_n = solve_gate_exponential(n, dt, *self.n_gate(v)) - return {f"{prefix}_m": new_m, f"{prefix}_h": new_h, f"{prefix}_n": new_n} + return {f"{prefix}_m": new_m, f"{prefix}_h": new_h} def compute_current( - self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray] - ): + self, + states: Dict[str, jnp.ndarray], + v: jnp.ndarray, + params: Dict[str, jnp.ndarray], + ) -> jnp.ndarray: """Return current through HH channels.""" prefix = self._name - m, h, n = states[f"{prefix}_m"], states[f"{prefix}_h"], states[f"{prefix}_n"] + m, h = states[f"{prefix}_m"], states[f"{prefix}_h"] gNa = params[f"{prefix}_gNa"] * (m**3) * h # S/cm^2 - gK = params[f"{prefix}_gK"] * n**4 # S/cm^2 - gLeak = params[f"{prefix}_gLeak"] # S/cm^2 - return ( - gNa * (v - params[f"{prefix}_eNa"]) - + gK * (v - params[f"{prefix}_eK"]) - + gLeak * (v - params[f"{prefix}_eLeak"]) - ) + return gNa * (v - params[f"{prefix}_eNa"]) - def init_state(self, states, v, params, delta_t): + def init_state( + self, + states: Dict[str, jnp.ndarray], + v: jnp.ndarray, + params: Dict[str, jnp.ndarray], + dt: float, + ) -> Dict[str, jnp.ndarray]: """Initialize the state such at fixed point of gate dynamics.""" prefix = self._name alpha_m, beta_m = self.m_gate(v) alpha_h, beta_h = self.h_gate(v) - alpha_n, beta_n = self.n_gate(v) return { f"{prefix}_m": alpha_m / (alpha_m + beta_m), f"{prefix}_h": alpha_h / (alpha_h + beta_h), - f"{prefix}_n": alpha_n / (alpha_n + beta_n), } @staticmethod @@ -88,6 +83,63 @@ def h_gate(v): beta = 1.0 / (save_exp(-(v + 35) / 10) + 1) return alpha, beta + +class K(Channel): + """Hodgkin-Huxley Potassium channel.""" + + def __init__(self, name: Optional[str] = None): + self.current_is_in_mA_per_cm2 = True + + super().__init__(name) + prefix = self._name + self.params = { + f"{prefix}_gK": 0.036, + f"{prefix}_eK": -77.0, + } + self.states = { + f"{prefix}_n": 0.2, + } + self.current_name = f"i_K" + + def update_states( + self, + states: Dict[str, jnp.ndarray], + dt: float, + v: jnp.ndarray, + params: Dict[str, jnp.ndarray], + ) -> Dict[str, jnp.ndarray]: + """Return updated HH channel state.""" + prefix = self._name + n = states[f"{prefix}_n"] + new_n = solve_gate_exponential(n, dt, *self.n_gate(v)) + return {f"{prefix}_n": new_n} + + def compute_current( + self, + states: Dict[str, jnp.ndarray], + v: jnp.ndarray, + params: Dict[str, jnp.ndarray], + ) -> jnp.ndarray: + """Return current through HH channels.""" + prefix = self._name + n = states[f"{prefix}_n"] + + gK = params[f"{prefix}_gK"] * n**4 # S/cm^2 + + return gK * (v - params[f"{prefix}_eK"]) + + def init_state( + self, + states: Dict[str, jnp.ndarray], + v: jnp.ndarray, + params: Dict[str, jnp.ndarray], + dt: float, + ) -> Dict[str, jnp.ndarray]: + """Initialize the state such at fixed point of gate dynamics.""" + prefix = self._name + alpha_n, beta_n = self.n_gate(v) + return {f"{prefix}_n": alpha_n / (alpha_n + beta_n)} + @staticmethod def n_gate(v): alpha = 0.01 * _vtrap(-(v + 55), 10) @@ -95,5 +147,124 @@ def n_gate(v): return alpha, beta +class Leak(Channel): + """Hodgkin-Huxley Leak channel.""" + + def __init__(self, name: Optional[str] = None): + self.current_is_in_mA_per_cm2 = True + + super().__init__(name) + prefix = self._name + self.params = { + f"{prefix}_gLeak": 0.0003, + f"{prefix}_eLeak": -54.3, + } + self.states = {} + self.current_name = f"i_Leak" + + def update_states( + self, + states: Dict[str, jnp.ndarray], + dt: float, + v: jnp.ndarray, + params: Dict[str, jnp.ndarray], + ) -> Dict[str, jnp.ndarray]: + """Return updated HH channel state.""" + return {} + + def compute_current( + self, + states: Dict[str, jnp.ndarray], + v: jnp.ndarray, + params: Dict[str, jnp.ndarray], + ) -> jnp.ndarray: + """Return current through HH channels.""" + prefix = self._name + gLeak = params[f"{prefix}_gLeak"] # S/cm^2 + + return gLeak * (v - params[f"{prefix}_eLeak"]) + + def init_state( + self, + states: Dict[str, jnp.ndarray], + v: jnp.ndarray, + params: Dict[str, jnp.ndarray], + dt: float, + ) -> Dict[str, jnp.ndarray]: + """Initialize the state such at fixed point of gate dynamics.""" + return {} + + +class HH(Channel): + """Hodgkin-Huxley channel.""" + + def __init__(self, name: Optional[str] = None): + self.current_is_in_mA_per_cm2 = True + + super().__init__(name) + self.Na = Na(self._name) + self.K = K(self._name) + self.Leak = Leak(self._name) + self.channels = [self.Na, self.K, self.Leak] + + self.params = { + **self.Na.params, + **self.K.params, + **self.Leak.params, + } + + self.states = { + **self.Na.states, + **self.K.states, + **self.Leak.states, + } + + self.current_name = f"i_HH" + + def change_name(self, new_name: str): + self._name = new_name + for channel in self.channels: + channel.change_name(new_name) + return self + + def update_states( + self, + states: Dict[str, jnp.ndarray], + dt: float, + v: jnp.ndarray, + params: Dict[str, jnp.ndarray], + ) -> Dict[str, jnp.ndarray]: + """Return updated HH channel state.""" + new_states = {} + for channel in self.channels: + new_states.update(channel.update_states(states, dt, v, params)) + return new_states + + def compute_current( + self, + states: Dict[str, jnp.ndarray], + v: jnp.ndarray, + params: Dict[str, jnp.ndarray], + ) -> jnp.ndarray: + """Return current through HH channels.""" + current = 0 + for channel in self.channels: + current += channel.compute_current(states, v, params) + return current + + def init_state( + self, + states: Dict[str, jnp.ndarray], + v: jnp.ndarray, + params: Dict[str, jnp.ndarray], + dt: float, + ) -> Dict[str, jnp.ndarray]: + """Initialize the state such at fixed point of gate dynamics.""" + init_states = {} + for channel in self.channels: + init_states.update(channel.init_state(states, v, params, dt)) + return init_states + + def _vtrap(x, y): return x / (save_exp(x / y) - 1.0) From a0874b484fc002dc832d5ad00a2554df2af3b6b1 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Sun, 22 Dec 2024 18:53:11 +0100 Subject: [PATCH 3/4] wip: started work on new channel API --- jaxley/channels/channel.py | 5 +- jaxley/channels/hh.py | 76 ++++++++---------------- jaxley/modules/base.py | 106 ++++++++++++++++++++++++++++++++++ jaxley/synapses/ionotropic.py | 21 +++---- jaxley/synapses/synapse.py | 4 +- jaxley/synapses/tanh_rate.py | 14 ++--- jaxley/synapses/test.py | 13 ++--- 7 files changed, 152 insertions(+), 87 deletions(-) diff --git a/jaxley/channels/channel.py b/jaxley/channels/channel.py index b8a1dc41..81d48fdb 100644 --- a/jaxley/channels/channel.py +++ b/jaxley/channels/channel.py @@ -16,9 +16,8 @@ class Channel: `uA/cm2`.""" _name = None - channel_params = None - channel_states = None - current_name = None + params = None + states = None def __init__(self, name: Optional[str] = None): contact = ( diff --git a/jaxley/channels/hh.py b/jaxley/channels/hh.py index 8f9072c2..a06dc5a5 100644 --- a/jaxley/channels/hh.py +++ b/jaxley/channels/hh.py @@ -16,16 +16,8 @@ def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - prefix = self._name - self.params = { - f"{prefix}_gNa": 0.12, - f"{prefix}_eNa": 50.0, - } - self.states = { - f"{prefix}_m": 0.2, - f"{prefix}_h": 0.2, - } - self.current_name = f"i_Na" + self.params = {"gNa": 0.12, "eNa": 50.0} + self.states = {"m": 0.2, "h": 0.2} def update_states( self, @@ -35,11 +27,11 @@ def update_states( params: Dict[str, jnp.ndarray], ) -> Dict[str, jnp.ndarray]: """Return updated HH channel state.""" - prefix = self._name - m, h = states[f"{prefix}_m"], states[f"{prefix}_h"] + m, h = states["m"], states["h"] + new_m = solve_gate_exponential(m, dt, *self.m_gate(v)) new_h = solve_gate_exponential(h, dt, *self.h_gate(v)) - return {f"{prefix}_m": new_m, f"{prefix}_h": new_h} + return {"m": new_m, "h": new_h} def compute_current( self, @@ -48,12 +40,11 @@ def compute_current( params: Dict[str, jnp.ndarray], ) -> jnp.ndarray: """Return current through HH channels.""" - prefix = self._name - m, h = states[f"{prefix}_m"], states[f"{prefix}_h"] - - gNa = params[f"{prefix}_gNa"] * (m**3) * h # S/cm^2 + m, h = states["m"], states["h"] + gNa, eNa = params["gNa"], params["eNa"] - return gNa * (v - params[f"{prefix}_eNa"]) + gNa = gNa * (m**3) * h # S/cm^2 + return gNa * (v - eNa) def init_state( self, @@ -63,13 +54,9 @@ def init_state( dt: float, ) -> Dict[str, jnp.ndarray]: """Initialize the state such at fixed point of gate dynamics.""" - prefix = self._name alpha_m, beta_m = self.m_gate(v) alpha_h, beta_h = self.h_gate(v) - return { - f"{prefix}_m": alpha_m / (alpha_m + beta_m), - f"{prefix}_h": alpha_h / (alpha_h + beta_h), - } + return {"m": alpha_m / (alpha_m + beta_m), "h": alpha_h / (alpha_h + beta_h)} @staticmethod def m_gate(v): @@ -91,15 +78,8 @@ def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - prefix = self._name - self.params = { - f"{prefix}_gK": 0.036, - f"{prefix}_eK": -77.0, - } - self.states = { - f"{prefix}_n": 0.2, - } - self.current_name = f"i_K" + self.params = {"gK": 0.036, "eK": -77.0} + self.states = {"n": 0.2} def update_states( self, @@ -109,10 +89,10 @@ def update_states( params: Dict[str, jnp.ndarray], ) -> Dict[str, jnp.ndarray]: """Return updated HH channel state.""" - prefix = self._name - n = states[f"{prefix}_n"] + n = states["n"] + new_n = solve_gate_exponential(n, dt, *self.n_gate(v)) - return {f"{prefix}_n": new_n} + return {"n": new_n} def compute_current( self, @@ -121,12 +101,11 @@ def compute_current( params: Dict[str, jnp.ndarray], ) -> jnp.ndarray: """Return current through HH channels.""" - prefix = self._name - n = states[f"{prefix}_n"] + n = states["n"] + gK, eK = params["gK"], params["eK"] - gK = params[f"{prefix}_gK"] * n**4 # S/cm^2 - - return gK * (v - params[f"{prefix}_eK"]) + gK = gK * n**4 # S/cm^2 + return gK * (v - eK) def init_state( self, @@ -136,9 +115,8 @@ def init_state( dt: float, ) -> Dict[str, jnp.ndarray]: """Initialize the state such at fixed point of gate dynamics.""" - prefix = self._name alpha_n, beta_n = self.n_gate(v) - return {f"{prefix}_n": alpha_n / (alpha_n + beta_n)} + return {"n": alpha_n / (alpha_n + beta_n)} @staticmethod def n_gate(v): @@ -154,13 +132,8 @@ def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - prefix = self._name - self.params = { - f"{prefix}_gLeak": 0.0003, - f"{prefix}_eLeak": -54.3, - } + self.params = {"gLeak": 0.0003, "eLeak": -54.3} self.states = {} - self.current_name = f"i_Leak" def update_states( self, @@ -179,10 +152,9 @@ def compute_current( params: Dict[str, jnp.ndarray], ) -> jnp.ndarray: """Return current through HH channels.""" - prefix = self._name - gLeak = params[f"{prefix}_gLeak"] # S/cm^2 + gLeak, eLeak = params["gLeak"], params["eLeak"] - return gLeak * (v - params[f"{prefix}_eLeak"]) + return gLeak * (v - eLeak) def init_state( self, @@ -219,8 +191,6 @@ def __init__(self, name: Optional[str] = None): **self.Leak.states, } - self.current_name = f"i_HH" - def change_name(self, new_name: str): self._name = new_name for channel in self.channels: diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index d344061a..a4683f0d 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -11,6 +11,7 @@ import jax.numpy as jnp import numpy as np +from optree import tree_map_with_path import pandas as pd from jax import jit, vmap from jax.lax import ScatterDimensionNumbers, scatter_add @@ -2734,3 +2735,108 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, exc_traceback): pass + + +class AutoPrefix: + """Wrapper for Channel classes that transparently handles name prefixing using pytrees.""" + + def __init__(self, mech: Union[Channel, Synapse]): + """Initialize wrapper with a channel instance.""" + # Store the wrapped channel + self._wrapped = mech + self.prefix = f"{self._wrapped.name}_" + self._name = self._wrapped._name + self.current_name = "i_" + self.prefix[:-1] + + # Make this class pretend to be the wrapped class + self.__class__.__name__ = mech.__class__.__name__ + self.__class__.__qualname__ = mech.__class__.__qualname__ + self.__class__.__module__ = mech.__class__.__module__ + + if isinstance(self._wrapped, Synapse): + delattr(self.__class__, "init_state") + + def __getattr__(self, name: str): + """Delegate unknown attributes to the wrapped channel.""" + return getattr(self._wrapped, name) + + def __repr__(self): + """Return the same representation as the wrapped channel.""" + return repr(self._wrapped) + + def __str__(self): + """Return the same string representation as the wrapped channel.""" + return str(self._wrapped) + + def _transform_dict_keys( + self, d: Dict[str, Any], add_prefix: bool = True + ) -> Dict[str, Any]: + """Transform dictionary keys using pytree operations.""" + + def transform_key(path, value): + key = path[-1] # Get the leaf key + if isinstance(key, str): + if add_prefix: + return f"{self.prefix}{key}", value + elif key.startswith(self.prefix): + return key[len(self.prefix) :], value + return key, value + + transformed = tree_map_with_path( + lambda p, v: transform_key(p, v)[1], + d, + is_leaf=lambda x: isinstance(x, (jnp.ndarray, float, int)), + ) + return transformed + + @property + def params(self) -> Dict[str, Any]: + """Get prefixed parameters.""" + return self._transform_dict_keys(self._wrapped.params, add_prefix=True) + + @property + def states(self) -> Dict[str, Any]: + """Get prefixed states.""" + return self._transform_dict_keys(self._wrapped.states, add_prefix=True) + + def update_states( + self, + states: Dict[str, Any], + dt: float, + v: jnp.ndarray, + params: Dict[str, Any], + ) -> Dict[str, Any]: + """Update states with automatic prefix handling using pytrees.""" + states = self._transform_dict_keys(states, add_prefix=False) + params = self._transform_dict_keys(params, add_prefix=False) + + states = self._wrapped.update_states(states, dt, v, params) + + return self._transform_dict_keys(states, add_prefix=True) + + def compute_current( + self, + states: Dict[str, Any], + v: jnp.ndarray, + params: Dict[str, Any], + ) -> jnp.ndarray: + """Compute current with automatic prefix handling using pytrees.""" + states = self._transform_dict_keys(states, add_prefix=False) + params = self._transform_dict_keys(params, add_prefix=False) + + return self._wrapped.compute_current(states, v, params) + + def init_state( + self, + states: Dict[str, Any], + v: jnp.ndarray, + params: Dict[str, Any], + dt: float, + ) -> Dict[str, Any]: + """Initialize states with automatic prefix handling using pytrees.""" + states = self._transform_dict_keys(states, add_prefix=False) + params = self._transform_dict_keys(params, add_prefix=False) + + init_states = self._wrapped.init_state(states, v, params, dt) + + return self._transform_dict_keys(init_states, add_prefix=True) diff --git a/jaxley/synapses/ionotropic.py b/jaxley/synapses/ionotropic.py index 101dd95b..58514c04 100644 --- a/jaxley/synapses/ionotropic.py +++ b/jaxley/synapses/ionotropic.py @@ -31,13 +31,12 @@ class IonotropicSynapse(Synapse): def __init__(self, name: Optional[str] = None): super().__init__(name) - prefix = self._name self.params = { - f"{prefix}_gS": 1e-4, - f"{prefix}_e_syn": 0.0, - f"{prefix}_k_minus": 0.025, + "gS": 1e-4, + "e_syn": 0.0, + "k_minus": 0.025, } - self.states = {f"{prefix}_s": 0.2} + self.states = {"s": 0.2} def update_states( self, @@ -48,21 +47,19 @@ def update_states( params: Dict, ) -> Dict: """Return updated synapse state and current.""" - prefix = self._name v_th = -35.0 # mV delta = 10.0 # mV s_inf = 1.0 / (1.0 + save_exp((v_th - pre_voltage) / delta)) - tau_s = (1.0 - s_inf) / params[f"{prefix}_k_minus"] + tau_s = (1.0 - s_inf) / params["k_minus"] slope = -1.0 / tau_s exp_term = save_exp(slope * delta_t) - new_s = states[f"{prefix}_s"] * exp_term + s_inf * (1.0 - exp_term) - return {f"{prefix}_s": new_s} + new_s = states["s"] * exp_term + s_inf * (1.0 - exp_term) + return {"s": new_s} def compute_current( self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict ) -> float: - prefix = self._name - g_syn = params[f"{prefix}_gS"] * states[f"{prefix}_s"] - return g_syn * (post_voltage - params[f"{prefix}_e_syn"]) + g_syn = params["gS"] * states["s"] + return g_syn * (post_voltage - params["e_syn"]) diff --git a/jaxley/synapses/synapse.py b/jaxley/synapses/synapse.py index 38cd7d3f..ea460512 100644 --- a/jaxley/synapses/synapse.py +++ b/jaxley/synapses/synapse.py @@ -15,8 +15,8 @@ class Synapse: """ _name = None - synapse_params = None - synapse_states = None + params = None + states = None def __init__(self, name: Optional[str] = None): self._name = name if name else self.__class__.__name__ diff --git a/jaxley/synapses/tanh_rate.py b/jaxley/synapses/tanh_rate.py index 6bbd49cc..8a95e79d 100644 --- a/jaxley/synapses/tanh_rate.py +++ b/jaxley/synapses/tanh_rate.py @@ -15,11 +15,10 @@ class TanhRateSynapse(Synapse): def __init__(self, name: Optional[str] = None): super().__init__(name) - prefix = self._name self.params = { - f"{prefix}_gS": 1e-4, - f"{prefix}_x_offset": -70.0, - f"{prefix}_slope": 1.0, + "gS": 1e-4, + "x_offset": -70.0, + "slope": 1.0, } self.states = {} @@ -38,12 +37,9 @@ def compute_current( self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict ) -> float: """Return updated synapse state and current.""" - prefix = self._name current = ( -1 - * params[f"{prefix}_gS"] - * jnp.tanh( - (pre_voltage - params[f"{prefix}_x_offset"]) * params[f"{prefix}_slope"] - ) + * params["gS"] + * jnp.tanh((pre_voltage - params["x_offset"]) * params["slope"]) ) return current diff --git a/jaxley/synapses/test.py b/jaxley/synapses/test.py index 84cb5d4d..95b7d0fa 100644 --- a/jaxley/synapses/test.py +++ b/jaxley/synapses/test.py @@ -18,9 +18,8 @@ class TestSynapse(Synapse): def __init__(self, name: Optional[str] = None): super().__init__(name) - prefix = self._name - self.params = {f"{prefix}_gC": 1e-4} - self.states = {f"{prefix}_c": 0.2} + self.params = {"gC": 1e-4} + self.states = {"c": 0.2} def update_states( self, @@ -31,7 +30,6 @@ def update_states( params: Dict, ) -> Dict: """Return updated synapse state and current.""" - prefix = self._name v_th = -35.0 delta = 10.0 k_minus = 1.0 / 40.0 @@ -42,13 +40,12 @@ def update_states( s_inf = s_bar slope = -1.0 / tau_s exp_term = save_exp(slope * delta_t) - new_s = states[f"{prefix}_c"] * exp_term + s_inf * (1.0 - exp_term) - return {f"{prefix}_c": new_s} + new_s = states["c"] * exp_term + s_inf * (1.0 - exp_term) + return {"c": new_s} def compute_current( self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict ) -> float: - prefix = self._name e_syn = 0.0 - g_syn = params[f"{prefix}_gC"] * states[f"{prefix}_c"] + g_syn = params["gC"] * states["c"] return g_syn * (post_voltage - e_syn) From 5c11a41ed59a9732ea1bcb05e0387ecc57e91005 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Sun, 22 Dec 2024 23:11:05 +0100 Subject: [PATCH 4/4] wip: add infer global param method and other fixes --- jaxley/channels/pospischil.py | 148 ++++++++++++++-------------------- jaxley/modules/base.py | 39 ++++++++- jaxley/modules/network.py | 3 +- 3 files changed, 101 insertions(+), 89 deletions(-) diff --git a/jaxley/channels/pospischil.py b/jaxley/channels/pospischil.py index 8602a72c..5df21fa6 100644 --- a/jaxley/channels/pospischil.py +++ b/jaxley/channels/pospischil.py @@ -39,13 +39,12 @@ def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - prefix = self._name self.params = { - f"{prefix}_gLeak": 1e-4, - f"{prefix}_eLeak": -70.0, + "gLeak": 1e-4, + "eLeak": -70.0, } self.states = {} - self.current_name = f"i_{prefix}" + # self.current_name = f"i_Leak" def update_states( self, @@ -61,9 +60,8 @@ def compute_current( self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray] ): """Return current.""" - prefix = self._name - gLeak = params[f"{prefix}_gLeak"] # S/cm^2 - return gLeak * (v - params[f"{prefix}_eLeak"]) + gLeak = params["gLeak"] # S/cm^2 + return gLeak * (v - params["eLeak"]) def init_state(self, states, v, params, delta_t): return {} @@ -76,14 +74,13 @@ def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - prefix = self._name self.params = { - f"{prefix}_gNa": 50e-3, - "eNa": 50.0, - "vt": -60.0, # Global parameter, not prefixed with `Na`. + "gNa": 50e-3, + # "eNa": 50.0, + # "vt": -60.0, # Global parameter, not prefixed with `Na`. } - self.states = {f"{prefix}_m": 0.2, f"{prefix}_h": 0.2} - self.current_name = f"i_Na" + self.states = {"m": 0.2, "h": 0.2} + # self.current_name = f"i_Na" def update_states( self, @@ -93,32 +90,29 @@ def update_states( params: Dict[str, jnp.ndarray], ): """Update state.""" - prefix = self._name - m, h = states[f"{prefix}_m"], states[f"{prefix}_h"] + m, h = states["m"], states["h"] new_m = solve_gate_exponential(m, dt, *self.m_gate(v, params["vt"])) new_h = solve_gate_exponential(h, dt, *self.h_gate(v, params["vt"])) - return {f"{prefix}_m": new_m, f"{prefix}_h": new_h} + return {"m": new_m, "h": new_h} def compute_current( self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray] ): """Return current.""" - prefix = self._name - m, h = states[f"{prefix}_m"], states[f"{prefix}_h"] + m, h = states["m"], states["h"] - gNa = params[f"{prefix}_gNa"] * (m**3) * h # S/cm^2 + gNa = params["gNa"] * (m**3) * h # S/cm^2 current = gNa * (v - params["eNa"]) return current def init_state(self, states, v, params, delta_t): """Initialize the state such at fixed point of gate dynamics.""" - prefix = self._name alpha_m, beta_m = self.m_gate(v, params["vt"]) alpha_h, beta_h = self.h_gate(v, params["vt"]) return { - f"{prefix}_m": alpha_m / (alpha_m + beta_m), - f"{prefix}_h": alpha_h / (alpha_h + beta_h), + "m": alpha_m / (alpha_m + beta_m), + "h": alpha_h / (alpha_h + beta_h), } @staticmethod @@ -147,14 +141,13 @@ def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - prefix = self._name self.params = { - f"{prefix}_gK": 5e-3, - "eK": -90.0, - "vt": -60.0, # Global parameter, not prefixed with `Na`. + "gK": 5e-3, + # "eK": -90.0, + # "vt": -60.0, # Global parameter, not prefixed with `Na`. } - self.states = {f"{prefix}_n": 0.2} - self.current_name = f"i_K" + self.states = {"n": 0.2} + # self.current_name = f"i_K" def update_states( self, @@ -164,27 +157,24 @@ def update_states( params: Dict[str, jnp.ndarray], ): """Update state.""" - prefix = self._name - n = states[f"{prefix}_n"] + n = states["n"] new_n = solve_gate_exponential(n, dt, *self.n_gate(v, params["vt"])) - return {f"{prefix}_n": new_n} + return {"n": new_n} def compute_current( self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray] ): """Return current.""" - prefix = self._name - n = states[f"{prefix}_n"] + n = states["n"] - gK = params[f"{prefix}_gK"] * (n**4) # S/cm^2 + gK = params["gK"] * (n**4) # S/cm^2 return gK * (v - params["eK"]) def init_state(self, states, v, params, delta_t): """Initialize the state such at fixed point of gate dynamics.""" - prefix = self._name alpha_n, beta_n = self.n_gate(v, params["vt"]) - return {f"{prefix}_n": alpha_n / (alpha_n + beta_n)} + return {"n": alpha_n / (alpha_n + beta_n)} @staticmethod def n_gate(v, vt): @@ -203,14 +193,13 @@ def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - prefix = self._name self.params = { - f"{prefix}_gKm": 0.004e-3, - f"{prefix}_taumax": 4000.0, - f"eK": -90.0, + "gKm": 0.004e-3, + "taumax": 4000.0, + # f"eK": -90.0, } - self.states = {f"{prefix}_p": 0.2} - self.current_name = f"i_K" + self.states = {"p": 0.2} + # self.current_name = f"i_K" def update_states( self, @@ -220,28 +209,23 @@ def update_states( params: Dict[str, jnp.ndarray], ): """Update state.""" - prefix = self._name - p = states[f"{prefix}_p"] - new_p = solve_inf_gate_exponential( - p, dt, *self.p_gate(v, params[f"{prefix}_taumax"]) - ) - return {f"{prefix}_p": new_p} + p = states["p"] + new_p = solve_inf_gate_exponential(p, dt, *self.p_gate(v, params["taumax"])) + return {"p": new_p} def compute_current( self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray] ): """Return current.""" - prefix = self._name - p = states[f"{prefix}_p"] + p = states["p"] - gKm = params[f"{prefix}_gKm"] * p # S/cm^2 + gKm = params["gKm"] * p # S/cm^2 return gKm * (v - params["eK"]) def init_state(self, states, v, params, delta_t): """Initialize the state such at fixed point of gate dynamics.""" - prefix = self._name - alpha_p, beta_p = self.p_gate(v, params[f"{prefix}_taumax"]) - return {f"{prefix}_p": alpha_p / (alpha_p + beta_p)} + alpha_p, beta_p = self.p_gate(v, params["taumax"]) + return {"p": alpha_p / (alpha_p + beta_p)} @staticmethod def p_gate(v, taumax): @@ -260,13 +244,12 @@ def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - prefix = self._name self.params = { - f"{prefix}_gCaL": 0.1e-3, - "eCa": 120.0, + "gCaL": 0.1e-3, + # "eCa": 120.0, } - self.states = {f"{prefix}_q": 0.2, f"{prefix}_r": 0.2} - self.current_name = f"i_Ca" + self.states = {"q": 0.2, "r": 0.2} + # self.current_name = f"i_Ca" def update_states( self, @@ -276,30 +259,27 @@ def update_states( params: Dict[str, jnp.ndarray], ): """Update state.""" - prefix = self._name - q, r = states[f"{prefix}_q"], states[f"{prefix}_r"] + q, r = states["q"], states["r"] new_q = solve_gate_exponential(q, dt, *self.q_gate(v)) new_r = solve_gate_exponential(r, dt, *self.r_gate(v)) - return {f"{prefix}_q": new_q, f"{prefix}_r": new_r} + return {"q": new_q, "r": new_r} def compute_current( self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray] ): """Return current.""" - prefix = self._name - q, r = states[f"{prefix}_q"], states[f"{prefix}_r"] - gCaL = params[f"{prefix}_gCaL"] * (q**2) * r # S/cm^2 + q, r = states["q"], states["r"] + gCaL = params["gCaL"] * (q**2) * r # S/cm^2 return gCaL * (v - params["eCa"]) def init_state(self, states, v, params, delta_t): """Initialize the state such at fixed point of gate dynamics.""" - prefix = self._name alpha_q, beta_q = self.q_gate(v) alpha_r, beta_r = self.r_gate(v) return { - f"{prefix}_q": alpha_q / (alpha_q + beta_q), - f"{prefix}_r": alpha_r / (alpha_r + beta_r), + "q": alpha_q / (alpha_q + beta_q), + "r": alpha_r / (alpha_r + beta_r), } @staticmethod @@ -328,14 +308,13 @@ def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) - prefix = self._name self.params = { - f"{prefix}_gCaT": 0.4e-4, - f"{prefix}_vx": 2.0, - "eCa": 120.0, # Global parameter, not prefixed with `CaT`. + "gCaT": 0.4e-4, + "vx": 2.0, + # "eCa": 120.0, # Global parameter, not prefixed with `CaT`. } - self.states = {f"{prefix}_u": 0.2} - self.current_name = f"i_Ca" + self.states = {"u": 0.2} + # self.current_name = f"i_Ca" def update_states( self, @@ -345,30 +324,25 @@ def update_states( params: Dict[str, jnp.ndarray], ): """Update state.""" - prefix = self._name - u = states[f"{prefix}_u"] - new_u = solve_inf_gate_exponential( - u, dt, *self.u_gate(v, params[f"{prefix}_vx"]) - ) - return {f"{prefix}_u": new_u} + u = states["u"] + new_u = solve_inf_gate_exponential(u, dt, *self.u_gate(v, params["vx"])) + return {"u": new_u} def compute_current( self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray] ): """Return current.""" - prefix = self._name - u = states[f"{prefix}_u"] - s_inf = 1.0 / (1.0 + save_exp(-(v + params[f"{prefix}_vx"] + 57.0) / 6.2)) + u = states["u"] + s_inf = 1.0 / (1.0 + save_exp(-(v + params["vx"] + 57.0) / 6.2)) - gCaT = params[f"{prefix}_gCaT"] * (s_inf**2) * u # S/cm^2 + gCaT = params["gCaT"] * (s_inf**2) * u # S/cm^2 return gCaT * (v - params["eCa"]) def init_state(self, states, v, params, delta_t): """Initialize the state such at fixed point of gate dynamics.""" - prefix = self._name - alpha_u, beta_u = self.u_gate(v, params[f"{prefix}_vx"]) - return {f"{prefix}_u": alpha_u / (alpha_u + beta_u)} + alpha_u, beta_u = self.u_gate(v, params["vx"]) + return {"u": alpha_u / (alpha_u + beta_u)} @staticmethod def u_gate(v, vx): diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index a4683f0d..cf02cb00 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -1725,6 +1725,7 @@ def insert(self, channel: Channel): Args: channel: The channel to insert.""" + channel = AutoPrefix(channel) name = channel._name # Channel does not yet exist in the `jx.Module` at all. @@ -2746,7 +2747,11 @@ def __init__(self, mech: Union[Channel, Synapse]): self._wrapped = mech self.prefix = f"{self._wrapped.name}_" self._name = self._wrapped._name - self.current_name = "i_" + self.prefix[:-1] + self.current_name = ( + self._wrapped.current_name + if hasattr(self._wrapped, "current_name") + else "i_" + self.prefix[:-1] + ) # Make this class pretend to be the wrapped class self.__class__.__name__ = mech.__class__.__name__ @@ -2840,3 +2845,35 @@ def init_state( init_states = self._wrapped.init_state(states, v, params, dt) return self._transform_dict_keys(init_states, add_prefix=True) + + +def infer_global_params_states(mech: Union[Channel, Synapse]) -> List[str]: + """Infer the global parameters and states of a channel or synapse. + + Infers global params and states by testing for KeyErrors in the `update_states`, + `compute_current`, and `init_state` methods. + + Args: + mech: The channel or synapse to infer the global params and states of. + + Returns: + A list of the inferred global parameters and states. + """ + global_state_params = {} + mech_states = mech.states + mech_params = mech.params + + while True: + try: + states = {**global_state_params, **mech_states} + params = {**global_state_params, **mech_params} + + mech.update_states(states, 0.025, -70, params) + mech.compute_current(states, -70, params) + if isinstance(mech, Channel): + mech.init_state(states, -70, params, 0.025) + break + except KeyError as e: + missing_key = e.args[0] + global_state_params[missing_key] = 0 + return list(global_state_params) diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 5727446a..0019fdb4 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -13,7 +13,7 @@ from matplotlib import pyplot as plt from matplotlib.axes import Axes -from jaxley.modules.base import Module +from jaxley.modules.base import AutoPrefix, Module from jaxley.modules.cell import Cell from jaxley.utils.cell_utils import ( build_branchpoint_group_inds, @@ -520,6 +520,7 @@ def _update_synapse_state_names(self, synapse_type): def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type): # Add synapse types to the module and infer their unique identifier. + synapse_type = AutoPrefix(synapse_type) synapse_name = synapse_type._name synapse_current_name = f"i_{synapse_name}" type_ind, is_new = self._infer_synapse_type_ind(synapse_name)