Skip to content

Commit

Permalink
wip: rename channel and synapse params and enh to_jax
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Dec 4, 2024
1 parent efa1e28 commit fadeb5d
Show file tree
Hide file tree
Showing 13 changed files with 120 additions and 121 deletions.
8 changes: 4 additions & 4 deletions jaxley/channels/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,22 +59,22 @@ def change_name(self, new_name: str):
new_prefix = new_name + "_"

self._name = new_name
self.channel_params = {
self.params = {
(
new_prefix + key[len(old_prefix) :]
if key.startswith(old_prefix)
else key
): value
for key, value in self.channel_params.items()
for key, value in self.params.items()
}

self.channel_states = {
self.states = {
(
new_prefix + key[len(old_prefix) :]
if key.startswith(old_prefix)
else key
): value
for key, value in self.channel_states.items()
for key, value in self.states.items()
}
return self

Expand Down
4 changes: 2 additions & 2 deletions jaxley/channels/hh.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ def __init__(self, name: Optional[str] = None):

super().__init__(name)
prefix = self._name
self.channel_params = {
self.params = {
f"{prefix}_gNa": 0.12,
f"{prefix}_gK": 0.036,
f"{prefix}_gLeak": 0.0003,
f"{prefix}_eNa": 50.0,
f"{prefix}_eK": -77.0,
f"{prefix}_eLeak": -54.3,
}
self.channel_states = {
self.states = {
f"{prefix}_m": 0.2,
f"{prefix}_h": 0.2,
f"{prefix}_n": 0.2,
Expand Down
24 changes: 12 additions & 12 deletions jaxley/channels/pospischil.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ def __init__(self, name: Optional[str] = None):

super().__init__(name)
prefix = self._name
self.channel_params = {
self.params = {
f"{prefix}_gLeak": 1e-4,
f"{prefix}_eLeak": -70.0,
}
self.channel_states = {}
self.states = {}
self.current_name = f"i_{prefix}"

def update_states(
Expand Down Expand Up @@ -77,12 +77,12 @@ def __init__(self, name: Optional[str] = None):

super().__init__(name)
prefix = self._name
self.channel_params = {
self.params = {
f"{prefix}_gNa": 50e-3,
"eNa": 50.0,
"vt": -60.0, # Global parameter, not prefixed with `Na`.
}
self.channel_states = {f"{prefix}_m": 0.2, f"{prefix}_h": 0.2}
self.states = {f"{prefix}_m": 0.2, f"{prefix}_h": 0.2}
self.current_name = f"i_Na"

def update_states(
Expand Down Expand Up @@ -148,12 +148,12 @@ def __init__(self, name: Optional[str] = None):

super().__init__(name)
prefix = self._name
self.channel_params = {
self.params = {
f"{prefix}_gK": 5e-3,
"eK": -90.0,
"vt": -60.0, # Global parameter, not prefixed with `Na`.
}
self.channel_states = {f"{prefix}_n": 0.2}
self.states = {f"{prefix}_n": 0.2}
self.current_name = f"i_K"

def update_states(
Expand Down Expand Up @@ -204,12 +204,12 @@ def __init__(self, name: Optional[str] = None):

super().__init__(name)
prefix = self._name
self.channel_params = {
self.params = {
f"{prefix}_gKm": 0.004e-3,
f"{prefix}_taumax": 4000.0,
f"eK": -90.0,
}
self.channel_states = {f"{prefix}_p": 0.2}
self.states = {f"{prefix}_p": 0.2}
self.current_name = f"i_K"

def update_states(
Expand Down Expand Up @@ -261,11 +261,11 @@ def __init__(self, name: Optional[str] = None):

super().__init__(name)
prefix = self._name
self.channel_params = {
self.params = {
f"{prefix}_gCaL": 0.1e-3,
"eCa": 120.0,
}
self.channel_states = {f"{prefix}_q": 0.2, f"{prefix}_r": 0.2}
self.states = {f"{prefix}_q": 0.2, f"{prefix}_r": 0.2}
self.current_name = f"i_Ca"

def update_states(
Expand Down Expand Up @@ -329,12 +329,12 @@ def __init__(self, name: Optional[str] = None):

super().__init__(name)
prefix = self._name
self.channel_params = {
self.params = {
f"{prefix}_gCaT": 0.4e-4,
f"{prefix}_vx": 2.0,
"eCa": 120.0, # Global parameter, not prefixed with `CaT`.
}
self.channel_states = {f"{prefix}_u": 0.2}
self.states = {f"{prefix}_u": 0.2}
self.current_name = f"i_Ca"

def update_states(
Expand Down
89 changes: 44 additions & 45 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ def __getattr__(self, key):
view._set_controlled_by_param(key) # overwrites param set by edge
# Ensure synapse param sharing works with `edge`
# `edge` will be removed as part of #463
view.edges["local_edge_index"] = np.arange(len(view.edges))
return view

def _childviews(self) -> List[str]:
Expand Down Expand Up @@ -710,17 +709,30 @@ def to_jax(self):
they can be processed on GPU/TPU and such that the simulation can be
differentiated. `.to_jax()` copies the `.nodes` to `.jaxnodes`.
"""
jaxnodes = self.base.jaxnodes = {}
nodes = self.base.nodes.to_dict(orient="list")
jaxnodes, jaxedges = {}, {}

jaxedges = self.base.jaxedges = {}
edges = self.base.edges.to_dict(orient="list")
edges.pop("type") # drop since column type is string
for jax_arrays, data, mechs in zip(
[jaxnodes, jaxedges],
[self.nodes, self.edges],
[self.channels, self.synapses],
):
jax_arrays.update({"index": data.index.to_numpy()})
all_inds = jax_arrays["index"]
for mech in mechs:
inds = (
all_inds[data["type"] == mech._name]
if "type" in data.columns
else all_inds[self.nodes[mech._name]]
)
states_params = list(mech.params) + list(mech.states)
params = data[states_params].loc[inds]
jax_arrays.update({mech._name: inds})
jax_arrays.update(params.to_dict(orient="list"))

for jax_array, params in zip([jaxnodes, jaxedges], [nodes, edges]):
for key, value in params.items():
inds = jnp.arange(len(value))
jax_array[key] = jnp.asarray(value)[inds]
morph_params = ["radius", "length", "axial_resistivity", "capacitance"]
jaxnodes.update(self.nodes[["v"]+morph_params].to_dict(orient="list"))
jaxnodes = {k: jnp.asarray(v) for k, v in jaxnodes.items()}
jaxedges = {k: jnp.asarray(v) for k, v in jaxedges.items()}

def show(
self,
Expand Down Expand Up @@ -753,12 +765,8 @@ def show(
inds = [f"{s}_{i}" for i in inds for s in scopes] if indices else []
cols += inds
cols += [ch._name for ch in self.channels] if channel_names else []
cols += (
sum([list(ch.channel_params) for ch in self.channels], []) if params else []
)
cols += (
sum([list(ch.channel_states) for ch in self.channels], []) if states else []
)
cols += sum([list(ch.params) for ch in self.channels], []) if params else []
cols += sum([list(ch.states) for ch in self.channels], []) if states else []

if not param_names is None:
cols = (
Expand Down Expand Up @@ -887,12 +895,8 @@ def set_ncomp(
start_idx = self.nodes["global_comp_index"].to_numpy()[0]
nseg_per_branch = self.base.nseg_per_branch
channel_names = [c._name for c in self.base.channels]
channel_param_names = list(
chain(*[c.channel_params for c in self.base.channels])
)
channel_state_names = list(
chain(*[c.channel_states for c in self.base.channels])
)
channel_param_names = list(chain(*[c.params for c in self.base.channels]))
channel_state_names = list(chain(*[c.states for c in self.base.channels]))
radius_generating_fns = self.base._radius_generating_fns

within_branch_radiuses = view["radius"].to_numpy()
Expand Down Expand Up @@ -1213,11 +1217,12 @@ def get_all_parameters(
A dictionary of all module parameters.
"""
params = {}
for key in ["radius", "length", "axial_resistivity", "capacitance"]:
morph_params = ["radius", "length", "axial_resistivity", "capacitance"]
for key in ["v"] + morph_params:
params[key] = self.base.jaxnodes[key]

for channel in self.base.channels:
for channel_params in channel.channel_params:
for channel_params in channel.params:
params[channel_params] = self.base.jaxnodes[channel_params]

for synapse_params in self.base.synapse_param_names:
Expand Down Expand Up @@ -1250,7 +1255,7 @@ def _get_states_from_nodes_and_edges(self) -> Dict[str, jnp.ndarray]:
states = {"v": self.base.jaxnodes["v"]}
# Join node and edge states into a single state dictionary.
for channel in self.base.channels:
for channel_states in channel.channel_states:
for channel_states in channel.states:
states[channel_states] = self.base.jaxnodes[channel_states]
for synapse_states in self.base.synapse_state_names:
states[synapse_states] = self.base.jaxedges[synapse_states]
Expand Down Expand Up @@ -1333,8 +1338,8 @@ def init_states(self, delta_t: float = 0.025):
].to_numpy()
voltages = channel_nodes.loc[channel_indices, "v"].to_numpy()

channel_param_names = list(channel.channel_params.keys())
channel_state_names = list(channel.channel_states.keys())
channel_param_names = list(channel.params.keys())
channel_state_names = list(channel.states.keys())
channel_states = query_states_and_params(
states, channel_state_names, channel_indices
)
Expand Down Expand Up @@ -1653,12 +1658,12 @@ def insert(self, channel: Channel):
self.base.nodes.loc[self._nodes_in_view, name] = True

# Loop over all new parameters, e.g. gNa, eNa.
for key in channel.channel_params:
self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_params[key]
for key in channel.params:
self.base.nodes.loc[self._nodes_in_view, key] = channel.params[key]

# Loop over all new parameters, e.g. gNa, eNa.
for key in channel.channel_states:
self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_states[key]
for key in channel.states:
self.base.nodes.loc[self._nodes_in_view, key] = channel.states[key]

@only_allow_module
def step(
Expand Down Expand Up @@ -1836,11 +1841,9 @@ def _step_channels_state(
query_channel = lambda d, names: query_states_and_params(
d, names, channel_inds
)
channel_param_names = list(channel.channel_params) + morph_params
channel_param_names = list(channel.params) + morph_params
channel_params = query_channel(params, channel_param_names)
channel_state_names = (
list(channel.channel_states) + self.membrane_current_names
)
channel_state_names = list(channel.states) + self.membrane_current_names
channel_states = query_channel(states, channel_state_names)

# States updates.
Expand Down Expand Up @@ -1884,9 +1887,9 @@ def _channel_currents(
query_channel = lambda d, names: query_states_and_params(
d, names, channel_inds
)
channel_param_names = list(channel.channel_params) + morph_params
channel_param_names = list(channel.params) + morph_params
channel_params = query_channel(params, channel_param_names)
channel_states = query_channel(states, channel.channel_states)
channel_states = query_channel(states, channel.states)

v_channel = voltages[channel_inds]
v_and_perturbed = jnp.array([v_channel, v_channel + diff])
Expand Down Expand Up @@ -2406,13 +2409,9 @@ def _filter_trainables(
):
pkey, pval = next(iter(params.items()))
trainable_inds_in_view = None
if pkey in sum(
[list(c.channel_params.keys()) for c in self.base.channels], []
):
if pkey in sum([list(c.params.keys()) for c in self.base.channels], []):
trainable_inds_in_view = np.intersect1d(inds, self._nodes_in_view)
elif pkey in sum(
[list(s.synapse_params.keys()) for s in self.base.synapses], []
):
elif pkey in sum([list(s.params.keys()) for s in self.base.synapses], []):
trainable_inds_in_view = np.intersect1d(inds, self._edges_in_view)

in_view = is_viewed == np.isin(inds, trainable_inds_in_view)
Expand Down Expand Up @@ -2475,8 +2474,8 @@ def _set_synapses_in_view(self, pointer: Union[Module, View]):
viewed_synapses += (
[syn] if in_view else [None]
) # padded with None to keep indices consistent
viewed_params += list(syn.synapse_params.keys()) if in_view else []
viewed_states += list(syn.synapse_states.keys()) if in_view else []
viewed_params += list(syn.params.keys()) if in_view else []
viewed_states += list(syn.states.keys()) if in_view else []
self.synapses = viewed_synapses
self.synapse_param_names = viewed_params
self.synapse_state_names = viewed_states
Expand Down
16 changes: 8 additions & 8 deletions jaxley/modules/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,8 @@ def _step_synapse_state(
edge_inds = group.index.to_numpy()

query_syn = lambda d, names: query_states_and_params(d, names, edge_inds)
synapse_params = query_syn(params, synapse.synapse_params)
synapse_states = query_syn(states, synapse.synapse_states)
synapse_params = query_syn(params, synapse.params)
synapse_states = query_syn(states, synapse.states)

# State updates.
states_updated = synapse.update_states(
Expand Down Expand Up @@ -313,8 +313,8 @@ def _synapse_currents(
edge_inds = group.index.to_numpy()

query_syn = lambda d, names: query_states_and_params(d, names, edge_inds)
synapse_params = query_syn(params, synapse.synapse_params)
synapse_states = query_syn(states, synapse.synapse_states)
synapse_params = query_syn(params, synapse.params)
synapse_states = query_syn(states, synapse.states)

v_pre, v_post = voltages[pre_inds], voltages[post_inds]
pre_v_and_perturbed = jnp.array([v_pre, v_pre + diff])
Expand Down Expand Up @@ -535,8 +535,8 @@ def _infer_synapse_type_ind(self, synapse_name):
def _update_synapse_state_names(self, synapse_type):
# (Potentially) update variables that track meta information about synapses.
self.base.synapse_names.append(synapse_type._name)
self.base.synapse_param_names += list(synapse_type.synapse_params.keys())
self.base.synapse_state_names += list(synapse_type.synapse_states.keys())
self.base.synapse_param_names += list(synapse_type.params.keys())
self.base.synapse_state_names += list(synapse_type.states.keys())
self.base.synapses.append(synapse_type)

def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type):
Expand Down Expand Up @@ -586,9 +586,9 @@ def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type):

def _add_params_to_edges(self, synapse_type, indices):
# Add parameters and states to the `.edges` table.
for key, param_val in synapse_type.synapse_params.items():
for key, param_val in synapse_type.params.items():
self.base.edges.loc[indices, key] = param_val

# Update synaptic state array.
for key, state_val in synapse_type.synapse_states.items():
for key, state_val in synapse_type.states.items():
self.base.edges.loc[indices, key] = state_val
4 changes: 2 additions & 2 deletions jaxley/synapses/ionotropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ class IonotropicSynapse(Synapse):
def __init__(self, name: Optional[str] = None):
super().__init__(name)
prefix = self._name
self.synapse_params = {
self.params = {
f"{prefix}_gS": 1e-4,
f"{prefix}_e_syn": 0.0,
f"{prefix}_k_minus": 0.025,
}
self.synapse_states = {f"{prefix}_s": 0.2}
self.states = {f"{prefix}_s": 0.2}

def update_states(
self,
Expand Down
Loading

0 comments on commit fadeb5d

Please sign in to comment.