diff --git a/jaxley/integrate.py b/jaxley/integrate.py index 40772a5f..54d8c51f 100644 --- a/jaxley/integrate.py +++ b/jaxley/integrate.py @@ -99,11 +99,13 @@ def _body_fun(state, i_stim): dummy_stimulus = jnp.zeros((size_difference, i_current.shape[1])) i_current = jnp.concatenate([i_current, dummy_stimulus]) - # Join node and edge states. + # Join node and edge states into a single state dictionary. states = {"voltages": module.jaxnodes["voltages"]} for channel in module.channels: for channel_states in list(channel.channel_states.keys()): states[channel_states] = module.jaxnodes[channel_states] + for synapse_states in module.synapse_state_names: + states[synapse_states] = module.jaxedges[synapse_states] # Override with the initial states set by `.make_trainable()`. for inds, set_param in zip(module.indices_set_by_trainables, params): @@ -111,10 +113,6 @@ def _body_fun(state, i_stim): if key in list(states.keys()): # Only initial states, not parameters. states[key] = states[key].at[inds].set(set_param[key]) - # Write synaptic states. TODO move above when new interface for synapses. - for key in module.syn_states: - states[key] = module.syn_states[key] - # Run simulation. _, recordings = nested_checkpoint_scan( _body_fun, states, i_current, length=length, nested_lengths=checkpoint_lengths diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 9dc7cdd7..e454786e 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -12,7 +12,6 @@ from jaxley.channels import Channel from jaxley.solver_voltage import step_voltage_explicit, step_voltage_implicit -from jaxley.synapses import Synapse from jaxley.utils.cell_utils import ( _compute_index_of_child, _compute_num_children, @@ -26,19 +25,16 @@ def __init__(self): self.nseg: int = None self.total_nbranches: int = 0 self.nbranches_per_cell: List[int] = None - - self.conns: List[Synapse] = None self.group_views = {} self.nodes: Optional[pd.DataFrame] = None - - self.syn_edges = pd.DataFrame( + self.edges = pd.DataFrame( columns=[ "pre_locs", - "post_locs", "pre_branch_index", - "post_branch_index", "pre_cell_index", + "post_locs", + "post_branch_index", "post_cell_index", "type", "type_ind", @@ -48,8 +44,6 @@ def __init__(self): "global_post_branch_index", ] ) - self.branch_edges: Optional[pd.DataFrame] = None - self.cumsum_nbranches: Optional[jnp.ndarray] = None self.comb_parents: jnp.ndarray = jnp.asarray([-1]) @@ -58,11 +52,12 @@ def __init__(self): self.initialized_morph: bool = False self.initialized_syns: bool = False - self.syn_params: Dict[str, jnp.ndarray] = {} - self.syn_states: Dict[str, jnp.ndarray] = {} - self.syn_classes: List = [] + # List of all types of `jx.Synapse`s. + self.synapses: List = [] + self.synapse_param_names = [] + self.synapse_state_names = [] - # List of all `jx.Channel`s. + # List of all types of `jx.Channel`s. self.channels: List[Channel] = [] # For trainable parameters. @@ -116,10 +111,19 @@ def _gather_channels_from_constituents(self, constituents: List) -> None: self.nodes.loc[self.nodes[name].isna(), name] = False def to_jax(self): + """Generates Dict[jnp.ndarray] from the pd.DataFrames for nodes and edges.""" self.jaxnodes = {} for key, value in self.nodes.to_dict(orient="list").items(): self.jaxnodes[key] = jnp.asarray(value) + # TODO(@michaeldeistler): if we wanted to reduce memory footprint, we could here + # remove NaN from jaxedges parameters and states. Then we only have to fix + # step_synapse and make_trainable with corresponding index updates. + self.jaxedges = {} + for key, value in self.edges.to_dict(orient="list").items(): + if key != "type": + self.jaxedges[key] = jnp.asarray(value) + def show( self, param_names: Optional[Union[str, List[str]]] = None, # TODO. @@ -198,17 +202,17 @@ def _append_channel_to_nodes(self, view, channel: "jx.Channel"): def set(self, key, val): """Set parameter.""" # Alternatively, we could do `assert key not in self.syn_params`. - nodes = self.syn_edges if key in self.syn_params else self.nodes - self._set(key, val, nodes) - - def _set(self, key, val, view): - if key in self.syn_params: - self.syn_params[key] = self.syn_params[key].at[view.index.values].set(val) - elif key in self.syn_states: - self.syn_states[key] = self.syn_states[key].at[view.index.values].set(val) - elif key in view.columns: + view = ( + self.edges + if key in self.synapse_param_names or key in self.synapse_state_names + else self.nodes + ) + self._set(key, val, view, view) + + def _set(self, key, val, view, table_to_update): + if key in view.columns: view = view[~np.isnan(view[key])] - self.nodes.loc[view.index.values, key] = val + table_to_update.loc[view.index.values, key] = val else: raise KeyError("Key not recognized.") @@ -230,7 +234,12 @@ def make_trainable( verbose: Whether to print the number of parameters that are added and the total number of parameters. """ - view = deepcopy(self.nodes.assign(controlled_by_param=0)) + view = ( + self.edges + if key in self.synapse_param_names or key in self.synapse_state_names + else self.nodes + ) + view = deepcopy(view.assign(controlled_by_param=0)) self._make_trainable(view, key, init_val, verbose=verbose) def _make_trainable( @@ -244,12 +253,7 @@ def _make_trainable( self.allow_make_trainable ), "network.cell('all').make_trainable() is not supported. Use a for-loop over cells." - if key in self.syn_params: - grouped_view = view.groupby("controlled_by_param") - inds_of_comps = list(grouped_view.apply(lambda x: x.index.values)) - indices_per_param = jnp.stack(inds_of_comps) - param_vals = self.syn_params[key][indices_per_param] - elif key in view.columns: + if key in view.columns: view = view[~np.isnan(view[key])] grouped_view = view.groupby("controlled_by_param") inds_of_comps = list(grouped_view.apply(lambda x: x.index.values)) @@ -313,8 +317,8 @@ def get_all_parameters(self, trainable_params): for channel_params in list(channel.channel_params.keys()): params[channel_params] = self.jaxnodes[channel_params] - for key, val in self.syn_params.items(): - params[key] = val + for synapse_params in self.synapse_param_names: + params[synapse_params] = self.jaxedges[synapse_params] # Override with those parameters set by `.make_trainable()`. for inds, set_param in zip(self.indices_set_by_trainables, trainable_params): @@ -337,7 +341,6 @@ def initialized(self): def initialize(self): """Initialize the module.""" self.init_morph() - self.init_syns() return self def record(self): @@ -382,7 +385,7 @@ def insert(self, channel): def _insert(self, channel, view): self._append_channel_to_nodes(view, channel) - def init_syns(self): + def init_syns(self, connectivities): self.initialized_syns = True def init_morph(self): @@ -414,10 +417,10 @@ def step( # Step of the synapse. u, syn_voltage_terms, syn_constant_terms = self._step_synapse( u, - self.syn_classes, + self.synapses, params, delta_t, - self.syn_edges, + self.edges, ) # Voltage steps. @@ -460,7 +463,7 @@ def _step_channels( states, delta_t, channels: List[Channel], - channel_nodes: List[pd.DataFrame], + channel_nodes: pd.DataFrame, params: Dict[str, jnp.ndarray], ): """One step of integration of the channels.""" @@ -741,7 +744,7 @@ def stimulate(self, current): def set(self, key: str, val: float): """Set parameters of the pointer.""" - self.pointer._set(key, val, self.view) + self.pointer._set(key, val, self.view, self.pointer.nodes) def make_trainable(self, key: str, init_val: Optional[Union[float, list]] = None): """Make a parameter trainable.""" diff --git a/jaxley/modules/branch.py b/jaxley/modules/branch.py index 06b0fbd4..60347b06 100644 --- a/jaxley/modules/branch.py +++ b/jaxley/modules/branch.py @@ -57,6 +57,7 @@ def __init__( dict(parent_branch_index=[], child_branch_index=[]) ) self.initialize() + self.init_syns(None) self.initialized_conds = False def __getattr__(self, key): diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index 5d767e76..6709aac8 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -88,6 +88,7 @@ def __init__( ) self.initialize() + self.init_syns(None) self.initialized_conds = False def __getattr__(self, key): diff --git a/jaxley/modules/compartment.py b/jaxley/modules/compartment.py index 490a8825..9823569c 100644 --- a/jaxley/modules/compartment.py +++ b/jaxley/modules/compartment.py @@ -36,6 +36,7 @@ def __init__(self): # Initialize the module. self.initialize() + self.init_syns(None) self.initialized_conds = True def init_conds(self, params): @@ -74,7 +75,7 @@ def connect(self, post: "CompartmentView", synapse_type): we need to register it as a new synapse in a bunch of dictionaries which track synapse parameters, state and meta information. - Next, we register the new connection in the synapse dataframe (`.syn_edges`). + Next, we register the new connection in the synapse dataframe (`.edges`). Then, we update synapse parameter and state arrays with the new connection. Finally, we update synapse meta information. """ @@ -84,15 +85,15 @@ def connect(self, post: "CompartmentView", synapse_type): if is_new_type: # New type: index for the synapse type is one more than the currently # highest index. - max_ind = self.pointer.syn_edges["type_ind"].max() + 1 + max_ind = self.pointer.edges["type_ind"].max() + 1 type_ind = 0 if jnp.isnan(max_ind) else max_ind else: # Not a new type: search for the index that this type has previously had. - type_ind = self.pointer.syn_edges.query(f"type == '{synapse_name}'")[ + type_ind = self.pointer.edges.query(f"type == '{synapse_name}'")[ "type_ind" ].to_numpy()[0] - # The `syn_edges` dataframe expects the compartment as continuous `loc`, not + # The `edges` dataframe expects the compartment as continuous `loc`, not # as discrete compartment index (because the continuous `loc` is used for # plotting). Below, we cast the compartment index to its (rough) location. pre_comp = loc_of_index( @@ -101,11 +102,12 @@ def connect(self, post: "CompartmentView", synapse_type): post_comp = loc_of_index( post.view["global_comp_index"].to_numpy(), self.pointer.nseg ) + index = len(self.pointer.edges) # Update edges. - self.pointer.syn_edges = pd.concat( + self.pointer.edges = pd.concat( [ - self.pointer.syn_edges, + self.pointer.edges, pd.DataFrame( dict( pre_locs=pre_comp, @@ -131,42 +133,18 @@ def connect(self, post: "CompartmentView", synapse_type): ], ignore_index=True, ) - - # We add a column called index which is used by `adjust_view` of the - # `SynapseView` (see `network.py`). - self.pointer.syn_edges["index"] = list(self.pointer.syn_edges.index) - - # Update synaptic parameter array. + # Add parameters and states to the `.edges` table. + indices = list(range(index, index + 1)) for key in synapse_type.synapse_params: - param_vals = jnp.asarray([synapse_type.synapse_params[key]]) - if is_new_type: - # Register parameter array for new synapse type. - self.pointer.syn_params[key] = param_vals - else: - # Append to synaptic parameter array. - self.pointer.syn_params[key] = jnp.concatenate( - [self.pointer.syn_params[key], param_vals] - ) - - # Update synaptic state array. + param_val = synapse_type.synapse_params[key] + self.pointer.edges.loc[indices, key] = param_val for key in synapse_type.synapse_states: - state_vals = jnp.asarray([synapse_type.synapse_states[key]]) - if is_new_type: - # Register parameter array for new synapse type. - self.pointer.syn_states[key] = state_vals - else: - # Append to synaptic parameter array. - self.pointer.syn_states[key] = jnp.concatenate( - [self.pointer.syn_states[key], state_vals] - ) + state_val = synapse_type.synapse_states[key] + self.pointer.edges.loc[indices, key] = state_val # (Potentially) update variables that track meta information about synapses. if is_new_type: self.pointer.synapse_names.append(type(synapse_type).__name__) - self.pointer.synapse_param_names.append( - list(synapse_type.synapse_params.keys()) - ) - self.pointer.synapse_state_names.append( - list(synapse_type.synapse_states.keys()) - ) - self.pointer.syn_classes.append(synapse_type) + self.pointer.synapse_param_names += list(synapse_type.synapse_params.keys()) + self.pointer.synapse_state_names += list(synapse_type.synapse_states.keys()) + self.pointer.synapses.append(synapse_type) diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 5e4a04cb..ec8bbabd 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -1,9 +1,9 @@ import itertools from copy import deepcopy +from itertools import chain from typing import Callable, Dict, List, Optional, Tuple, Union import jax.numpy as jnp -import matplotlib.pyplot as plt import networkx as nx import numpy as np import pandas as pd @@ -37,23 +37,34 @@ def __init__( super().__init__() for cell in cells: self.xyzr += deepcopy(cell.xyzr) - self._append_synapses_to_params_and_state(connectivities) self.cells = cells - self.connectivities = connectivities - self.syn_classes = [ - connectivity.synapse_type for connectivity in connectivities - ] self.nseg = cells[0].nseg + + self.synapses = [connectivity.synapse_type for connectivity in connectivities] + + # TODO(@michaeldeistler): should we also track this for channels? self.synapse_names = [type(c.synapse_type).__name__ for c in connectivities] - self.synapse_param_names = [ - c.synapse_type.synapse_params.keys() for c in connectivities - ] - self.synapse_state_names = [ - c.synapse_type.synapse_states.keys() for c in connectivities - ] + self.synapse_param_names = list( + chain.from_iterable( + [list(c.synapse_type.synapse_params.keys()) for c in connectivities] + ) + ) + self.synapse_state_names = list( + chain.from_iterable( + [list(c.synapse_type.synapse_states.keys()) for c in connectivities] + ) + ) + + # Two columns: `parent_branch_index` and `child_branch_index`. One row per + # branch, apart from those branches which do not have a parent (i.e. + # -1 in parents). For every branch, tracks the global index of that branch + # (`child_branch_index`) and the global index of its parent + # (`parent_branch_index`). Needed at `init_syns()`. + self.branch_edges: Optional[pd.DataFrame] = None self.initialize() + self.init_syns(connectivities) self.nodes = pd.concat([c.nodes for c in cells], ignore_index=True) self._append_params_and_states(self.network_params, self.network_states) @@ -71,25 +82,6 @@ def __init__( self._gather_channels_from_constituents(cells) self.initialized_conds = False - def _append_synapses_to_params_and_state(self, connectivities): - for connectivity in connectivities: - for key in connectivity.synapse_type.synapse_params: - param_vals = jnp.asarray( - [ - connectivity.synapse_type.synapse_params[key] - for _ in connectivity.conns - ] - ) - self.syn_params[key] = param_vals - for key in connectivity.synapse_type.synapse_states: - state_vals = jnp.asarray( - [ - connectivity.synapse_type.synapse_states[key] - for _ in connectivity.conns - ] - ) - self.syn_states[key] = state_vals - def __getattr__(self, key): # Ensure that hidden methods such as `__deepcopy__` still work. if key.startswith("__"): @@ -102,7 +94,8 @@ def __getattr__(self, key): view["global_cell_index"] = view["cell_index"] return CellView(self, view) elif key in self.synapse_names: - return SynapseView(self, self.syn_edges, key) + type_index = self.synapse_names.index(key) + return SynapseView(self, self.edges, key, self.synapses[type_index]) elif key in self.group_views: return self.group_views[key] else: @@ -176,7 +169,7 @@ def init_conds(self, params): return cond_params - def init_syns(self): + def init_syns(self, connectivities): global_pre_comp_inds = [] global_post_comp_inds = [] global_pre_branch_inds = [] @@ -187,82 +180,76 @@ def init_syns(self): post_branch_inds = [] pre_cell_inds = [] post_cell_inds = [] - for connectivity in self.connectivities: + for i, connectivity in enumerate(connectivities): pre_cell_inds_, pre_inds, post_cell_inds_, post_inds = prepare_syn( connectivity.conns, self.nseg ) # Global compartment indizes. - global_pre_comp_inds.append( + global_pre_comp_inds = ( self.cumsum_nbranches[pre_cell_inds_] * self.nseg + pre_inds ) - global_post_comp_inds.append( + global_post_comp_inds = ( self.cumsum_nbranches[post_cell_inds_] * self.nseg + post_inds ) - global_pre_branch_inds.append( - [ - self.cumsum_nbranches[c.pre_cell_ind] + c.pre_branch_ind - for c in connectivity.conns - ] - ) - global_post_branch_inds.append( - [ - self.cumsum_nbranches[c.post_cell_ind] + c.post_branch_ind - for c in connectivity.conns - ] - ) + global_pre_branch_inds = [ + self.cumsum_nbranches[c.pre_cell_ind] + c.pre_branch_ind + for c in connectivity.conns + ] + global_post_branch_inds = [ + self.cumsum_nbranches[c.post_cell_ind] + c.post_branch_ind + for c in connectivity.conns + ] # Local compartment inds. - pre_locs.append(np.asarray([c.pre_loc for c in connectivity.conns])) - post_locs.append(np.asarray([c.post_loc for c in connectivity.conns])) + pre_locs = np.asarray([c.pre_loc for c in connectivity.conns]) + post_locs = np.asarray([c.post_loc for c in connectivity.conns]) # Local branch inds. - pre_branch_inds.append( - np.asarray([c.pre_branch_ind for c in connectivity.conns]) - ) - post_branch_inds.append( - np.asarray([c.post_branch_ind for c in connectivity.conns]) + pre_branch_inds = np.asarray([c.pre_branch_ind for c in connectivity.conns]) + post_branch_inds = np.asarray( + [c.post_branch_ind for c in connectivity.conns] ) - pre_cell_inds.append(pre_cell_inds_) - post_cell_inds.append(post_cell_inds_) + pre_cell_inds = pre_cell_inds_ + post_cell_inds = post_cell_inds_ + # for key in connectivity.synapse_type.synapse_states: - # Prepare synapses. - self.syn_edges = pd.DataFrame( - columns=[ - "pre_locs", - "pre_branch_index", - "pre_cell_index", - "post_locs", - "post_branch_index", - "post_cell_index", - "type", - "type_ind", - "global_pre_comp_index", - "global_post_comp_index", - "global_pre_branch_index", - "global_post_branch_index", - ] - ) - for i, connectivity in enumerate(self.connectivities): - self.syn_edges = pd.concat( + self.edges = pd.concat( [ - self.syn_edges, + self.edges, pd.DataFrame( dict( - pre_locs=pre_locs[i], - pre_branch_index=pre_branch_inds[i], - pre_cell_index=pre_cell_inds[i], - post_locs=post_locs[i], - post_branch_index=post_branch_inds[i], - post_cell_index=post_cell_inds[i], + pre_locs=pre_locs, + pre_branch_index=pre_branch_inds, + pre_cell_index=pre_cell_inds, + post_locs=post_locs, + post_branch_index=post_branch_inds, + post_cell_index=post_cell_inds, type=type(connectivity.synapse_type).__name__, type_ind=i, - global_pre_comp_index=global_pre_comp_inds[i], - global_post_comp_index=global_post_comp_inds[i], - global_pre_branch_index=global_pre_branch_inds[i], - global_post_branch_index=global_post_branch_inds[i], + global_pre_comp_index=global_pre_comp_inds, + global_post_comp_index=global_post_comp_inds, + global_pre_branch_index=global_pre_branch_inds, + global_post_branch_index=global_post_branch_inds, ) ), ], ) - self.syn_edges["index"] = list(self.syn_edges.index) + + # Add an `index` column. + self.edges = self.edges.reset_index(drop=True) + + # Add parameters and states to the `.edges` table. + index = 0 + for i, connectivity in enumerate(connectivities): + for key in connectivity.synapse_type.synapse_params: + param_val = connectivity.synapse_type.synapse_params[key] + indices = np.arange(index, index + len(connectivity.conns)) + self.edges.loc[indices, key] = param_val + + for key in connectivity.synapse_type.synapse_states: + state_val = connectivity.synapse_type.synapse_states[key] + indices = np.arange(index, index + len(connectivity.conns)) + self.edges.loc[indices, key] = state_val + + index += len(connectivity.conns) self.branch_edges = pd.DataFrame( dict( @@ -275,14 +262,14 @@ def init_syns(self): @staticmethod def _step_synapse( - u, + states, syn_channels, params, delta_t, edges: pd.DataFrame, ): """Perform one step of the synapses and obtain their currents.""" - voltages = u["voltages"] + voltages = states["voltages"] grouped_syns = edges.groupby("type", sort=False, group_keys=False) pre_syn_inds = grouped_syns["global_pre_comp_index"].apply(list) @@ -291,13 +278,29 @@ def _step_synapse( syn_voltage_terms = jnp.zeros_like(voltages) syn_constant_terms = jnp.zeros_like(voltages) - new_syn_states = [] for i, synapse_type in enumerate(syn_channels): assert ( synapse_names[i] == type(synapse_type).__name__ ), "Mixup in the ordering of synapses. Please create an issue on Github." - synapse_states, synapse_current_terms = synapse_type.step( - u, delta_t, voltages, params, np.asarray(pre_syn_inds[synapse_names[i]]) + + name = type(synapse_type).__name__ + synapse_param_names = list(synapse_type.synapse_params.keys()) + synapse_state_names = list(synapse_type.synapse_states.keys()) + indices = edges.loc[edges["type"] == name].index.values + + synapse_params = {} + for p in synapse_param_names: + synapse_params[p] = params[p][indices] + synapse_states = {} + for s in synapse_state_names: + synapse_states[s] = states[s][indices] + + states_updated, synapse_current_terms = synapse_type.step( + synapse_states, + delta_t, + voltages, + synapse_params, + np.asarray(pre_syn_inds[synapse_names[i]]), ) synapse_current_terms = postsyn_voltage_updates( voltages, @@ -306,14 +309,12 @@ def _step_synapse( ) syn_voltage_terms += synapse_current_terms[0] syn_constant_terms += synapse_current_terms[1] - new_syn_states.append(synapse_states) - # Rebuild synapse states. - for s in new_syn_states: - for key, val in s.items(): - u[key] = val + # Rebuild state. + for key, val in states_updated.items(): + states[key] = states[key].at[indices].set(val) - return u, syn_voltage_terms, syn_constant_terms + return states, syn_voltage_terms, syn_constant_terms def vis( self, @@ -359,10 +360,10 @@ def vis( morph_plot_kwargs=morph_plot_kwargs, ) - pre_locs = self.syn_edges["pre_locs"].to_numpy() - post_locs = self.syn_edges["post_locs"].to_numpy() - pre_branch = self.syn_edges["global_pre_branch_index"].to_numpy() - post_branch = self.syn_edges["global_post_branch_index"].to_numpy() + pre_locs = self.edges["pre_locs"].to_numpy() + post_locs = self.edges["post_locs"].to_numpy() + pre_branch = self.edges["global_pre_branch_index"].to_numpy() + post_branch = self.edges["global_post_branch_index"].to_numpy() dims_np = np.asarray(dims) @@ -423,8 +424,8 @@ def build_extents(*subset_sizes): else: graph.add_nodes_from(range(len(self.cells))) - pre_cell = self.syn_edges["pre_cell_index"].to_numpy() - post_cell = self.syn_edges["post_cell_index"].to_numpy() + pre_cell = self.edges["pre_cell_index"].to_numpy() + post_cell = self.edges["post_cell_index"].to_numpy() inds = np.stack([pre_cell, post_cell]).T graph.add_edges_from(inds) @@ -435,11 +436,11 @@ def build_extents(*subset_sizes): class SynapseView(View): """SynapseView.""" - def __init__(self, pointer, view, key): - view = view[view["type"] == key] - view = view.reset_index(drop=True) - view["index"] = list(view.index) - view = view.assign(controlled_by_param=view.index) + def __init__(self, pointer, view, key, synapse: "jx.Synapse"): + self.synapse = synapse + view = deepcopy(view[view["type"] == key]) + view["index"] = list(range(len(view))) + view = view.assign(controlled_by_param=view["index"]) super().__init__(pointer, view) def __call__(self, index: int): @@ -453,33 +454,40 @@ def show( states: bool = True, ): """Show synapses.""" - ind_of_params = self.view.index.values - nodes = deepcopy(self.view) + printable_nodes = deepcopy(self.view[["type", "type_ind"]]) - if not indices: - for key in nodes: - nodes = nodes.drop(key, axis=1) + if indices: + names = [ + "pre_locs", + "pre_branch_index", + "pre_cell_index", + "post_locs", + "post_branch_index", + "post_cell_index", + ] + printable_nodes[names] = self.view[names] if params: - for key, val in self.pointer.syn_params.items(): - nodes[key] = val[ind_of_params] - + for key in self.synapse.synapse_params.keys(): + printable_nodes[key] = self.view[key] if states: - for key, val in self.pointer.syn_states.items(): - nodes[key] = val[ind_of_params] + for key in self.synapse.synapse_states.keys(): + printable_nodes[key] = self.view[key] - return nodes + printable_nodes["controlled_by_param"] = self.view["controlled_by_param"] + return printable_nodes def set(self, key: str, val: float): """Set parameters of the pointer.""" assert ( key in self.pointer.synapse_param_names[self.view["type_ind"].values[0]] ), f"Parameter {key} does not exist in synapse of type {self.view['type'].values[0]}." - self.pointer._set(key, val, self.view) + self.pointer._set(key, val, self.view, self.pointer.edges) def make_trainable(self, key: str, init_val: Optional[Union[float, list]] = None): """Make a parameter trainable.""" assert ( key in self.pointer.synapse_param_names[self.view["type_ind"].values[0]] + or key in self.pointer.synapse_state_names[self.view["type_ind"].values[0]] ), f"Parameter {key} does not exist in synapse of type {self.view['type'].values[0]}." self.pointer._make_trainable(self.view, key, init_val) diff --git a/tests/test_make_trainable.py b/tests/test_make_trainable.py index 9d1b9106..11f200f0 100644 --- a/tests/test_make_trainable.py +++ b/tests/test_make_trainable.py @@ -103,9 +103,9 @@ def test_diverse_synapse_types(): assert np.all(all_parameters["length"] == 10.0) assert np.all(all_parameters["axial_resistivity"] == 5000.0) assert np.all(all_parameters["gS"][0] == 2.2) - assert np.all(all_parameters["gS"][1] == 2.2) - assert np.all(all_parameters["gC"][0] == 3.3) - assert np.all(all_parameters["gC"][1] == 4.4) + assert np.all(all_parameters["gS"][2] == 2.2) + assert np.all(all_parameters["gC"][1] == 3.3) + assert np.all(all_parameters["gC"][3] == 4.4) # Add another trainable parameter and test again. net.GlutamateSynapse(1).make_trainable("gS") @@ -118,7 +118,7 @@ def test_diverse_synapse_types(): net.to_jax() all_parameters = net.get_all_parameters(params) assert np.all(all_parameters["gS"][0] == 2.2) - assert np.all(all_parameters["gS"][1] == 5.5) + assert np.all(all_parameters["gS"][2] == 5.5) def test_make_all_trainable_corresponds_to_set(): diff --git a/tests/test_synapse_indexing.py b/tests/test_synapse_indexing.py index dc9dfa11..ab11aa20 100644 --- a/tests/test_synapse_indexing.py +++ b/tests/test_synapse_indexing.py @@ -24,19 +24,19 @@ def test_set_and_querying_params_one_type(): pre.connect(post, GlutamateSynapse()) net.set("gS", 0.15) - assert np.all(net.syn_params["gS"] == 0.15) + assert np.all(net.edges["gS"].to_numpy() == 0.15) net.GlutamateSynapse.set("gS", 0.32) - assert np.all(net.syn_params["gS"] == 0.32) + assert np.all(net.edges["gS"].to_numpy() == 0.32) net.GlutamateSynapse(1).set("gS", 0.18) - assert net.syn_params["gS"][1] == 0.18 - assert np.all(net.syn_params["gS"][np.asarray([0, 2, 3])] == 0.32) + assert net.edges["gS"].to_numpy()[1] == 0.18 + assert np.all(net.edges["gS"].to_numpy()[np.asarray([0, 2, 3])] == 0.32) net.GlutamateSynapse([2, 3]).set("gS", 0.12) - assert net.syn_params["gS"][0] == 0.32 - assert net.syn_params["gS"][1] == 0.18 - assert np.all(net.syn_params["gS"][np.asarray([2, 3])] == 0.12) + assert net.edges["gS"][0] == 0.32 + assert net.edges["gS"][1] == 0.18 + assert np.all(net.edges["gS"].to_numpy()[np.asarray([2, 3])] == 0.12) def test_set_and_querying_params_two_types(): @@ -53,29 +53,33 @@ def test_set_and_querying_params_two_types(): pre.connect(post, synapse) net.set("gS", 0.15) - assert np.all(net.syn_params["gS"] == 0.15) - assert np.all(net.syn_params["gC"] == 0.5) # 0.5 is the default value. + assert np.all(net.edges["gS"].to_numpy()[[0, 2]] == 0.15) + assert np.all( + net.edges["gC"].to_numpy()[[1, 3]] == 0.5 + ) # 0.5 is the default value. net.GlutamateSynapse.set("gS", 0.32) - assert np.all(net.syn_params["gS"] == 0.32) - assert np.all(net.syn_params["gC"] == 0.5) # 0.5 is the default value. + assert np.all(net.edges["gS"].to_numpy()[[0, 2]] == 0.32) + assert np.all( + net.edges["gC"].to_numpy()[[1, 3]] == 0.5 + ) # 0.5 is the default value. net.TestSynapse.set("gC", 0.18) - assert np.all(net.syn_params["gS"] == 0.32) - assert np.all(net.syn_params["gC"] == 0.18) + assert np.all(net.edges["gS"].to_numpy()[[0, 2]] == 0.32) + assert np.all(net.edges["gC"].to_numpy()[[1, 3]] == 0.18) net.GlutamateSynapse(1).set("gS", 0.24) - assert net.syn_params["gS"][0] == 0.32 - assert net.syn_params["gS"][1] == 0.24 - assert np.all(net.syn_params["gC"] == 0.18) + assert net.edges["gS"][0] == 0.32 + assert net.edges["gS"][2] == 0.24 + assert np.all(net.edges["gC"].to_numpy()[[1, 3]] == 0.18) net.GlutamateSynapse([0, 1]).set("gS", 0.27) - assert np.all(net.syn_params["gS"] == 0.27) - assert np.all(net.syn_params["gC"] == 0.18) + assert np.all(net.edges["gS"].to_numpy()[[0, 2]] == 0.27) + assert np.all(net.edges["gC"].to_numpy()[[1, 3]] == 0.18) net.TestSynapse([0, 1]).set("gC", 0.21) - assert np.all(net.syn_params["gS"] == 0.27) - assert np.all(net.syn_params["gC"] == 0.21) + assert np.all(net.edges["gS"].to_numpy()[[0, 2]] == 0.27) + assert np.all(net.edges["gC"].to_numpy()[[1, 3]] == 0.21) def test_shuffling_order_of_set():