Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New structure for synapses #209

Closed
wants to merge 13 commits into from
8 changes: 3 additions & 5 deletions jaxley/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,22 +99,20 @@ def _body_fun(state, i_stim):
dummy_stimulus = jnp.zeros((size_difference, i_current.shape[1]))
i_current = jnp.concatenate([i_current, dummy_stimulus])

# Join node and edge states.
# Join node and edge states into a single state dictionary.
states = {"voltages": module.jaxnodes["voltages"]}
for channel in module.channels:
for channel_states in list(channel.channel_states.keys()):
states[channel_states] = module.jaxnodes[channel_states]
for synapse_states in module.synapse_state_names:
states[synapse_states] = module.jaxedges[synapse_states]

# Override with the initial states set by `.make_trainable()`.
for inds, set_param in zip(module.indices_set_by_trainables, params):
for key in set_param.keys():
if key in list(states.keys()): # Only initial states, not parameters.
states[key] = states[key].at[inds].set(set_param[key])

# Write synaptic states. TODO move above when new interface for synapses.
for key in module.syn_states:
states[key] = module.syn_states[key]

# Run simulation.
_, recordings = nested_checkpoint_scan(
_body_fun, states, i_current, length=length, nested_lengths=checkpoint_lengths
Expand Down
79 changes: 41 additions & 38 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from jaxley.channels import Channel
from jaxley.solver_voltage import step_voltage_explicit, step_voltage_implicit
from jaxley.synapses import Synapse
from jaxley.utils.cell_utils import (
_compute_index_of_child,
_compute_num_children,
Expand All @@ -26,19 +25,16 @@ def __init__(self):
self.nseg: int = None
self.total_nbranches: int = 0
self.nbranches_per_cell: List[int] = None

self.conns: List[Synapse] = None
self.group_views = {}

self.nodes: Optional[pd.DataFrame] = None

self.syn_edges = pd.DataFrame(
self.edges = pd.DataFrame(
columns=[
"pre_locs",
"post_locs",
"pre_branch_index",
"post_branch_index",
"pre_cell_index",
"post_locs",
"post_branch_index",
"post_cell_index",
"type",
"type_ind",
Expand All @@ -48,8 +44,6 @@ def __init__(self):
"global_post_branch_index",
]
)
self.branch_edges: Optional[pd.DataFrame] = None

self.cumsum_nbranches: Optional[jnp.ndarray] = None

self.comb_parents: jnp.ndarray = jnp.asarray([-1])
Expand All @@ -58,11 +52,12 @@ def __init__(self):
self.initialized_morph: bool = False
self.initialized_syns: bool = False

self.syn_params: Dict[str, jnp.ndarray] = {}
self.syn_states: Dict[str, jnp.ndarray] = {}
self.syn_classes: List = []
# List of all types of `jx.Synapse`s.
self.synapses: List = []
self.synapse_param_names = []
self.synapse_state_names = []

# List of all `jx.Channel`s.
# List of all types of `jx.Channel`s.
self.channels: List[Channel] = []

# For trainable parameters.
Expand Down Expand Up @@ -116,10 +111,19 @@ def _gather_channels_from_constituents(self, constituents: List) -> None:
self.nodes.loc[self.nodes[name].isna(), name] = False

def to_jax(self):
"""Generates Dict[jnp.ndarray] from the pd.DataFrames for nodes and edges."""
self.jaxnodes = {}
for key, value in self.nodes.to_dict(orient="list").items():
self.jaxnodes[key] = jnp.asarray(value)

# TODO(@michaeldeistler): if we wanted to reduce memory footprint, we could here
# remove NaN from jaxedges parameters and states. Then we only have to fix
# step_synapse and make_trainable with corresponding index updates.
self.jaxedges = {}
for key, value in self.edges.to_dict(orient="list").items():
if key != "type":
self.jaxedges[key] = jnp.asarray(value)

def show(
self,
param_names: Optional[Union[str, List[str]]] = None, # TODO.
Expand Down Expand Up @@ -198,17 +202,17 @@ def _append_channel_to_nodes(self, view, channel: "jx.Channel"):
def set(self, key, val):
"""Set parameter."""
# Alternatively, we could do `assert key not in self.syn_params`.
nodes = self.syn_edges if key in self.syn_params else self.nodes
self._set(key, val, nodes)

def _set(self, key, val, view):
if key in self.syn_params:
self.syn_params[key] = self.syn_params[key].at[view.index.values].set(val)
elif key in self.syn_states:
self.syn_states[key] = self.syn_states[key].at[view.index.values].set(val)
elif key in view.columns:
view = (
self.edges
if key in self.synapse_param_names or key in self.synapse_state_names
else self.nodes
)
self._set(key, val, view, view)

def _set(self, key, val, view, table_to_update):
if key in view.columns:
view = view[~np.isnan(view[key])]
self.nodes.loc[view.index.values, key] = val
table_to_update.loc[view.index.values, key] = val
else:
raise KeyError("Key not recognized.")

Expand All @@ -230,7 +234,12 @@ def make_trainable(
verbose: Whether to print the number of parameters that are added and the
total number of parameters.
"""
view = deepcopy(self.nodes.assign(controlled_by_param=0))
view = (
self.edges
if key in self.synapse_param_names or key in self.synapse_state_names
else self.nodes
)
view = deepcopy(view.assign(controlled_by_param=0))
self._make_trainable(view, key, init_val, verbose=verbose)

def _make_trainable(
Expand All @@ -244,12 +253,7 @@ def _make_trainable(
self.allow_make_trainable
), "network.cell('all').make_trainable() is not supported. Use a for-loop over cells."

if key in self.syn_params:
grouped_view = view.groupby("controlled_by_param")
inds_of_comps = list(grouped_view.apply(lambda x: x.index.values))
indices_per_param = jnp.stack(inds_of_comps)
param_vals = self.syn_params[key][indices_per_param]
elif key in view.columns:
if key in view.columns:
view = view[~np.isnan(view[key])]
grouped_view = view.groupby("controlled_by_param")
inds_of_comps = list(grouped_view.apply(lambda x: x.index.values))
Expand Down Expand Up @@ -313,8 +317,8 @@ def get_all_parameters(self, trainable_params):
for channel_params in list(channel.channel_params.keys()):
params[channel_params] = self.jaxnodes[channel_params]

for key, val in self.syn_params.items():
params[key] = val
for synapse_params in self.synapse_param_names:
params[synapse_params] = self.jaxedges[synapse_params]

# Override with those parameters set by `.make_trainable()`.
for inds, set_param in zip(self.indices_set_by_trainables, trainable_params):
Expand All @@ -337,7 +341,6 @@ def initialized(self):
def initialize(self):
"""Initialize the module."""
self.init_morph()
self.init_syns()
return self

def record(self):
Expand Down Expand Up @@ -382,7 +385,7 @@ def insert(self, channel):
def _insert(self, channel, view):
self._append_channel_to_nodes(view, channel)

def init_syns(self):
def init_syns(self, connectivities):
self.initialized_syns = True

def init_morph(self):
Expand Down Expand Up @@ -414,10 +417,10 @@ def step(
# Step of the synapse.
u, syn_voltage_terms, syn_constant_terms = self._step_synapse(
u,
self.syn_classes,
self.synapses,
params,
delta_t,
self.syn_edges,
self.edges,
)

# Voltage steps.
Expand Down Expand Up @@ -460,7 +463,7 @@ def _step_channels(
states,
delta_t,
channels: List[Channel],
channel_nodes: List[pd.DataFrame],
channel_nodes: pd.DataFrame,
params: Dict[str, jnp.ndarray],
):
"""One step of integration of the channels."""
Expand Down Expand Up @@ -741,7 +744,7 @@ def stimulate(self, current):

def set(self, key: str, val: float):
"""Set parameters of the pointer."""
self.pointer._set(key, val, self.view)
self.pointer._set(key, val, self.view, self.pointer.nodes)

def make_trainable(self, key: str, init_val: Optional[Union[float, list]] = None):
"""Make a parameter trainable."""
Expand Down
1 change: 1 addition & 0 deletions jaxley/modules/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
dict(parent_branch_index=[], child_branch_index=[])
)
self.initialize()
self.init_syns(None)
self.initialized_conds = False

def __getattr__(self, key):
Expand Down
1 change: 1 addition & 0 deletions jaxley/modules/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
)

self.initialize()
self.init_syns(None)
self.initialized_conds = False

def __getattr__(self, key):
Expand Down
56 changes: 17 additions & 39 deletions jaxley/modules/compartment.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self):

# Initialize the module.
self.initialize()
self.init_syns(None)
self.initialized_conds = True

def init_conds(self, params):
Expand Down Expand Up @@ -74,7 +75,7 @@ def connect(self, post: "CompartmentView", synapse_type):
we need to register it as a new synapse in a bunch of dictionaries which track
synapse parameters, state and meta information.

Next, we register the new connection in the synapse dataframe (`.syn_edges`).
Next, we register the new connection in the synapse dataframe (`.edges`).
Then, we update synapse parameter and state arrays with the new connection.
Finally, we update synapse meta information.
"""
Expand All @@ -84,15 +85,15 @@ def connect(self, post: "CompartmentView", synapse_type):
if is_new_type:
# New type: index for the synapse type is one more than the currently
# highest index.
max_ind = self.pointer.syn_edges["type_ind"].max() + 1
max_ind = self.pointer.edges["type_ind"].max() + 1
type_ind = 0 if jnp.isnan(max_ind) else max_ind
else:
# Not a new type: search for the index that this type has previously had.
type_ind = self.pointer.syn_edges.query(f"type == '{synapse_name}'")[
type_ind = self.pointer.edges.query(f"type == '{synapse_name}'")[
"type_ind"
].to_numpy()[0]

# The `syn_edges` dataframe expects the compartment as continuous `loc`, not
# The `edges` dataframe expects the compartment as continuous `loc`, not
# as discrete compartment index (because the continuous `loc` is used for
# plotting). Below, we cast the compartment index to its (rough) location.
pre_comp = loc_of_index(
Expand All @@ -101,11 +102,12 @@ def connect(self, post: "CompartmentView", synapse_type):
post_comp = loc_of_index(
post.view["global_comp_index"].to_numpy(), self.pointer.nseg
)
index = len(self.pointer.edges)

# Update edges.
self.pointer.syn_edges = pd.concat(
self.pointer.edges = pd.concat(
[
self.pointer.syn_edges,
self.pointer.edges,
pd.DataFrame(
dict(
pre_locs=pre_comp,
Expand All @@ -131,42 +133,18 @@ def connect(self, post: "CompartmentView", synapse_type):
],
ignore_index=True,
)

# We add a column called index which is used by `adjust_view` of the
# `SynapseView` (see `network.py`).
self.pointer.syn_edges["index"] = list(self.pointer.syn_edges.index)

# Update synaptic parameter array.
# Add parameters and states to the `.edges` table.
indices = list(range(index, index + 1))
for key in synapse_type.synapse_params:
param_vals = jnp.asarray([synapse_type.synapse_params[key]])
if is_new_type:
# Register parameter array for new synapse type.
self.pointer.syn_params[key] = param_vals
else:
# Append to synaptic parameter array.
self.pointer.syn_params[key] = jnp.concatenate(
[self.pointer.syn_params[key], param_vals]
)

# Update synaptic state array.
param_val = synapse_type.synapse_params[key]
self.pointer.edges.loc[indices, key] = param_val
for key in synapse_type.synapse_states:
state_vals = jnp.asarray([synapse_type.synapse_states[key]])
if is_new_type:
# Register parameter array for new synapse type.
self.pointer.syn_states[key] = state_vals
else:
# Append to synaptic parameter array.
self.pointer.syn_states[key] = jnp.concatenate(
[self.pointer.syn_states[key], state_vals]
)
state_val = synapse_type.synapse_states[key]
self.pointer.edges.loc[indices, key] = state_val

# (Potentially) update variables that track meta information about synapses.
if is_new_type:
self.pointer.synapse_names.append(type(synapse_type).__name__)
self.pointer.synapse_param_names.append(
list(synapse_type.synapse_params.keys())
)
self.pointer.synapse_state_names.append(
list(synapse_type.synapse_states.keys())
)
self.pointer.syn_classes.append(synapse_type)
self.pointer.synapse_param_names += list(synapse_type.synapse_params.keys())
self.pointer.synapse_state_names += list(synapse_type.synapse_states.keys())
self.pointer.synapses.append(synapse_type)
Loading