Skip to content

Commit

Permalink
mv: rename channel and synapse param attrs
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Dec 22, 2024
1 parent ac5026d commit cdbc510
Show file tree
Hide file tree
Showing 13 changed files with 106 additions and 118 deletions.
8 changes: 4 additions & 4 deletions jaxley/channels/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,22 +59,22 @@ def change_name(self, new_name: str):
new_prefix = new_name + "_"

self._name = new_name
self.channel_params = {
self.params = {
(
new_prefix + key[len(old_prefix) :]
if key.startswith(old_prefix)
else key
): value
for key, value in self.channel_params.items()
for key, value in self.params.items()
}

self.channel_states = {
self.states = {
(
new_prefix + key[len(old_prefix) :]
if key.startswith(old_prefix)
else key
): value
for key, value in self.channel_states.items()
for key, value in self.states.items()
}
return self

Expand Down
4 changes: 2 additions & 2 deletions jaxley/channels/hh.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ def __init__(self, name: Optional[str] = None):

super().__init__(name)
prefix = self._name
self.channel_params = {
self.params = {
f"{prefix}_gNa": 0.12,
f"{prefix}_gK": 0.036,
f"{prefix}_gLeak": 0.0003,
f"{prefix}_eNa": 50.0,
f"{prefix}_eK": -77.0,
f"{prefix}_eLeak": -54.3,
}
self.channel_states = {
self.states = {
f"{prefix}_m": 0.2,
f"{prefix}_h": 0.2,
f"{prefix}_n": 0.2,
Expand Down
24 changes: 12 additions & 12 deletions jaxley/channels/pospischil.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ def __init__(self, name: Optional[str] = None):

super().__init__(name)
prefix = self._name
self.channel_params = {
self.params = {
f"{prefix}_gLeak": 1e-4,
f"{prefix}_eLeak": -70.0,
}
self.channel_states = {}
self.states = {}
self.current_name = f"i_{prefix}"

def update_states(
Expand Down Expand Up @@ -77,12 +77,12 @@ def __init__(self, name: Optional[str] = None):

super().__init__(name)
prefix = self._name
self.channel_params = {
self.params = {
f"{prefix}_gNa": 50e-3,
"eNa": 50.0,
"vt": -60.0, # Global parameter, not prefixed with `Na`.
}
self.channel_states = {f"{prefix}_m": 0.2, f"{prefix}_h": 0.2}
self.states = {f"{prefix}_m": 0.2, f"{prefix}_h": 0.2}
self.current_name = f"i_Na"

def update_states(
Expand Down Expand Up @@ -148,12 +148,12 @@ def __init__(self, name: Optional[str] = None):

super().__init__(name)
prefix = self._name
self.channel_params = {
self.params = {
f"{prefix}_gK": 5e-3,
"eK": -90.0,
"vt": -60.0, # Global parameter, not prefixed with `Na`.
}
self.channel_states = {f"{prefix}_n": 0.2}
self.states = {f"{prefix}_n": 0.2}
self.current_name = f"i_K"

def update_states(
Expand Down Expand Up @@ -204,12 +204,12 @@ def __init__(self, name: Optional[str] = None):

super().__init__(name)
prefix = self._name
self.channel_params = {
self.params = {
f"{prefix}_gKm": 0.004e-3,
f"{prefix}_taumax": 4000.0,
f"eK": -90.0,
}
self.channel_states = {f"{prefix}_p": 0.2}
self.states = {f"{prefix}_p": 0.2}
self.current_name = f"i_K"

def update_states(
Expand Down Expand Up @@ -261,11 +261,11 @@ def __init__(self, name: Optional[str] = None):

super().__init__(name)
prefix = self._name
self.channel_params = {
self.params = {
f"{prefix}_gCaL": 0.1e-3,
"eCa": 120.0,
}
self.channel_states = {f"{prefix}_q": 0.2, f"{prefix}_r": 0.2}
self.states = {f"{prefix}_q": 0.2, f"{prefix}_r": 0.2}
self.current_name = f"i_Ca"

def update_states(
Expand Down Expand Up @@ -329,12 +329,12 @@ def __init__(self, name: Optional[str] = None):

super().__init__(name)
prefix = self._name
self.channel_params = {
self.params = {
f"{prefix}_gCaT": 0.4e-4,
f"{prefix}_vx": 2.0,
"eCa": 120.0, # Global parameter, not prefixed with `CaT`.
}
self.channel_states = {f"{prefix}_u": 0.2}
self.states = {f"{prefix}_u": 0.2}
self.current_name = f"i_Ca"

def update_states(
Expand Down
68 changes: 28 additions & 40 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,9 +746,9 @@ def to_jax(self):
edges = self.base.edges.to_dict(orient="list")
for i, synapse in enumerate(self.base.synapses):
condition = np.asarray(edges["type_ind"]) == i
for key in synapse.synapse_params:
for key in synapse.params:
self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])
for key in synapse.synapse_states:
for key in synapse.states:
self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])

def show(
Expand Down Expand Up @@ -782,12 +782,8 @@ def show(
inds = [f"{s}_{i}" for i in inds for s in scopes] if indices else []
cols += inds
cols += [ch._name for ch in self.channels] if channel_names else []
cols += (
sum([list(ch.channel_params) for ch in self.channels], []) if params else []
)
cols += (
sum([list(ch.channel_states) for ch in self.channels], []) if states else []
)
cols += sum([list(ch.params) for ch in self.channels], []) if params else []
cols += sum([list(ch.states) for ch in self.channels], []) if states else []

if not param_names is None:
cols = (
Expand Down Expand Up @@ -916,12 +912,8 @@ def set_ncomp(
start_idx = self.nodes["global_comp_index"].to_numpy()[0]
ncomp_per_branch = self.base.ncomp_per_branch
channel_names = [c._name for c in self.base.channels]
channel_param_names = list(
chain(*[c.channel_params for c in self.base.channels])
)
channel_state_names = list(
chain(*[c.channel_states for c in self.base.channels])
)
channel_param_names = list(chain(*[c.params for c in self.base.channels]))
channel_state_names = list(chain(*[c.states for c in self.base.channels]))
radius_generating_fns = self.base._radius_generating_fns

within_branch_radiuses = view["radius"].to_numpy()
Expand Down Expand Up @@ -1166,9 +1158,9 @@ def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]):
edges = self.base.edges.to_dict(orient="list")
for i, synapse in enumerate(self.base.synapses):
condition = np.asarray(edges["type_ind"]) == i
for key in list(synapse.synapse_params.keys()):
for key in list(synapse.params.keys()):
self.base.edges.loc[condition, key] = all_params[key]
for key in list(synapse.synapse_states.keys()):
for key in list(synapse.states.keys()):
self.base.edges.loc[condition, key] = all_states[key]

def distance(self, endpoint: "View") -> float:
Expand Down Expand Up @@ -1221,9 +1213,9 @@ def _get_state_names(self) -> Tuple[List, List]:
"""Collect all recordable / clampable states in the membrane and synapses.
Returns states seperated by comps and edges."""
channel_states = [name for c in self.channels for name in c.channel_states]
channel_states = [name for c in self.channels for name in c.states]
synapse_states = [
name for s in self.synapses if s is not None for name in s.synapse_states
name for s in self.synapses if s is not None for name in s.states
]
membrane_states = ["v", "i"] + self.membrane_current_names
return (
Expand Down Expand Up @@ -1283,7 +1275,7 @@ def get_all_parameters(
params[key] = self.base.jaxnodes[key]

for channel in self.base.channels:
for channel_params in channel.channel_params:
for channel_params in channel.params:
params[channel_params] = self.base.jaxnodes[channel_params]

for synapse_params in self.base.synapse_param_names:
Expand Down Expand Up @@ -1327,7 +1319,7 @@ def _get_states_from_nodes_and_edges(self) -> Dict[str, jnp.ndarray]:
states = {"v": self.base.jaxnodes["v"]}
# Join node and edge states into a single state dictionary.
for channel in self.base.channels:
for channel_states in channel.channel_states:
for channel_states in channel.states:
states[channel_states] = self.base.jaxnodes[channel_states]
for synapse_states in self.base.synapse_state_names:
states[synapse_states] = self.base.jaxedges[synapse_states]
Expand Down Expand Up @@ -1410,8 +1402,8 @@ def init_states(self, delta_t: float = 0.025):
].to_numpy()
voltages = channel_nodes.loc[channel_indices, "v"].to_numpy()

channel_param_names = list(channel.channel_params.keys())
channel_state_names = list(channel.channel_states.keys())
channel_param_names = list(channel.params.keys())
channel_state_names = list(channel.states.keys())
channel_states = query_channel_states_and_params(
states, channel_state_names, channel_indices
)
Expand Down Expand Up @@ -1748,12 +1740,12 @@ def insert(self, channel: Channel):
self.base.nodes.loc[self._nodes_in_view, name] = True

# Loop over all new parameters, e.g. gNa, eNa.
for key in channel.channel_params:
self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_params[key]
for key in channel.params:
self.base.nodes.loc[self._nodes_in_view, key] = channel.params[key]

# Loop over all new parameters, e.g. gNa, eNa.
for key in channel.channel_states:
self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_states[key]
for key in channel.states:
self.base.nodes.loc[self._nodes_in_view, key] = channel.states[key]

def delete_channel(self, channel: Channel):
"""Remove a channel from the module.
Expand All @@ -1764,8 +1756,8 @@ def delete_channel(self, channel: Channel):
channel_names = [c._name for c in self.channels]
all_channel_names = [c._name for c in self.base.channels]
if name in channel_names:
channel_cols = list(channel.channel_params.keys())
channel_cols += list(channel.channel_states.keys())
channel_cols = list(channel.params.keys())
channel_cols += list(channel.states.keys())
self.base.nodes.loc[self._nodes_in_view, channel_cols] = float("nan")
self.base.nodes.loc[self._nodes_in_view, name] = False

Expand Down Expand Up @@ -1948,14 +1940,14 @@ def _step_channels_state(
# Update states of the channels.
indices = channel_nodes["global_comp_index"].to_numpy()
for channel in channels:
channel_param_names = list(channel.channel_params)
channel_param_names = list(channel.params)
channel_param_names += [
"radius",
"length",
"axial_resistivity",
"capacitance",
]
channel_state_names = list(channel.channel_states)
channel_state_names = list(channel.states)
channel_state_names += self.membrane_current_names
channel_indices = indices[channel_nodes[channel._name].astype(bool)]

Expand Down Expand Up @@ -2003,8 +1995,8 @@ def _channel_currents(

for channel in channels:
name = channel._name
channel_param_names = list(channel.channel_params.keys())
channel_state_names = list(channel.channel_states.keys())
channel_param_names = list(channel.params.keys())
channel_state_names = list(channel.states.keys())
indices = channel_nodes.loc[channel_nodes[name]][
"global_comp_index"
].to_numpy()
Expand Down Expand Up @@ -2599,13 +2591,9 @@ def _filter_trainables(
):
pkey, pval = next(iter(params.items()))
trainable_inds_in_view = None
if pkey in sum(
[list(c.channel_params.keys()) for c in self.base.channels], []
):
if pkey in sum([list(c.params.keys()) for c in self.base.channels], []):
trainable_inds_in_view = np.intersect1d(inds, self._nodes_in_view)
elif pkey in sum(
[list(s.synapse_params.keys()) for s in self.base.synapses], []
):
elif pkey in sum([list(s.params.keys()) for s in self.base.synapses], []):
trainable_inds_in_view = np.intersect1d(inds, self._edges_in_view)

in_view = is_viewed == np.isin(inds, trainable_inds_in_view)
Expand Down Expand Up @@ -2668,8 +2656,8 @@ def _set_synapses_in_view(self, pointer: Union[Module, View]):
viewed_synapses += (
[syn] if in_view else [None]
) # padded with None to keep indices consistent
viewed_params += list(syn.synapse_params.keys()) if in_view else []
viewed_states += list(syn.synapse_states.keys()) if in_view else []
viewed_params += list(syn.params.keys()) if in_view else []
viewed_states += list(syn.states.keys()) if in_view else []
self.synapses = viewed_synapses
self.synapse_param_names = viewed_params
self.synapse_state_names = viewed_states
Expand Down
16 changes: 8 additions & 8 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ def _step_synapse_state(
assert (
synapse_names[i] == synapse_type._name
), "Mixup in the ordering of synapses. Please create an issue on Github."
synapse_param_names = list(synapse_type.synapse_params.keys())
synapse_state_names = list(synapse_type.synapse_states.keys())
synapse_param_names = list(synapse_type.params.keys())
synapse_state_names = list(synapse_type.states.keys())

synapse_params = {}
for p in synapse_param_names:
Expand Down Expand Up @@ -325,8 +325,8 @@ def _synapse_currents(
assert (
synapse_names[i] == synapse_type._name
), "Mixup in the ordering of synapses. Please create an issue on Github."
synapse_param_names = list(synapse_type.synapse_params.keys())
synapse_state_names = list(synapse_type.synapse_states.keys())
synapse_param_names = list(synapse_type.params.keys())
synapse_state_names = list(synapse_type.states.keys())

synapse_params = {}
for p in synapse_param_names:
Expand Down Expand Up @@ -514,8 +514,8 @@ def _infer_synapse_type_ind(self, synapse_name):
def _update_synapse_state_names(self, synapse_type):
# (Potentially) update variables that track meta information about synapses.
self.base.synapse_names.append(synapse_type._name)
self.base.synapse_param_names += list(synapse_type.synapse_params.keys())
self.base.synapse_state_names += list(synapse_type.synapse_states.keys())
self.base.synapse_param_names += list(synapse_type.params.keys())
self.base.synapse_state_names += list(synapse_type.states.keys())
self.base.synapses.append(synapse_type)

def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type):
Expand Down Expand Up @@ -567,9 +567,9 @@ def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type):

def _add_params_to_edges(self, synapse_type, indices):
# Add parameters and states to the `.edges` table.
for key, param_val in synapse_type.synapse_params.items():
for key, param_val in synapse_type.params.items():
self.base.edges.loc[indices, key] = param_val

# Update synaptic state array.
for key, state_val in synapse_type.synapse_states.items():
for key, state_val in synapse_type.states.items():
self.base.edges.loc[indices, key] = state_val
4 changes: 2 additions & 2 deletions jaxley/synapses/ionotropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ class IonotropicSynapse(Synapse):
def __init__(self, name: Optional[str] = None):
super().__init__(name)
prefix = self._name
self.synapse_params = {
self.params = {
f"{prefix}_gS": 1e-4,
f"{prefix}_e_syn": 0.0,
f"{prefix}_k_minus": 0.025,
}
self.synapse_states = {f"{prefix}_s": 0.2}
self.states = {f"{prefix}_s": 0.2}

def update_states(
self,
Expand Down
8 changes: 4 additions & 4 deletions jaxley/synapses/synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,22 @@ def change_name(self, new_name: str):
new_prefix = new_name + "_"

self._name = new_name
self.synapse_params = {
self.params = {
(
new_prefix + key[len(old_prefix) :]
if key.startswith(old_prefix)
else key
): value
for key, value in self.synapse_params.items()
for key, value in self.params.items()
}

self.synapse_states = {
self.states = {
(
new_prefix + key[len(old_prefix) :]
if key.startswith(old_prefix)
else key
): value
for key, value in self.synapse_states.items()
for key, value in self.states.items()
}
return self

Expand Down
Loading

0 comments on commit cdbc510

Please sign in to comment.