diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index ee4bc998..12bb8f0d 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -2,6 +2,7 @@ # licensed under the Apache License Version 2.0, see from __future__ import annotations +import warnings from abc import ABC, abstractmethod from copy import deepcopy from itertools import chain @@ -41,6 +42,23 @@ from jaxley.utils.swc import build_radiuses_from_xyzr +def only_allow_module(func): + """Decorator to only allow the function to be called on Module instances. + + Decorates methods of Module that cannot be called on Views of Modules instances. + and have to be called on the Module itself.""" + + def wrapper(self, *args, **kwargs): + module_name = self.base.__class__.__name__ + method_name = func.__name__ + assert not isinstance( + self, View + ), f"{method_name} is currently not supported for Views. Call on the {module_name} base Module." + return func(self, *args, **kwargs) + + return wrapper + + class Module(ABC): """Module base class. @@ -49,7 +67,42 @@ class Module(ABC): Modules can be traversed and modified using the `at`, `cell`, `branch`, `comp`, `edge`, and `loc` methods. The `scope` method can be used to toggle between - global and local indices. + global and local indices. Traversal of Modules will return a `View` of itself, + that has a modified set of attributes, which only consider the part of the Module + that is in view. + + This has consequences for how to operate on Module and which changes take affect + where. The following guidelines should be followed (copied from `View`): + 1. We consider a Module to have everything in view. + 2. Views can display and keep track of how a module is traversed. But(!), + do not support making changes or setting variables. This still has to be + done in the base Module, i.e. `self.base`. In order to enssure that these + changes only affects whatever is currently in view `self._nodes_in_view`, + or `self._edges_in_view` among others have to be used. Operating on nodes + currently in view can for example be done with + `self.base.node.loc[self._nodes_in_view]` + 3. Every attribute of Module that changes based on what's in view, i.e. `xyzr`, + needs to modified when View is instantiated. I.e. `xyzr` of `cell.branch(0)`, + should be `[self.base.xyzr[0]]` This could be achieved via: + `[self.base.xyzr[b] for b in self._branches_in_view]`. + + + Example to make methods of Module compatible with View: + ``` + # use data in view to return something + def count_small_branches(self): + # no need to use self.base.attr + viewed indices, + # since no change is made to the attr in question (nodes) + comp_lens = self.nodes["length"] + branch_lens = comp_lens.groupby("global_branch_index").sum() + return np.sum(branch_lens < 10) + + # change data in view + def change_attr_in_view(self): + # changes to attrs have to be made via self.base.attr + viewed indices + a = func1(self.base.attr1[self._cells_in_view]) + b = func2(self.base.attr2[self._edges_in_view]) + self.base.attr3[self._branches_in_view] = a + b This base class defines the scaffold for all jaxley modules (compartments, branches, cells, networks). @@ -68,19 +121,15 @@ def __init__(self): self._edges_in_view: np.ndarray = None self.edges = pd.DataFrame( - columns=["global_edge_index"] - + [ - f"global_{lvl}_index" - for lvl in [ - "pre_comp", - "pre_branch", - "pre_cell", - "post_comp", - "post_branch", - "post_cell", - ] + columns=[ + "global_edge_index", + "global_pre_comp_index", + "global_post_comp_index", + "pre_locs", + "post_locs", + "type", + "type_ind", ] - + ["pre_locs", "post_locs", "type", "type_ind"] ) self.cumsum_nbranches: Optional[np.ndarray] = None @@ -162,11 +211,19 @@ def __getattr__(self, key): # intercepts calls to synapse types if key in self.base.synapse_names: - syn_inds = self.edges.index[self.edges["type"] == key].to_numpy() + syn_inds = self.edges[self.edges["type"] == key][ + "global_edge_index" + ].to_numpy() + orig_scope = self._scope view = ( - self.edge(syn_inds) if key in self.synapse_names else self.select(None) + self.scope("global").edge(syn_inds).scope(orig_scope) + if key in self.synapse_names + else self.select(None) ) view._set_controlled_by_param(key) # overwrites param set by edge + # Ensure synapse param sharing works with `edge` + # `edge` will be removed as part of #463 + view.edges["local_edge_index"] = np.arange(len(view.edges)) return view def _childviews(self) -> List[str]: @@ -177,19 +234,21 @@ def _childviews(self) -> List[str]: children = levels[levels.index(self._current_view) + 1 :] return children + def _has_childview(self, key: str) -> bool: + child_views = self._childviews() + return key in child_views + def __getitem__(self, index): - supported_lvls = ["network", "cell", "branch"] # cannot index into comp + """Lazy indexing of the module.""" + supported_parents = ["network", "cell", "branch"] # cannot index into comp - # TODO: SHOULD WE ALLOW GROUPVIEW TO BE INDEXED? - # IF YES, UNDER WHICH CONDITIONS? - is_group_view = self._current_view in self.groups + not_group_view = self._current_view not in self.groups assert ( - self._current_view in supported_lvls or is_group_view - ), "Lazy indexing is not supported for this View/Module." + self._current_view in supported_parents or not_group_view + ), "Lazy indexing is only supported for `Network`, `Cell`, `Branch` and Views thereof." index = index if isinstance(index, tuple) else (index,) - module_or_view = self.base if is_group_view else self - child_views = module_or_view._childviews() + child_views = self._childviews() assert len(index) <= len(child_views), "Too many indices." view = self for i, child in zip(index, child_views): @@ -227,39 +286,37 @@ def reindex_a_by_b( return df index_names = ["cell_index", "branch_index", "comp_index"] # order is important - for obj, prefix in zip( - [self.nodes, self.edges, self.edges], ["", "pre_", "post_"] - ): - global_idx_cols = [f"global_{prefix}{name}" for name in index_names] - local_idx_cols = [f"local_{prefix}{name}" for name in index_names] - idcs = obj[global_idx_cols] + global_idx_cols = [f"global_{name}" for name in index_names] + local_idx_cols = [f"local_{name}" for name in index_names] + idcs = self.nodes[global_idx_cols] - idcs = reindex_a_by_b(idcs, global_idx_cols[0]) - idcs = reindex_a_by_b(idcs, global_idx_cols[1], global_idx_cols[0]) - idcs = reindex_a_by_b(idcs, global_idx_cols[2], global_idx_cols[:2]) - idcs.columns = [col.replace("global", "local") for col in global_idx_cols] - obj[local_idx_cols] = idcs[local_idx_cols].astype(int) + # update local indices of nodes + idcs = reindex_a_by_b(idcs, global_idx_cols[0]) + idcs = reindex_a_by_b(idcs, global_idx_cols[1], global_idx_cols[0]) + idcs = reindex_a_by_b(idcs, global_idx_cols[2], global_idx_cols[:2]) + idcs.columns = [col.replace("global", "local") for col in global_idx_cols] + self.nodes[local_idx_cols] = idcs[local_idx_cols].astype(int) # move indices to the front of the dataframe; move controlled_by_param to the end + # move indices of current scope to the front and the others to the back + not_scope = "global" if self._scope == "local" else "local" + self.nodes = reorder_cols( + self.nodes, [f"{self._scope}_{name}" for name in index_names], first=True + ) self.nodes = reorder_cols( - self.nodes, - [ - f"{scope}_{name}" - for scope in ["global", "local"] - for name in index_names - ], + self.nodes, [f"{not_scope}_{name}" for name in index_names], first=False ) + + self.edges = reorder_cols(self.edges, ["global_edge_index"]) self.nodes = reorder_cols(self.nodes, ["controlled_by_param"], first=False) - self.edges["local_edge_index"] = rerank(self.edges["global_edge_index"]) - self.edges = reorder_cols(self.edges, ["global_edge_index", "local_edge_index"]) self.edges = reorder_cols(self.edges, ["controlled_by_param"], first=False) def _init_view(self): """Init attributes critical for View. Needs to be called at init of a Module.""" - lvl = self.__class__.__name__.lower() - self._current_view = "comp" if lvl == "compartment" else lvl + parent = self.__class__.__name__.lower() + self._current_view = "comp" if parent == "compartment" else parent self._nodes_in_view = self.nodes.index.to_numpy() self._edges_in_view = self.edges.index.to_numpy() self.nodes["controlled_by_param"] = 0 @@ -347,7 +404,7 @@ def _set_controlled_by_param(self, key: str): key: key specifying group / view that is in control of the params.""" if key in ["comp", "branch", "cell"]: self.nodes["controlled_by_param"] = self.nodes[f"global_{key}_index"] - self.edges["controlled_by_param"] = self.edges[f"global_pre_{key}_index"] + self.edges["controlled_by_param"] = 0 elif key == "edge": self.edges["controlled_by_param"] = np.arange(len(self.edges)) elif key == "filter": @@ -413,6 +470,8 @@ def _at_nodes(self, key: str, idx: Any) -> View: Keys can be `cell`, `branch`, `comp` and determine which index is used to filter. """ + base_name = self.base.__class__.__name__ + assert self.base._has_childview(key), f"{base_name} does not support {key}." idx = self._reformat_index(idx) idx = self.nodes[self._scope + f"_{key}_index"] if is_str_all(idx) else idx where = self.nodes[self._scope + f"_{key}_index"].isin(idx) @@ -476,28 +535,6 @@ def edge(self, idx: Any) -> View: View of the module at the specified edge index.""" return self._at_edges("edge", idx) - # TODO: pre and post could just modify scope - # -> self.scope=self.scope+"_pre" and then call edge? - # def pre(self, idx: Any) -> View: - # """Return a View of the module at the selected pre-synaptic compartments(s). - - # Args: - # idx: index of the edge to view. - - # Returns: - # View of the module filtered by the selected pre-comp index.""" - # return self._at_edges("edge", idx) - - # def post(self, idx: Any) -> View: - # """Return a View of the module at the selected post-synaptic compartments(s). - - # Args: - # idx: index of the edge to view. - - # Returns: - # View of the module filtered by the selected post-comp index.""" - # return self._at_edges("edge", idx) - def loc(self, at: Any) -> View: """Return a View of the module at the selected branch location(s). @@ -613,17 +650,18 @@ def copy( Returns: A part of the module or a copied view of it.""" view = deepcopy(self) - # TODO: add reset_index, i.e. for parents, nodes, edges etc. such that they + warnings.warn("This method is experimental, use at your own risk.") + # TODO FROM #447: add reset_index, i.e. for parents, nodes, edges etc. such that they # start from 0/-1 and are contiguous if as_module: raise NotImplementedError("Not yet implemented.") - # TODO: initialize a new module with the same attributes + # initialize a new module with the same attributes return view @property def view(self): """Return view of the module.""" - return View(self, self._nodes_in_view) + return View(self, self._nodes_in_view, self._edges_in_view) @property def _module_type(self): @@ -661,8 +699,9 @@ def _gather_channels_from_constituents(self, constituents: List): name = channel._name self.base.nodes.loc[self.nodes[name].isna(), name] = False - # TODO: Make this work for View? + @only_allow_module def to_jax(self): + # TODO FROM #447: Make this work for View? """Move `.nodes` to `.jaxnodes`. Before the actual simulation is run (via `jx.integrate`), all parameters of @@ -690,7 +729,7 @@ def to_jax(self): def show( self, - param_names: Optional[Union[str, List[str]]] = None, # TODO. + param_names: Optional[Union[str, List[str]]] = None, *, indices: bool = True, params: bool = True, @@ -735,6 +774,7 @@ def show( return nodes[cols] + @only_allow_module def _init_morph(self): """Initialize the morphology such that it can be processed by the solvers.""" self._init_morph_jaxley_spsolve() @@ -846,7 +886,6 @@ def set_ncomp( and len(self._branches_in_view) == len(self.base._branches_in_view) ), "This is not allowed for cells." - # TODO: MAKE THIS NICER # Update all attributes that are affected by compartment structure. view = self.nodes.copy() all_nodes = self.base.nodes @@ -1107,13 +1146,19 @@ def distance(self, endpoint: "View") -> float: end_xyz = endpoint.xyzr[0][0, :3] return np.sqrt(np.sum((start_xyz - end_xyz) ** 2)) - # TODO: MAKE THIS WORK FOR VIEW? def delete_trainables(self): """Removes all trainable parameters from the module.""" - assert isinstance(self, Module), "Only supports modules." - self.base.indices_set_by_trainables = [] - self.base.trainable_params = [] - self.base.num_trainable_params = 0 + + if isinstance(self, View): + trainables_and_inds = self._filter_trainables(is_viewed=False) + self.base.indices_set_by_trainables = trainables_and_inds[0] + self.base.trainable_params = trainables_and_inds[1] + self.base.num_trainable_params -= self.num_trainable_params + else: + self.base.indices_set_by_trainables = [] + self.base.trainable_params = [] + self.base.num_trainable_params = 0 + self._update_view() def add_to_group(self, group_name: str): """Add a view of the module to a group. @@ -1134,7 +1179,6 @@ def add_to_group(self, group_name: str): np.concatenate([self.base.groups[group_name], self._nodes_in_view]) ) - # TODO: MAKE THIS WORK FOR VIEW? def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: """Get all trainable parameters. @@ -1144,12 +1188,13 @@ def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: A list of all trainable parameters in the form of [{"gNa": jnp.array([0.1, 0.2, 0.3])}, ...]. """ - return self.base.trainable_params + return self.trainable_params - # TODO: MAKE THIS WORK FOR VIEW? + @only_allow_module def get_all_parameters( self, pstate: List[Dict], voltage_solver: str ) -> Dict[str, jnp.ndarray]: + # TODO FROM #447: MAKE THIS WORK FOR VIEW? """Return all parameters (and coupling conductances) needed to simulate. Runs `_compute_axial_conductances()` and return every parameter that is needed @@ -1200,14 +1245,13 @@ def get_all_parameters( # This is needed since SynapseViews worked differently before. # This mimics the old behaviour and tranformes the new indices # to the old indices. - # TODO: Longterm this should be gotten rid of. + # TODO FROM #447: Longterm this should be gotten rid of. # Instead edges should work similar to nodes (would also allow for # param sharing). + synapse_inds = self.base.edges.groupby("type").rank()["global_edge_index"] + synapse_inds = (synapse_inds.astype(int) - 1).to_numpy() if key in self.base.synapse_param_names: - syn_name_from_param = key.split("_")[0] - syn_edges = self.__getattr__(syn_name_from_param).edges - inds = syn_edges.loc[inds.reshape(-1)]["local_edge_index"].values - inds = inds.reshape(-1, 1) + inds = synapse_inds[inds] if key in params: # Only parameters, not initial states. # `inds` is of shape `(num_params, num_comps_per_param)`. @@ -1222,8 +1266,9 @@ def get_all_parameters( ) return params - # TODO: MAKE THIS WORK FOR VIEW? + @only_allow_module def _get_states_from_nodes_and_edges(self) -> Dict[str, jnp.ndarray]: + # TODO FROM #447: MAKE THIS WORK FOR VIEW? """Return states as they are set in the `.nodes` and `.edges` tables.""" self.base.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`. states = {"v": self.base.jaxnodes["v"]} @@ -1235,10 +1280,11 @@ def _get_states_from_nodes_and_edges(self) -> Dict[str, jnp.ndarray]: states[synapse_states] = self.base.jaxedges[synapse_states] return states - # TODO: MAKE THIS WORK FOR VIEW? + @only_allow_module def get_all_states( self, pstate: List[Dict], all_params, delta_t: float ) -> Dict[str, jnp.ndarray]: + # TODO FROM #447: MAKE THIS WORK FOR VIEW? """Get the full initial state of the module from jaxnodes and trainables. Args: @@ -1284,8 +1330,9 @@ def _initialize(self): self._init_morph() return self - # TODO: MAKE THIS WORK FOR VIEW? + @only_allow_module def init_states(self, delta_t: float = 0.025): + # TODO FROM #447: MAKE THIS WORK FOR VIEW? """Initialize all mechanisms in their steady state. This considers the voltages and parameters of each compartment. @@ -1400,9 +1447,11 @@ def _init_morph_for_debugging(self): self.base.debug_states["par_inds"] = self.base.par_inds def record(self, state: str = "v", verbose=True): - in_view = ( - self._edges_in_view if state in self.edges.columns else self._nodes_in_view - ) + in_view = None + in_view = self._edges_in_view if state in self.edges.columns else in_view + in_view = self._nodes_in_view if state in self.nodes.columns else in_view + assert in_view is not None, "State not found in nodes or edges." + new_recs = pd.DataFrame(in_view, columns=["rec_index"]) new_recs["state"] = state self.base.recordings = pd.concat([self.base.recordings, new_recs]) @@ -1413,11 +1462,31 @@ def record(self, state: str = "v", verbose=True): f"Added {len(in_view)-sum(has_duplicates)} recordings. See `.recordings` for details." ) - # TODO: MAKE THIS WORK FOR VIEW? + def _update_view(self): + """Update the attrs of the view after changes in the base module.""" + if isinstance(self, View): + scope = self._scope + current_view = self._current_view + # copy dict of new View. For some reason doing self = View(self) + # did not work. + self.__dict__ = View( + self.base, self._nodes_in_view, self._edges_in_view + ).__dict__ + + # retain the scope and current_view of the previous view + self._scope = scope + self._current_view = current_view + def delete_recordings(self): """Removes all recordings from the module.""" - assert isinstance(self, Module), "Only supports modules." - self.base.recordings = pd.DataFrame().from_dict({}) + if isinstance(self, View): + base_recs = self.base.recordings + self.base.recordings = base_recs[ + ~base_recs.isin(self.recordings).all(axis=1) + ] + self._update_view() + else: + self.base.recordings = pd.DataFrame().from_dict({}) def stimulate(self, current: Optional[jnp.ndarray] = None, verbose: bool = True): """Insert a stimulus into the compartment. @@ -1565,19 +1634,27 @@ def _data_external_input( return (state_name, external_input, inds) - # TODO: MAKE THIS WORK FOR VIEW? def delete_stimuli(self): """Removes all stimuli from the module.""" - assert isinstance(self, Module), "Only supports modules." - self.base.externals.pop("i", None) - self.base.external_inds.pop("i", None) + self.delete_clamps("i") - # TODO: MAKE THIS WORK FOR VIEW? def delete_clamps(self, state_name: str): """Removes all clamps of the given state from the module.""" - assert isinstance(self, Module), "Only supports modules." - self.base.externals.pop(state_name, None) - self.base.external_inds.pop(state_name, None) + if state_name in self.externals: + keep_inds = ~np.isin( + self.base.external_inds[state_name], self._nodes_in_view + ) + base_exts = self.base.externals + base_exts_inds = self.base.external_inds + if np.all(~keep_inds): + base_exts.pop(state_name, None) + base_exts_inds.pop(state_name, None) + else: + base_exts[state_name] = base_exts[state_name][keep_inds] + base_exts_inds[state_name] = base_exts_inds[state_name][keep_inds] + self._update_view() + else: + pass # does not have to be deleted if not in externals def insert(self, channel: Channel): """Insert a channel into the module. @@ -1607,6 +1684,7 @@ def insert(self, channel: Channel): for key in channel.channel_states: self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_states[key] + @only_allow_module def step( self, u: Dict[str, jnp.ndarray], @@ -2090,14 +2168,19 @@ def move_to( "NaN coordinate values detected. Shift amounts cannot be computed. Please run compute_xyzr() or assign initial coordinate values." ) - root_xyz_cells = np.array([c.xyzr[0][0, :3] for c in self.cells]) + # can only iterate over cells for networks + # lambda makes sure that generator can be created multiple times + base_is_net = self.base._current_view == "network" + cells = lambda: (self.cells if base_is_net else [self]) + + root_xyz_cells = np.array([c.xyzr[0][0, :3] for c in cells()]) root_xyz = root_xyz_cells[0] if isinstance(x, float) else root_xyz_cells move_by = np.array([x, y, z]).T - root_xyz if len(move_by.shape) == 1: move_by = np.tile(move_by, (len(self._cells_in_view), 1)) - for cell, offset in zip(self.cells, move_by): + for cell, offset in zip(cells(), move_by): for idx in cell._branches_in_view: self.base.xyzr[idx][:, :3] += offset if update_nodes: @@ -2141,6 +2224,13 @@ class View(Module): allow to target specific parts of a Module, i.e. setting parameters for parts of a cell. + Almost all methods in View are concerned with updating the attributes of the + base Module, i.e. `self.base`, based on the indices in view. For example, + `_channels_in_view` lists all channels, finds the subset set to `True` in + `self.nodes` (currently in view) and returns the updated list such that we can set + `self.channels = self._channels_in_view()`. + + To allow seamless operation on Views and Modules as if they were the same, the following needs to be ensured: 1. We consider a Module to have everything in view. @@ -2173,7 +2263,6 @@ def change_attr_in_view(self): a = func1(self.base.attr1[self._cells_in_view]) b = func2(self.base.attr2[self._edges_in_view]) self.base.attr3[self._branches_in_view] = a + b - ``` """ def __init__( @@ -2208,6 +2297,8 @@ def __init__( self.cumsum_nbranches = jnp.cumsum(np.asarray(self.nbranches_per_cell)) self.comb_branches_in_each_level = pointer.comb_branches_in_each_level self.branch_edges = pointer.branch_edges.loc[self._branch_edges_in_view] + self.nseg_per_branch = self.base.nseg_per_branch[self._branches_in_view] + self.cumsum_nseg = cumsum_leading_zero(self.nseg_per_branch) self.synapse_names = np.unique(self.edges["type"]).tolist() self._set_synapses_in_view(pointer) @@ -2242,7 +2333,7 @@ def __init__( self._current_view = "view" # if not instantiated via `comp`, `cell` etc. self._update_local_indices() - # TODO: + # TODO FROM #447: self.debug_states = pointer.debug_states if len(self.nodes) == 0: @@ -2251,7 +2342,7 @@ def __init__( def _set_inds_in_view( self, pointer: Union[Module, View], nodes: np.ndarray, edges: np.ndarray ): - """Set nodes and edge indices that are in view.""" + """Update node and edge indices to list only those currently in view.""" # set nodes and edge indices in view has_node_inds = nodes is not None has_edge_inds = edges is not None @@ -2287,6 +2378,7 @@ def _set_inds_in_view( self._edges_in_view = edges def _jax_arrays_in_view(self, pointer: Union[Module, View]): + """Update jaxnodes/jaxedges to show only those currently in view.""" a_intersects_b_at = lambda a, b: jnp.intersect1d(a, b, return_indices=True)[1] jaxnodes = {} if pointer.jaxnodes is not None else None if self.jaxnodes is not None: @@ -2311,6 +2403,7 @@ def _jax_arrays_in_view(self, pointer: Union[Module, View]): return jaxnodes, jaxedges def _set_externals_in_view(self): + """Update external inputs to show only those currently in view.""" self.externals = {} self.external_inds = {} for (name, inds), data in zip( @@ -2322,75 +2415,90 @@ def _set_externals_in_view(self): self.externals[name] = data[in_view] self.external_inds[name] = inds_in_view - def _set_trainables_in_view(self): - trainable_inds = self.base.indices_set_by_trainables - trainable_inds = ( - np.unique(np.hstack([inds.reshape(-1) for inds in trainable_inds])) - if len(trainable_inds) > 0 - else [] - ) - trainable_node_inds_in_view = np.intersect1d( - trainable_inds, self._nodes_in_view - ) + def _filter_trainables( + self, is_viewed: bool = True + ) -> Tuple[List[np.ndarray], List[Dict]]: + """Filters the trainables inside and outside of the view. + Trainables are split between `indices_set_by_trainables` and `trainable_params` + and can be shared between mutliple compartments / branches etc, which makes it + difficult to filter them based on the current view w.o. destroying the + original structure. + + This method filters `indices_set_by_trainables` for the indices that are + currently in view (or not in view) and returns the corresponding trainable + parameters and indices such that the sharing behavior is preserved as much as + possible. + + Args: + is_viewed: Toggles between returning the trainables and inds + currently inside or outside of the scope of View.""" índices_set_by_trainables_in_view = [] trainable_params_in_view = [] for inds, params in zip( self.base.indices_set_by_trainables, self.base.trainable_params ): - in_view = np.isin(inds, trainable_node_inds_in_view) - + pkey, pval = next(iter(params.items())) + trainable_inds_in_view = None + if pkey in sum( + [list(c.channel_params.keys()) for c in self.base.channels], [] + ): + trainable_inds_in_view = np.intersect1d(inds, self._nodes_in_view) + elif pkey in sum( + [list(s.synapse_params.keys()) for s in self.base.synapses], [] + ): + trainable_inds_in_view = np.intersect1d(inds, self._edges_in_view) + + in_view = is_viewed == np.isin(inds, trainable_inds_in_view) completely_in_view = in_view.all(axis=1) - índices_set_by_trainables_in_view.append(inds[completely_in_view]) + partially_in_view = in_view.any(axis=1) & ~completely_in_view + trainable_params_in_view.append( {k: v[completely_in_view] for k, v in params.items()} ) - - partially_in_view = in_view.any(axis=1) & ~completely_in_view - índices_set_by_trainables_in_view.append( - inds[partially_in_view][in_view[partially_in_view]] - ) trainable_params_in_view.append( {k: v[partially_in_view] for k, v in params.items()} ) - # TODO: working but ugly. maybe integrate into above loop - trainable_names = np.array([next(iter(d)) for d in self.base.trainable_params]) - is_syn_trainable_in_view = np.isin(trainable_names, self.synapse_param_names) - syn_trainable_names_in_view = trainable_names[is_syn_trainable_in_view] - syn_trainable_inds_in_view = np.intersect1d( - syn_trainable_names_in_view, trainable_names, return_indices=True - )[2] - for idx in syn_trainable_inds_in_view: - syn_name = trainable_names[idx].split("_")[0] - syn_edges = self.base.edges[self.base.edges["type"] == syn_name] - syn_inds = np.arange(len(syn_edges)) - syn_inds_in_view = syn_inds[np.isin(syn_edges.index, self._edges_in_view)] - - syn_trainable_params_in_view = { - k: v[syn_inds_in_view] - for k, v in self.base.trainable_params[idx].items() - } - trainable_params_in_view.append(syn_trainable_params_in_view) - syn_inds_set_by_trainables_in_view = self.base.indices_set_by_trainables[ - idx - ][syn_inds_in_view] - índices_set_by_trainables_in_view.append(syn_inds_set_by_trainables_in_view) + índices_set_by_trainables_in_view.append(inds[completely_in_view]) + partial_inds = inds[partially_in_view][in_view[partially_in_view]] + + # the indexing i.e. `inds[partially_in_view]` reshapes `inds`. Since the shape + # determines how parameters are shared, `inds` has to be returned to its + # original shape. + if inds.shape[0] > 1 and partial_inds.shape != (0,): + partial_inds = partial_inds.reshape(-1, 1) + if inds.shape[1] > 1 and partial_inds.shape != (0,): + partial_inds = partial_inds.reshape(1, -1) - self.indices_set_by_trainables = [ + índices_set_by_trainables_in_view.append(partial_inds) + + indices_set_by_trainables = [ inds for inds in índices_set_by_trainables_in_view if len(inds) > 0 ] - self.trainable_params = [ + trainable_params = [ p for p in trainable_params_in_view if len(next(iter(p.values()))) > 0 ] + return indices_set_by_trainables, trainable_params + + def _set_trainables_in_view(self): + """Set `trainable_params` and `indices_set_by_trainables` to show only those in view.""" + trainables = self._filter_trainables() + + # note for `branch.comp(0).make_trainable("X"); branch.make_trainable("X")` + # `view = branch.comp(0)` will have duplicate training params. + self.indices_set_by_trainables = trainables[0] + self.trainable_params = trainables[1] def _channels_in_view(self, pointer: Union[Module, View]) -> List[Channel]: + """Set channels to show only those in view.""" names = [name._name for name in pointer.channels] channel_in_view = self.nodes[names].any(axis=0) channel_in_view = channel_in_view[channel_in_view].index return [c for c in pointer.channels if c._name in channel_in_view] def _set_synapses_in_view(self, pointer: Union[Module, View]): + """Set synapses to show only those in view.""" viewed_synapses = [] viewed_params = [] viewed_states = [] @@ -2412,6 +2520,9 @@ def _nbranches_per_cell_in_view(self) -> np.ndarray: return cell_nodes["global_branch_index"].nunique().to_list() def _xyzr_in_view(self) -> List[np.ndarray]: + """Return xyzr coordinates of every branch that is in `_branches_in_view`. + + If a branch is not completely in view, the coordinates are interpolated.""" xyzr = [self.base.xyzr[i] for i in self._branches_in_view].copy() # Currently viewing with `.loc` will show the closest compartment @@ -2461,6 +2572,7 @@ def _comps_in_view(self) -> np.ndarray: @property def _branch_edges_in_view(self) -> np.ndarray: + """Lists the global branch edge indices which are currently part of the view.""" incl_branches = self.nodes["global_branch_index"].unique() pre = self.base.branch_edges["parent_branch_index"].isin(incl_branches) post = self.base.branch_edges["child_branch_index"].isin(incl_branches) diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index b4de9771..c225545e 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -52,7 +52,7 @@ def __init__( for cell in cells: self.xyzr += deepcopy(cell.xyzr) - self.cells_list = cells # TODO: TEMPORARY FIX, REMOVE BY ADDING ATTRS TO VIEW (solve_indexer.children_in_level) + self._cells_list = cells self.nseg_per_branch = np.concatenate([cell.nseg_per_branch for cell in cells]) self.nseg = int(np.max(self.nseg_per_branch)) self.cumsum_nseg = cumsum_leading_zero(self.nseg_per_branch) @@ -106,6 +106,7 @@ def __init__( self._gather_channels_from_constituents(cells) self._initialize() + del self._cells_list def __repr__(self): return f"{type(self).__name__} with {len(self.channels)} different channels and {len(self.synapses)} synapses. Use `.nodes` or `.edges` for details." @@ -119,18 +120,18 @@ def _init_morph_jaxley_spsolve(self): children_in_level = merge_cells( self.cumsum_nbranches, self.cumsum_nbranchpoints_per_cell, - [cell.solve_indexer.children_in_level for cell in self.cells_list], + [cell.solve_indexer.children_in_level for cell in self._cells_list], exclude_first=False, ) parents_in_level = merge_cells( self.cumsum_nbranches, self.cumsum_nbranchpoints_per_cell, - [cell.solve_indexer.parents_in_level for cell in self.cells_list], + [cell.solve_indexer.parents_in_level for cell in self._cells_list], exclude_first=False, ) padded_cumsum_nseg = cumsum_leading_zero( np.concatenate( - [np.diff(cell.solve_indexer.cumsum_nseg) for cell in self.cells_list] + [np.diff(cell.solve_indexer.cumsum_nseg) for cell in self._cells_list] ) ) @@ -171,12 +172,12 @@ def _init_morph_jax_spsolve(self): `type == 4`: child-compartment --> branchpoint """ self._cumsum_nseg_per_cell = cumsum_leading_zero( - jnp.asarray([cell.cumsum_nseg[-1] for cell in self.cells_list]) + jnp.asarray([cell.cumsum_nseg[-1] for cell in self.cells]) ) self._comp_edges = pd.DataFrame() # Add all the internal nodes. - for offset, cell in zip(self._cumsum_nseg_per_cell, self.cells_list): + for offset, cell in zip(self._cumsum_nseg_per_cell, self._cells_list): condition = cell._comp_edges["type"].to_numpy() == 0 rows = cell._comp_edges[condition] self._comp_edges = pd.concat( @@ -188,7 +189,7 @@ def _init_morph_jax_spsolve(self): for offset, offset_branchpoints, cell in zip( self._cumsum_nseg_per_cell, self.cumsum_nbranchpoints_per_cell, - self.cells_list, + self._cells_list, ): offset_within_cell = cell.cumsum_nseg[-1] condition = cell._comp_edges["type"].isin([1, 2]) @@ -210,7 +211,7 @@ def _init_morph_jax_spsolve(self): for offset, offset_branchpoints, cell in zip( self._cumsum_nseg_per_cell, self.cumsum_nbranchpoints_per_cell, - self.cells_list, + self._cells_list, ): offset_within_cell = cell.cumsum_nseg[-1] condition = cell._comp_edges["type"].isin([3, 4]) @@ -228,9 +229,6 @@ def _init_morph_jax_spsolve(self): ignore_index=True, ) - # Note that, unlike in `cell.py`, we cannot delete `self.cells_list` here because - # it is used in plotting. - # Convert comp_edges to the index format required for `jax.sparse` solvers. n_nodes, data_inds, indices, indptr = comp_edges_to_indices(self._comp_edges) self._n_nodes = n_nodes @@ -473,8 +471,11 @@ def vis( pre_locs = self.edges["pre_locs"].to_numpy() post_locs = self.edges["post_locs"].to_numpy() - pre_branch = self.edges["global_pre_branch_index"].to_numpy() - post_branch = self.edges["global_post_branch_index"].to_numpy() + pre_comp = self.edges["global_pre_comp_index"].to_numpy() + nodes = self.nodes.set_index("global_comp_index") + pre_branch = nodes.loc[pre_comp, "global_branch_index"].to_numpy() + post_comp = self.edges["global_post_comp_index"].to_numpy() + post_branch = nodes.loc[post_comp, "global_branch_index"].to_numpy() dims_np = np.asarray(dims) @@ -533,10 +534,13 @@ def build_extents(*subset_sizes): for i, layer in enumerate(layers): graph.add_nodes_from(layer, layer=i) else: - graph.add_nodes_from(range(len(self.cells_list))) + graph.add_nodes_from(range(len(self._cells_in_view))) - pre_cell = self.edges["global_pre_cell_index"].to_numpy() - post_cell = self.edges["global_post_cell_index"].to_numpy() + pre_comp = self.edges["global_pre_comp_index"].to_numpy() + nodes = self.nodes.set_index("global_comp_index") + pre_cell = nodes.loc[pre_comp, "global_cell_index"].to_numpy() + post_comp = self.edges["global_post_comp_index"].to_numpy() + post_cell = nodes.loc[post_comp, "global_cell_index"].to_numpy() inds = np.stack([pre_cell, post_cell]).T graph.add_edges_from(inds) @@ -578,11 +582,10 @@ def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type): ) # Define new synapses. Each row is one synapse. - cols = ["comp_index", "branch_index", "cell_index"] - pre_nodes = pre_nodes[[f"global_{col}" for col in cols]] - pre_nodes.columns = [f"global_pre_{col}" for col in cols] - post_nodes = post_nodes[[f"global_{col}" for col in cols]] - post_nodes.columns = [f"global_post_{col}" for col in cols] + pre_nodes = pre_nodes[["global_comp_index"]] + pre_nodes.columns = ["global_pre_comp_index"] + post_nodes = post_nodes[["global_comp_index"]] + post_nodes.columns = ["global_post_comp_index"] new_rows = pd.concat( [ global_edge_index, @@ -591,7 +594,6 @@ def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type): ], axis=1, ) - new_rows["local_edge_index"] = new_rows["global_edge_index"] new_rows["type"] = synapse_name new_rows["type_ind"] = type_ind new_rows["pre_locs"] = pre_loc diff --git a/tests/test_connection.py b/tests/test_connection.py index 4d1bd37d..5178d24b 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -57,17 +57,13 @@ def test_connect(): # check if all connections are made correctly first_set_edges = net2.edges.iloc[:8] - # TODO: VERIFY THAT THIS IS INTENDED BEHAVIOUR! @Michael - assert ( - ( - first_set_edges[["global_pre_branch_index", "global_post_branch_index"]] - == (4, 8) - ) - .all() - .all() - ) - assert (first_set_edges["global_pre_cell_index"] == 1).all() - assert (first_set_edges["global_post_cell_index"] == 2).all() + nodes = net2.nodes.set_index("global_comp_index") + cols = ["global_pre_comp_index", "global_post_comp_index"] + comp_inds = nodes.loc[first_set_edges[cols].to_numpy().flatten()] + branch_inds = comp_inds["global_branch_index"].to_numpy().reshape(-1, 2) + cell_inds = comp_inds["global_cell_index"].to_numpy().reshape(-1, 2) + assert np.all(branch_inds == (4, 8)) + assert (cell_inds == (1, 2)).all() assert ( get_comps(first_set_edges["pre_locs"]) == get_comps(first_set_edges["post_locs"]) @@ -181,14 +177,11 @@ def test_connectivity_matrix_connect(): net[:4], net[4:8], TestSynapse(), n_by_n_adjacency_matrix ) assert len(net.edges.index) == 4 - assert ( - ( - net.edges[["global_pre_cell_index", "global_post_cell_index"]] - == incides_of_connected_cells - ) - .all() - .all() - ) + nodes = net.nodes.set_index("global_comp_index") + cols = ["global_pre_comp_index", "global_post_comp_index"] + comp_inds = nodes.loc[net.edges[cols].to_numpy().flatten()] + cell_inds = comp_inds["global_cell_index"].to_numpy().reshape(-1, 2) + assert np.all(cell_inds == incides_of_connected_cells) m_by_n_adjacency_matrix = np.array( [[0, 1, 1, 0], [0, 0, 1, 1], [0, 0, 0, 1]], dtype=bool @@ -205,11 +198,8 @@ def test_connectivity_matrix_connect(): net[:3], net[:4], TestSynapse(), m_by_n_adjacency_matrix ) assert len(net.edges.index) == 5 - assert ( - ( - net.edges[["global_pre_cell_index", "global_post_cell_index"]] - == incides_of_connected_cells - ) - .all() - .all() - ) + nodes = net.nodes.set_index("global_comp_index") + cols = ["global_pre_comp_index", "global_post_comp_index"] + comp_inds = nodes.loc[net.edges[cols].to_numpy().flatten()] + cell_inds = comp_inds["global_cell_index"].to_numpy().reshape(-1, 2) + assert np.all(cell_inds == incides_of_connected_cells) diff --git a/tests/test_groups.py b/tests/test_groups.py index a469a214..00e22ee5 100644 --- a/tests/test_groups.py +++ b/tests/test_groups.py @@ -29,12 +29,6 @@ def test_subclassing_groups_cell_api(): cell.subtree.branch(0).set("radius", 0.1) cell.subtree.branch(0).comp("all").make_trainable("length") - # TODO: REMOVE THIS IS NOW ALLOWED - # with pytest.raises(KeyError): - # cell.subtree.cell(0).branch("all").make_trainable("length") - # with pytest.raises(KeyError): - # cell.subtree.comp(0).make_trainable("length") - def test_subclassing_groups_net_api(): comp = jx.Compartment() @@ -48,12 +42,6 @@ def test_subclassing_groups_net_api(): net.excitatory.cell(0).set("radius", 0.1) net.excitatory.cell(0).branch("all").make_trainable("length") - # TODO: REMOVE THIS IS NOW ALLOWED - # with pytest.raises(KeyError): - # cell.excitatory.branch(0).comp("all").make_trainable("length") - # with pytest.raises(KeyError): - # cell.excitatory.comp("all").make_trainable("length") - def test_subclassing_groups_net_set_equivalence(): """Test whether calling `.set` on subclasses group is same as on view.""" @@ -89,7 +77,7 @@ def test_subclassing_groups_net_make_trainable_equivalence(): # The following lines are made possible by PR #324. # The new behaviour needs changing of the scope to still conform here - # TODO: Rewrite this test / reconsider what behaviour is desired + # TODO FROM #447: Rewrite this test / reconsider what behaviour is desired net1.excitatory.scope("global").cell([0, 3]).scope("local").branch( 0 ).make_trainable("radius") @@ -113,37 +101,6 @@ def test_subclassing_groups_net_make_trainable_equivalence(): assert jnp.array_equal(inds1, inds2) -def test_subclassing_groups_net_lazy_indexing_make_trainable_equivalence(): - """Test whether groups can be indexing in a lazy way.""" - comp = jx.Compartment() - branch = jx.Branch(comp, 4) - cell = jx.Cell(branch, [-1, 0]) - net1 = jx.Network([cell for _ in range(10)]) - net2 = jx.Network([cell for _ in range(10)]) - - net1.cell([0, 3, 5]).add_to_group("excitatory") - net2.cell([0, 3, 5]).add_to_group("excitatory") - - # The following lines are made possible by PR #324. - net1.excitatory.cell([0, 3]).branch(0).make_trainable("radius") - net1.excitatory.cell([0, 5]).branch(1).comp("all").make_trainable("length") - net1.excitatory.cell("all").branch(1).comp(2).make_trainable("axial_resistivity") - params1 = jnp.concatenate(jax.tree_util.tree_flatten(net1.get_parameters())[0]) - - # The following lines are made possible by PR #324. - net2.excitatory[[0, 3], 0].make_trainable("radius") - net2.excitatory[[0, 5], 1, :].make_trainable("length") - net2.excitatory[:, 1, 2].make_trainable("axial_resistivity") - params2 = jnp.concatenate(jax.tree_util.tree_flatten(net2.get_parameters())[0]) - - assert jnp.array_equal(params1, params2) - - for inds1, inds2 in zip( - net1.indices_set_by_trainables, net2.indices_set_by_trainables - ): - assert jnp.array_equal(inds1, inds2) - - def test_fully_connect_groups_equivalence(): """Test whether groups can be used with `fully_connect`.""" comp = jx.Compartment() diff --git a/tests/test_viewing.py b/tests/test_viewing.py index d869a29a..ba09757f 100644 --- a/tests/test_viewing.py +++ b/tests/test_viewing.py @@ -4,6 +4,7 @@ from copy import deepcopy import jax +import pandas as pd import pytest jax.config.update("jax_enable_x64", True) @@ -195,27 +196,16 @@ def test_local_indexing(): ["local_cell_index", "local_branch_index", "local_comp_index"] ] idx_cols = ["global_cell_index", "global_branch_index", "global_comp_index"] - # TODO: Write new and more comprehensive test for local indexing! global_index = 0 for cell_idx in range(2): for branch_idx in range(5): for comp_idx in range(4): - - # compview = net[cell_idx, branch_idx, comp_idx].show() - # assert np.all( - # compview[idx_cols].values == [cell_idx, branch_idx, comp_idx] - # ) assert np.all( local_idxs.iloc[global_index] == [cell_idx, branch_idx, comp_idx] ) global_index += 1 -def test_comp_indexing_exception_handling(): - # TODO: Add tests for indexing exceptions - pass - - def test_indexing_a_compartment_of_many_branches(): comp = jx.Compartment() branch1 = jx.Branch(comp, nseg=3) @@ -226,7 +216,7 @@ def test_indexing_a_compartment_of_many_branches(): net = jx.Network([cell1, cell2]) # Indexing a single compartment of multiple branches is not supported with `loc`. - # TODO: Reevaluate what kind of indexing is allowed and which is not! + # TODO FROM #447: Reevaluate what kind of indexing is allowed and which is not! # with pytest.raises(NotImplementedError): # net.cell("all").branch("all").loc(0.0) # with pytest.raises(NotImplementedError): @@ -257,8 +247,6 @@ def test_solve_indexer(): assert np.all(idx.upper(branch_inds) == np.asarray([[0, 1, 2], [7, 8, 9]])) -# TODO: tests - comp = jx.Compartment() branch = jx.Branch(comp, nseg=3) cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) @@ -268,12 +256,21 @@ def test_solve_indexer(): # make sure all attrs in module also have a corresponding attr in view @pytest.mark.parametrize("module", [comp, branch, cell, net]) -def test_view_attrs(module): +def test_view_attrs(module: jx.Compartment | jx.Branch | jx.Cell | jx.Network): + """Check if all attributes of Module have a corresponding attribute in View. + + To ensure that View behaves like a Module as much as possible, View should support + all attributes of Module. This test checks if all attributes of Module have a + corresponding attribute in View. Also checks if the types of the attributes match. + """ # attributes of Module that do not have to exist in View exceptions = ["view"] - # TODO: should be added to View in the future + + # TODO: Types are inconsistent between different Modules + exceptions += ["cumsum_nbranches"] + + # TODO FROM #447: should be added to View in the future exceptions += [ - "cumsum_nseg", "_internal_node_inds", "par_inds", "child_inds", @@ -291,7 +288,6 @@ def test_view_attrs(module): "cumsum_nbranchpoints_per_cell", "_cumsum_nseg_per_cell", ] # for network - exceptions += ["cumsum_nbranches"] # HOTFIX #TODO: take care of this for name, attr in module.__dict__.items(): if name not in exceptions: @@ -304,7 +300,331 @@ def test_view_attrs(module): ), f"Type mismatch: {name}, Module type: {type(getattr(module, name))}, View type: {type(getattr(view, name))}" -# TODO: test filter for modules and check for param sharing -# add test local_indexing and global_indexing -# add cell.comp (branch is skipped also for param sharing) -# add tests for new features i.e. iter, context, scope +comp = jx.Compartment() +branch = jx.Branch([comp] * 4) +cell = jx.Cell([branch] * 4, parents=[-1, 0, 0, 0]) +net = jx.Network([cell] * 4) + + +@pytest.mark.parametrize("module", [comp, branch, cell, net]) +def test_view_supported_index_types(module): + """Check if different ways to index into Modules/Views work correctly.""" + # test int, range, slice, list, np.array, pd.Index + index_types = [ + 0, + range(3), + slice(0, 3), + [0, 1, 2], + np.array([0, 1, 2]), + pd.Index([0, 1, 2]), + ] + + # comp.comp is not allowed + if not isinstance(module, jx.Compartment): + # `_reformat_index` should always return a np.ndarray + for index in index_types: + assert isinstance( + module._reformat_index(index), np.ndarray + ), f"Failed for {type(index)}" + assert module.comp(index), f"Failed for {type(index)}" + assert View(module).comp(index), f"Failed for {type(index)}" + + # for loc test float and list of floats + assert module.loc(0.0), "Failed for float" + assert module.loc([0.0, 0.5, 1.0]), "Failed for List[float]" + else: + with pytest.raises(AssertionError): + module.comp(0) + + +def test_select(): + """Ensure `select` works correctly and returns expected View of Modules.""" + comp = jx.Compartment() + branch = jx.Branch([comp] * 3) + cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) + net = jx.Network([cell] * 3) + connect(net[0, 0, :], net[1, 0, :], TestSynapse()) + + np.random.seed(0) + + # select only nodes + inds = np.random.choice(net.nodes.index, replace=False, size=5) + view = net.select(nodes=inds) + assert np.all(view.nodes.index == inds), "Selecting nodes by index failed" + + # select only edges + inds = np.random.choice(net.edges.index, replace=False, size=2) + view = net.select(edges=inds) + assert np.all(view.edges.index == inds), "Selecting edges by index failed" + + # check if pre and post comps of edges are in nodes + edge_node_inds = np.unique( + view.edges[["global_pre_comp_index", "global_post_comp_index"]] + .to_numpy() + .flatten() + ) + assert np.all( + view.nodes["global_comp_index"] == edge_node_inds + ), "Selecting edges did not yield the correct nodes." + + # select nodes and edges + node_inds = np.random.choice(net.nodes.index, replace=False, size=5) + edge_inds = np.random.choice(net.edges.index, replace=False, size=2) + view = net.select(nodes=node_inds, edges=edge_inds) + assert np.all( + view.nodes.index == node_inds + ), "Selecting nodes and edges by index failed for nodes." + assert np.all( + view.edges.index == edge_inds + ), "Selecting nodes and edges by index failed for edges." + + +def test_viewing(): + """Test that the View object is working correctly.""" + comp = jx.Compartment() + branch = jx.Branch([comp] * 3) + cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) + net = jx.Network([cell] * 3) + + # test parameter sharing works correctly + nodes1 = net.branch(0).comp("all").nodes + nodes2 = net.branch(0).nodes + nodes3 = net.cell(0).nodes + control_params1 = nodes1.pop("controlled_by_param") + control_params2 = nodes2.pop("controlled_by_param") + control_params3 = nodes3.pop("controlled_by_param") + assert np.all(nodes1 == nodes2), "Nodes are not the same" + assert np.all( + control_params1 == nodes1["global_comp_index"] + ), "Parameter sharing is not correct" + assert np.all( + control_params2 == nodes2["global_branch_index"] + ), "Parameter sharing is not correct" + assert np.all( + control_params3 == nodes3["global_cell_index"] + ), "Parameter sharing is not correct" + + # test local and global indexes match the expected targets + for view, local_targets, global_targets in zip( + [ + net.branch(0), # shows every comp on 0th branch of all cells + cell.branch("all"), # shows all branches and comps of cell + net.cell(0).comp(0), # shows every 0th comp for every branch on 0th cell + net.comp(0), # shows 0th comp of every branch of every cell + cell.comp(0), # shows 0th comp of every branch of cell + ], + [[0, 1, 2] * 3, [0, 1, 2] * 3, [0] * 3, [0] * 9, [0] * 3], + [ + [0, 1, 2, 9, 10, 11, 18, 19, 20], + list(range(9)), + [0, 3, 6], + list(range(0, 27, 3)), + list(range(0, 9, 3)), + ], + ): + assert np.all( + view.nodes["local_comp_index"] == local_targets + ), "Indices do not match that of the target" + assert np.all( + view.nodes["global_comp_index"] == global_targets + ), "Indices do not match that of the target" + + with pytest.raises(ValueError): + net.scope("global").comp(999) # Nothing should be in View + + +def test_scope(): + """Ensure scope has the intended effect for Modules and Views.""" + comp = jx.Compartment() + branch = jx.Branch([comp] * 3) + cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) + + view = cell.scope("global").branch(1) + assert view._scope == "global" + view = view.scope("local").comp(0) + assert np.all( + view.nodes[["global_branch_index", "global_comp_index"]] == [1, 3] + ), "Expected [1,3] but got {}".format( + view.nodes[["global_branch_index", "global_comp_index"]] + ) + + cell.set_scope("global") + assert cell._scope == "global" + view = cell.branch(1).comp(3) + assert np.all( + view.nodes[["global_branch_index", "global_comp_index"]] == [1, 3] + ), "Expected [1,3] but got {}".format( + view.nodes[["global_branch_index", "global_comp_index"]] + ) + + cell.set_scope("local") + assert cell._scope == "local" + view = cell.branch(1).comp(0) + assert np.all( + view.nodes[["global_branch_index", "global_comp_index"]] == [1, 3] + ), "Expected [1,3] but got {}".format( + view.nodes[["global_branch_index", "global_comp_index"]] + ) + + +def test_context_manager(): + """Test that context manager works correctly for Module.""" + comp = jx.Compartment() + branch = jx.Branch([comp] * 3) + cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) + + with cell.branch(0).comp(0) as comp: + comp.set("v", -71) + comp.set("radius", 0.123) + + with cell.branch(1).comp([0, 1]) as comps: + comps.set("v", -71) + comps.set("radius", 0.123) + + assert np.all( + cell.branch(0).comp(1).nodes[["v", "radius"]] == [-70, 1.0] + ), "Set affected nodes not in context manager View." + assert np.all( + cell.branch(0).comp(0).nodes[["v", "radius"]] == [-71, 0.123] + ), "Context management of View not working." + assert np.all( + cell.branch(1).comp([0, 1]).nodes[["v", "radius"]] == [-71, 0.123] + ), "Context management of View not working." + + +def test_iter(): + """Test that __iter__ works correctly for all modules.""" + comp = jx.Compartment() + branch1 = jx.Branch([comp] * 2) + branch2 = jx.Branch([comp] * 3) + cell = jx.Cell([branch1, branch1, branch2], parents=[-1, 0, 0]) + net = jx.Network([cell] * 2) + + # test iterating over bracnhes with different numbers of compartments + assert np.all( + [ + len(branch.nodes) == expected_len + for branch, expected_len in zip(cell.branches, [2, 2, 3]) + ] + ), "__iter__ failed for branches with different numbers of compartments." + + # test iterating using cells, branches, and comps properties + nodes1 = [] + for cell in net.cells: + for branch in cell.branches: + for comp in branch.comps: + nodes1.append(comp.nodes) + assert len(nodes1) == len(net.nodes), "Some compartments were skipped in iteration." + + nodes2 = [] + for cell in net: + for branch in cell: + for comp in branch: + nodes2.append(comp.nodes) + assert len(nodes2) == len(net.nodes), "Some compartments were skipped in iteration." + assert np.all( + [np.all(n1 == n2) for n1, n2 in zip(nodes1, nodes2)] + ), "__iter__ is not consistent with [comp.nodes for cell in net.cells for branches in cell.branches for comp in branches.comps]" + + assert np.all( + [len(comp.nodes) for comp in net[0, 0].comps] == [1, 1] + ), "Iterator yielded unexpected number of compartments" + + # 0th comp in every branch (3), 1st comp in every branch (3), 2nd comp in (every) branch (only 1 branch with > 2 comps) + assert np.all( + [len(comp.nodes) for comp in net[0].comps] == [3, 3, 1] + ), "Iterator yielded unexpected number of compartments" + + # 0th comp in every branch for every cell (6), 1st comp in every branch for every cell , 2nd comp in (every) branch for every cell + assert np.all( + [len(comp.nodes) for comp in net.comps] == [6, 6, 2] + ), "Iterator yielded unexpected number of compartments" + + for comp in branch1: + comp.set("v", -72) + assert np.all(branch1.nodes["v"] == -72), "Setting parameters with __iter__ failed." + + # needs to be redefined because cell was overwritten with View object + cell = jx.Cell([branch1, branch1, branch2], parents=[-1, 0, 0]) + for branch in cell: + for comp in branch: + comp.set("v", -73) + assert np.all(cell.nodes["v"] == -73), "Setting parameters with __iter__ failed." + + +def test_synapse_and_channel_filtering(): + """Test that synapses and channels are filtered correctly by View.""" + comp = jx.Compartment() + branch = jx.Branch([comp] * 3) + cell = jx.Cell([branch] * 3, parents=[-1, 0, 0]) + net = jx.Network([cell] * 3) + net.insert(HH()) + connect(net[0, 0, :], net[1, 0, :], TestSynapse()) + + assert np.all(net.cell(0).HH.nodes == net.HH.cell(0).nodes) + view1 = net.cell([0, 1]).TestSynapse + nodes1 = view1.nodes + edges1 = view1.edges + view2 = net.TestSynapse.cell([0, 1]) + nodes2 = view2.nodes + edges2 = view2.edges + nodes_control_param1 = nodes1.pop("controlled_by_param") + nodes_control_param2 = nodes2.pop("controlled_by_param") + edges_control_param1 = edges1.pop("controlled_by_param") + edges_control_param2 = edges2.pop("controlled_by_param") + + # convert to dict so order of cols and index dont matter for __eq__ + assert nodes1.to_dict() == nodes2.to_dict() + assert np.all(nodes_control_param1 == 0) + assert np.all(nodes_control_param2 == nodes2["global_cell_index"]) + + assert np.all(edges1 == edges2) + + +def test_view_equals_module(): + """Test that View behaves the same as Module for important attrs and methods.""" + comp = jx.Compartment() + branch = jx.Branch([comp] * 3) + + comp.insert(HH()) + branch.comp([0, 1]).insert(HH()) + + comp.set("v", -71.2) + branch.comp(0).set("v", -71.2) + + comp.record("v") + branch.comp([0, 1]).record("v") + + comp.stimulate(np.zeros(100)) + branch.comp([0, 1]).stimulate(np.zeros(100)) + + comp.make_trainable("HH_gNa") + comp.make_trainable("HH_gK") + branch.comp([0, 1]).make_trainable("HH_gNa") + branch.make_trainable("HH_gK") + + # test deleting subset of attributes + branch.comp(1).delete_trainables() + branch.comp(1).delete_recordings() + branch.comp(1).delete_stimuli() + + assert ( + branch.comp(1).trainable_params == [] and branch.comp(0).trainable_params != [] + ) + assert branch.comp(1).recordings.empty and not branch.comp(0).recordings.empty + assert branch.comp(1).externals == {} and branch.comp(0).externals != {} + + # convert to dict so order of cols and index dont matter for __eq__ + assert comp.nodes.to_dict() == branch.comp(0).nodes.to_dict() + + assert comp.trainable_params == branch.comp(0).trainable_params + assert comp.indices_set_by_trainables == branch.comp(0).indices_set_by_trainables + assert np.all(comp.recordings == branch.comp(0).recordings) + assert np.all( + [ + np.all([np.all(v1 == v2), k1 == k2]) + for (k1, v1), (k2, v2) in zip( + comp.externals.items(), branch.comp(0).externals.items() + ) + ] + )