From a559cafafc9b8957f44d58dfc332ffef1e97343a Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Tue, 15 Oct 2024 16:58:27 +0200 Subject: [PATCH] fix: add latest changes from main --- jaxley/modules/base.py | 1560 +++++++++++++++++++++------------ jaxley/modules/branch.py | 122 ++- jaxley/modules/cell.py | 78 +- jaxley/modules/compartment.py | 131 ++- jaxley/modules/network.py | 304 ++++--- 5 files changed, 1442 insertions(+), 753 deletions(-) diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 3073a038..c432db2c 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -24,15 +24,12 @@ from jaxley.utils.cell_utils import ( _compute_index_of_child, _compute_num_children, - build_branchpoint_group_inds, - compute_children_in_level, + compute_axial_conductances, compute_levels, - compute_morphology_indices_in_levels, - compute_parents_in_level, convert_point_process_to_distributed, interpolate_xyz, loc_of_index, - remap_to_consecutive, + query_channel_states_and_params, v_interp, ) from jaxley.utils.debug_solver import compute_morphology_indices @@ -61,30 +58,27 @@ def __init__(self): self.total_nbranches: int = 0 self.nbranches_per_cell: List[int] = None - self.groups = {} + self.group_nodes = {} self.nodes: Optional[pd.DataFrame] = None self.edges = pd.DataFrame( columns=[ - f"{scope}_{lvl}_index" - for lvl in [ - "pre_comp", - "pre_branch", - "pre_cell", - "post_comp", - "post_branch", - "post_cell", - ] - for scope in ["global", "local"] + "pre_locs", + "pre_branch_index", + "pre_cell_index", + "post_locs", + "post_branch_index", + "post_cell_index", + "type", + "type_ind", + "global_pre_comp_index", + "global_post_comp_index", + "global_pre_branch_index", + "global_post_branch_index", ] - + ["pre_locs", "post_locs", "type", "type_ind"] ) - # Attributes for viewing - self._scope = "local" # defaults to local scope - self._in_view = None - self.cumsum_nbranches: Optional[jnp.ndarray] = None self.comb_parents: jnp.ndarray = jnp.asarray([-1]) @@ -126,9 +120,6 @@ def __init__(self): # `self._init_morph_for_debugging` is run. self.debug_states = {} - # needs to run at the end of __init__ - self.base = self - def _update_nodes_with_xyz(self): """Add xyz coordinates of compartment centers to nodes. @@ -145,11 +136,7 @@ def _update_nodes_with_xyz(self): avoid overlapping branch_lens i.e. norm_cum_branch_len = [0,1,1,2] for only incrementing. """ - nsegs = ( - self.nodes.groupby("global_branch_index")["global_comp_index"] - .nunique() - .to_numpy() - ) + nsegs = self.nodes.groupby("branch_index")["comp_index"].nunique().to_numpy() comp_ends = np.hstack( [np.linspace(0, 1, nseg + 1) + 2 * i for i, nseg in enumerate(nsegs)] @@ -172,7 +159,8 @@ def _update_nodes_with_xyz(self): # this means centers between comps have to be removed here between_comp_inds = (cum_nsegs + np.arange(len(cum_nsegs)))[:-1] centers = np.delete(centers, between_comp_inds, axis=0) - self.base.nodes.loc[self._in_view, ["x", "y", "z"]] = centers + idcs = self.nodes["comp_index"] + self.nodes.loc[idcs, ["x", "y", "z"]] = centers return centers, xyz def __repr__(self): @@ -185,7 +173,14 @@ def __dir__(self): base_dir = object.__dir__(self) return sorted(base_dir + self.synapse_names + list(self.group_nodes.keys())) - # TODO: update with new functionality? + @property + def _module_type(self): + """Return type of the module (compartment, branch, cell, network) as string. + + This is used to perform asserts for some modules (e.g. network cannot use + `set_ncomp`) without having to import the module in `base.py`.""" + return self.__class__.__name__.lower() + def _append_params_and_states(self, param_dict: Dict, state_dict: Dict): """Insert the default params of the module (e.g. radius, length). @@ -196,7 +191,6 @@ def _append_params_and_states(self, param_dict: Dict, state_dict: Dict): for state_name, state_value in state_dict.items(): self.nodes[state_name] = state_value - # TODO: update with new functionality? def _gather_channels_from_constituents(self, constituents: List): """Modify `self.channels` and `self.nodes` with channel info from constituents. @@ -215,8 +209,6 @@ def _gather_channels_from_constituents(self, constituents: List): name = channel._name self.nodes.loc[self.nodes[name].isna(), name] = False - # TODO: update with new functionality? - # TODO: verify that this works for view def to_jax(self): """Move `.nodes` to `.jaxnodes`. @@ -243,113 +235,6 @@ def to_jax(self): for key in synapse.synapse_states: self.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition]) - def _update_local_indices(self) -> pd.DataFrame: - idx_cols = ["global_comp_index", "global_branch_index", "global_cell_index"] - self.nodes.rename( - columns={col.replace("global_", ""): col for col in idx_cols}, inplace=True - ) - idcs = self.nodes[idx_cols] - - def reindex_a_by_b(df, a, b): - df.loc[:, a] = df.groupby(b)[a].rank(method="dense").astype(int) - 1 - return df - - idcs = reindex_a_by_b(idcs, idx_cols[1], idx_cols[2]) - idcs = reindex_a_by_b(idcs, idx_cols[0], idx_cols[1:]) - idcs.columns = [col.replace("global", "local") for col in idx_cols] - self.nodes[["local_comp_index", "local_branch_index", "local_cell_index"]] = ( - idcs[["local_comp_index", "local_branch_index", "local_cell_index"]] - ) - # TODO: place indices at the front of the dataframe - - def _reformat_index(self, idx): - idx = np.array([], dtype=int) if idx is None else idx - idx = np.array([idx]) if isinstance(idx, (int, np.int64)) else idx - idx = np.array(idx) if isinstance(idx, (list, range)) else idx - idx = np.arange(len(self._in_view) + 1)[idx] if isinstance(idx, slice) else idx - if isinstance(idx, str): - assert idx == "all", "Only 'all' is allowed" - idx = np.arange(len(self._in_view) + 1) - assert isinstance(idx, np.ndarray), "Invalid type" - assert idx.dtype == np.int64, "Invalid dtype" - return idx.reshape(-1) - - def _set_controlled_by_param(self, key): - if key in ["comp", "branch", "cell"]: - self.nodes["controlled_by_param"] = self.nodes[f"global_{key}_index"] - self.edges["controlled_by_param"] = self.edges[f"global_pre_{key}_index"] - else: - self.nodes["controlled_by_param"] = 0 - self.edges["controlled_by_param"] = 0 - - def at(self, idx, sorted=False): - idx = self._reformat_index(idx) - new_indices = self._in_view[idx] - new_indices = np.sort(new_indices) if sorted else new_indices - return View(self, at=new_indices) - - def set_scope(self, scope): - self._scope = scope - - def scope(self, scope): - view = self.view - view.set_scope(scope) - return view - - def _at_level(self, level: str, idx): - idx = self._reformat_index(idx) - where = self.nodes[self._scope + f"_{level}_index"].isin(idx) - inds = np.where(where)[0] - view = self.at(inds) - view._set_controlled_by_param(level) - return view - - def cell(self, idx): - return self._at_level("cell", idx) - - def branch(self, idx): - return self._at_level("branch", idx) - - def comp(self, idx): - return self._at_level("comp", idx) - - def loc(self, at: float): - comp_edges = np.linspace(0, 1, self.base.nseg + 1) - idx = np.digitize(at, comp_edges) - view = self.comp(idx) - return view - - def _iter_level(self, level): - col = self._scope + f"_{level}_index" - idxs = self.nodes[col].unique() - for idx in idxs: - yield self._at_level(level, idx) - - @property - def cells(self): - yield from self._iter_level("cell") - - @property - def branches(self): - yield from self._iter_level("branch") - - @property - def comps(self): - yield from self._iter_level("comp") - - def copy(self, reset_index=False, as_module=False): - view = deepcopy(self) - # TODO: add reset_index, i.e. for parents, nodes, edges etc. such that they - # start from 0/-1 and are contiguous - if as_module: - # TODO: initialize a new module with the same attributes - pass - return view - - @property - def view(self): - return View(self, self._in_view) - def show( self, param_names: Optional[Union[str, List[str]]] = None, # TODO. @@ -373,28 +258,41 @@ def show( Returns: A `pd.DataFrame` with the requested information. """ - nodes = self.nodes.copy() # prevents this from being edited - - cols = [] - inds = ["comp_index", "branch_index", "cell_index"] - scopes = ["local", "global"] - cols += ( - [f"{scope}_{idx}" for idx in inds for scope in scopes] if indices else [] - ) - cols += [ch._name for ch in self.channels] if channel_names else [] - cols += ( - sum([list(ch.channel_params) for ch in self.channels], []) if params else [] - ) - cols += ( - sum([list(ch.channel_states) for ch in self.channels], []) if states else [] + return self._show( + self.nodes, param_names, indices, params, states, channel_names ) - if not param_names is None: - cols = ( - [c for c in cols if c in param_names] if params else list(param_names) - ) + def _show( + self, + view: pd.DataFrame, + param_names: Optional[Union[str, List[str]]] = None, + indices: bool = True, + params: bool = True, + states: bool = True, + channel_names: Optional[List[str]] = None, + ): + """Print detailed information about the entire Module.""" + printable_nodes = deepcopy(view) - return nodes[cols] + for channel in self.channels: + name = channel._name + param_names = list(channel.channel_params.keys()) + state_names = list(channel.channel_states.keys()) + if channel_names is not None and name not in channel_names: + printable_nodes = printable_nodes.drop(name, axis=1) + printable_nodes = printable_nodes.drop(param_names, axis=1) + printable_nodes = printable_nodes.drop(state_names, axis=1) + else: + if not params: + printable_nodes = printable_nodes.drop(param_names, axis=1) + if not states: + printable_nodes = printable_nodes.drop(state_names, axis=1) + + if not indices: + for name in ["comp_index", "branch_index", "cell_index"]: + printable_nodes = printable_nodes.drop(name, axis=1) + + return printable_nodes def init_morph(self): """Initialize the morphology such that it can be processed by the solvers.""" @@ -412,29 +310,32 @@ def _init_morph_jaxley_spsolve(self): """Initialize the morphology for the custom Jaxley solver.""" raise NotImplementedError - def insert(self, channel): + def _compute_axial_conductances(self, params: Dict[str, jnp.ndarray]): + """Given radius, length, r_a, compute the axial coupling conductances.""" + return compute_axial_conductances(self._comp_edges, params) + + def _append_channel_to_nodes(self, view: pd.DataFrame, channel: "jx.Channel"): + """Adds channel nodes from constituents to `self.channel_nodes`.""" name = channel._name # Channel does not yet exist in the `jx.Module` at all. - if name not in [c._name for c in self.base.channels]: - self.base.channels.append(channel) - self.base.nodes[name] = ( - False # Previous columns do not have the new channel. - ) + if name not in [c._name for c in self.channels]: + self.channels.append(channel) + self.nodes[name] = False # Previous columns do not have the new channel. - if channel.current_name not in self.base.membrane_current_names: - self.base.membrane_current_names.append(channel.current_name) + if channel.current_name not in self.membrane_current_names: + self.membrane_current_names.append(channel.current_name) # Add a binary column that indicates if a channel is present. - self.base.nodes.loc[self._in_view, name] = True + self.nodes.loc[view.index.values, name] = True # Loop over all new parameters, e.g. gNa, eNa. for key in channel.channel_params: - self.base.nodes.loc[self._in_view, key] = channel.channel_params[key] + self.nodes.loc[view.index.values, key] = channel.channel_params[key] # Loop over all new parameters, e.g. gNa, eNa. for key in channel.channel_states: - self.base.nodes.loc[self._in_view, key] = channel.channel_states[key] + self.nodes.loc[view.index.values, key] = channel.channel_states[key] def set(self, key: str, val: Union[float, jnp.ndarray]): """Set parameter of module (or its view) to a new value. @@ -449,14 +350,27 @@ def set(self, key: str, val: Union[float, jnp.ndarray]): val: The value to set the parameter to. If it is `jnp.ndarray` then it must be of shape `(len(num_compartments))`. """ - if key in self.nodes.columns: - not_nan = ~self.nodes[key].isna() - self.base.nodes.loc[self._in_view[not_nan], key] = val - elif key in self.edges.columns: - not_nan = ~self.edges[key].isna() - self.base.edges.loc[self._edges_in_view[not_nan], key] = val + # TODO(@michaeldeistler) should we allow `.set()` for synaptic parameters + # without using the `SynapseView`, purely for consistency with `make_trainable`? + view = ( + self.edges + if key in self.synapse_param_names or key in self.synapse_state_names + else self.nodes + ) + self._set(key, val, view, view) + + def _set( + self, + key: str, + val: Union[float, jnp.ndarray], + view: pd.DataFrame, + table_to_update: pd.DataFrame, + ): + if key in view.columns: + view = view[~np.isnan(view[key])] + table_to_update.loc[view.index.values, key] = val else: - raise KeyError(f"Key '{key}' not found in nodes or edges") + raise KeyError("Key not recognized.") def data_set( self, @@ -473,12 +387,26 @@ def data_set( param_state: State of the setted parameters, internally used such that this function does not modify global state. """ + view = ( + self.edges + if key in self.synapse_param_names or key in self.synapse_state_names + else self.nodes + ) + return self._data_set(key, val, view, param_state=param_state) + + def _data_set( + self, + key: str, + val: Tuple[float, jnp.ndarray], + view: pd.DataFrame, + param_state: Optional[List[Dict]] = None, + ): # Note: `data_set` does not support arrays for `val`. - if key in self.nodes.columns: - not_nan = ~self.nodes[key].isna() + if key in view.columns: + view = view[~np.isnan(view[key])] added_param_state = [ { - "indices": np.atleast_2d(self._in_view[not_nan]), + "indices": np.atleast_2d(view.index.values), "key": key, "val": jnp.atleast_1d(jnp.asarray(val)), } @@ -640,30 +568,50 @@ def make_trainable( verbose: Whether to print the number of parameters that are added and the total number of parameters. """ + assert ( + key not in self.synapse_param_names and key not in self.synapse_state_names + ), "Parameters of synapses can only be made trainable via the `SynapseView`." + view = self.nodes + view = deepcopy(view.assign(controlled_by_param=0)) + self._make_trainable(view, key, init_val, verbose=verbose) + + def _make_trainable( + self, + view: pd.DataFrame, + key: str, + init_val: Optional[Union[float, list]] = None, + verbose: bool = True, + ): assert ( self.allow_make_trainable ), "network.cell('all').make_trainable() is not supported. Use a for-loop over cells." - data = self.nodes if key in self.nodes.columns else None - data = self.edges if key in self.edges.columns else data - assert data is not None, f"Key '{key}' not found in nodes or edges" - not_nan = ~data[key].isna() - data = data.loc[not_nan] - assert ( - len(data) > 0 - ), "No settable parameters found in the selected compartments." - - grouped_view = data.groupby("controlled_by_param") - # Because of this `x.index.values` we cannot support `make_trainable()` on - # the module level for synapse parameters (but only for `SynapseView`). - inds_of_comps = list( - grouped_view.apply(lambda x: x.index.values, include_groups=False) - ) + if key in view.columns: + view = view[~np.isnan(view[key])] + grouped_view = view.groupby("controlled_by_param") + num_elements_being_set = grouped_view.apply(len).to_numpy() + assert np.all(num_elements_being_set == num_elements_being_set[0]), ( + "You are using `make_trainable()` with parameter sharing (e.g. same " + "parameter for an entire cell, or same parameter for entire branches). " + "This error is caused because you are trying to share a parameter " + "across an inhomogenous number of compartments. To overcome this " + "error, write a for-loop across cells (or branches). For example, " + "change `net.cell('all').make_trainable('HH_gNa')` to " + "`for i in range(num_cells): net.cell(i).make_trainable('HH_gNa')`" + ) + # Because of this `x.index.values` we cannot support `make_trainable()` on + # the module level for synapse parameters (but only for `SynapseView`). + inds_of_comps = list(grouped_view.apply(lambda x: x.index.values)) + + # Sorted inds are only used to infer the correct starting values. + param_vals = jnp.asarray( + [view.loc[inds, key].to_numpy() for inds in inds_of_comps] + ) + else: + raise KeyError(f"Parameter {key} not recognized.") + indices_per_param = jnp.stack(inds_of_comps) - # Sorted inds are only used to infer the correct starting values. - param_vals = jnp.asarray( - [data.loc[inds, key].to_numpy() for inds in inds_of_comps] - ) + self.indices_set_by_trainables.append(indices_per_param) # Set the value which the trainable parameter should take. num_created_parameters = len(indices_per_param) @@ -681,19 +629,19 @@ def make_trainable( ) else: new_params = jnp.mean(param_vals, axis=1) - self.base.trainable_params.append({key: new_params}) - self.base.indices_set_by_trainables.append(indices_per_param) + + self.trainable_params.append({key: new_params}) + self.num_trainable_params += num_created_parameters if verbose: print( f"Number of newly added trainable parameters: {num_created_parameters}. Total number of trainable parameters: {self.num_trainable_params}" ) - # TODO: Make this work on view def delete_trainables(self): """Removes all trainable parameters from the module.""" - self.base.indices_set_by_trainables = [] - self.base.trainable_params = [] - self.base.num_trainable_params = 0 + self.indices_set_by_trainables: List[jnp.ndarray] = [] + self.trainable_params: List[Dict[str, jnp.ndarray]] = [] + self.num_trainable_params: int = 0 def add_to_group(self, group_name: str): """Add a view of the module to a group. @@ -707,7 +655,12 @@ def add_to_group(self, group_name: str): Args: group_name: The name of the group. """ - self.base.groups[group_name] = self._in_view + raise ValueError("`add_to_group()` makes no sense for an entire module.") + + def _add_to_group(self, group_name: str, view: pd.DataFrame): + if group_name in self.group_nodes: + view = pd.concat([self.group_nodes[group_name], view]) + self.group_nodes[group_name] = view def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: """Get all trainable parameters. @@ -720,9 +673,9 @@ def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: """ return self.trainable_params - # TODO: Verify that this works (also for view) - # TODO: update with new functionality? - def get_all_parameters(self, pstate: List[Dict]) -> Dict[str, jnp.ndarray]: + def get_all_parameters( + self, pstate: List[Dict], voltage_solver: str + ) -> Dict[str, jnp.ndarray]: """Return all parameters (and coupling conductances) needed to simulate. Runs `_compute_axial_conductances()` and return every parameter that is needed @@ -780,7 +733,18 @@ def get_all_parameters(self, pstate: List[Dict]) -> Dict[str, jnp.ndarray]: params["axial_conductances"] = self._compute_axial_conductances(params=params) return params - # TODO: Verify that this works + def get_states_from_nodes_and_edges(self): + """Return states as they are set in the `.nodes` and `.edges` tables.""" + self.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`. + states = {"v": self.jaxnodes["v"]} + # Join node and edge states into a single state dictionary. + for channel in self.channels: + for channel_states in channel.channel_states: + states[channel_states] = self.jaxnodes[channel_states] + for synapse_states in self.synapse_state_names: + states[synapse_states] = self.jaxedges[synapse_states] + return states + def get_all_states( self, pstate: List[Dict], all_params, delta_t: float ) -> Dict[str, jnp.ndarray]: @@ -829,8 +793,7 @@ def initialize(self): self.init_morph() return self - # TODO: Verify that this works - def init_states(self): + def init_states(self, delta_t: float = 0.025): """Initialize all mechanisms in their steady state. This considers the voltages and parameters of each compartment. @@ -850,8 +813,10 @@ def init_states(self): for channel in self.channels: name = channel._name - indices = channel_nodes.loc[channel_nodes[name]].index.to_numpy() - voltages = channel_nodes.loc[indices, "v"].to_numpy() + channel_indices = channel_nodes.loc[channel_nodes[name]][ + "comp_index" + ].to_numpy() + voltages = channel_nodes.loc[channel_indices, "v"].to_numpy() channel_param_names = list(channel.channel_params.keys()) channel_state_names = list(channel.channel_states.keys()) @@ -869,76 +834,78 @@ def init_states(self): # `init_state` might not return all channel states. Only the ones that are # returned are updated here. for key, val in init_state.items(): - self.base.nodes.loc[indices, key] = val - - # TODO: - # def _init_morph_for_debugging(self): - # """Instandiates row and column inds which can be used to solve the voltage eqs. - - # This is important only for expert users who try to modify the solver for the - # voltage equations. By default, this function is never run. - - # This is useful for debugging the solver because one can use - # `scipy.linalg.sparse.spsolve` after every step of the solve. - - # Here is the code snippet that can be used for debugging then (to be inserted in - # `solver_voltage`): - # ```python - # from scipy.sparse import csc_matrix - # from scipy.sparse.linalg import spsolve - # from jaxley.utils.debug_solver import build_voltage_matrix_elements - - # elements, solve, num_entries, start_ind_for_branchpoints = ( - # build_voltage_matrix_elements( - # uppers, - # lowers, - # diags, - # solves, - # branchpoint_conds_children[debug_states["child_inds"]], - # branchpoint_conds_parents[debug_states["par_inds"]], - # branchpoint_weights_children[debug_states["child_inds"]], - # branchpoint_weights_parents[debug_states["par_inds"]], - # branchpoint_diags, - # branchpoint_solves, - # debug_states["nseg"], - # nbranches, - # ) - # ) - # sparse_matrix = csc_matrix( - # (elements, (debug_states["row_inds"], debug_states["col_inds"])), - # shape=(num_entries, num_entries), - # ) - # solution = spsolve(sparse_matrix, solve) - # solution = solution[:start_ind_for_branchpoints] # Delete branchpoint voltages. - # solves = jnp.reshape(solution, (debug_states["nseg"], nbranches)) - # return solves - # ``` - # """ - # # For scipy and jax.scipy. - # row_and_col_inds = compute_morphology_indices( - # len(self.par_inds), - # self.child_belongs_to_branchpoint, - # self.par_inds, - # self.child_inds, - # self.nseg, - # self.total_nbranches, - # ) - - # num_elements = len(row_and_col_inds["row_inds"]) - # data_inds, indices, indptr = convert_to_csc( - # num_elements=num_elements, - # row_ind=row_and_col_inds["row_inds"], - # col_ind=row_and_col_inds["col_inds"], - # ) - # self.debug_states["row_inds"] = row_and_col_inds["row_inds"] - # self.debug_states["col_inds"] = row_and_col_inds["col_inds"] - # self.debug_states["data_inds"] = data_inds - # self.debug_states["indices"] = indices - # self.debug_states["indptr"] = indptr - - # self.debug_states["nseg"] = self.nseg - # self.debug_states["child_inds"] = self.child_inds - # self.debug_states["par_inds"] = self.par_inds + # Note that we are overriding `self.nodes` here, but `self.nodes` is + # not used above to actually compute the current states (so there are + # no issues with overriding states). + self.nodes.loc[channel_indices, key] = val + + def _init_morph_for_debugging(self): + """Instandiates row and column inds which can be used to solve the voltage eqs. + + This is important only for expert users who try to modify the solver for the + voltage equations. By default, this function is never run. + + This is useful for debugging the solver because one can use + `scipy.linalg.sparse.spsolve` after every step of the solve. + + Here is the code snippet that can be used for debugging then (to be inserted in + `solver_voltage`): + ```python + from scipy.sparse import csc_matrix + from scipy.sparse.linalg import spsolve + from jaxley.utils.debug_solver import build_voltage_matrix_elements + + elements, solve, num_entries, start_ind_for_branchpoints = ( + build_voltage_matrix_elements( + uppers, + lowers, + diags, + solves, + branchpoint_conds_children[debug_states["child_inds"]], + branchpoint_conds_parents[debug_states["par_inds"]], + branchpoint_weights_children[debug_states["child_inds"]], + branchpoint_weights_parents[debug_states["par_inds"]], + branchpoint_diags, + branchpoint_solves, + debug_states["nseg"], + nbranches, + ) + ) + sparse_matrix = csc_matrix( + (elements, (debug_states["row_inds"], debug_states["col_inds"])), + shape=(num_entries, num_entries), + ) + solution = spsolve(sparse_matrix, solve) + solution = solution[:start_ind_for_branchpoints] # Delete branchpoint voltages. + solves = jnp.reshape(solution, (debug_states["nseg"], nbranches)) + return solves + ``` + """ + # For scipy and jax.scipy. + row_and_col_inds = compute_morphology_indices( + len(self.par_inds), + self.child_belongs_to_branchpoint, + self.par_inds, + self.child_inds, + self.nseg, + self.total_nbranches, + ) + + num_elements = len(row_and_col_inds["row_inds"]) + data_inds, indices, indptr = convert_to_csc( + num_elements=num_elements, + row_ind=row_and_col_inds["row_inds"], + col_ind=row_and_col_inds["col_inds"], + ) + self.debug_states["row_inds"] = row_and_col_inds["row_inds"] + self.debug_states["col_inds"] = row_and_col_inds["col_inds"] + self.debug_states["data_inds"] = data_inds + self.debug_states["indices"] = indices + self.debug_states["indptr"] = indptr + + self.debug_states["nseg"] = self.nseg + self.debug_states["child_inds"] = self.child_inds + self.debug_states["par_inds"] = self.par_inds def record(self, state: str = "v", verbose: bool = True): """Insert a recording into the compartment. @@ -946,20 +913,20 @@ def record(self, state: str = "v", verbose: bool = True): Args: state: The name of the state to record. verbose: Whether to print number of inserted recordings.""" - new_recs = pd.DataFrame(self._in_view, columns=["rec_index"]) - new_recs["state"] = state - self.base.recordings = pd.concat([self.base.recordings, new_recs]) - has_duplicates = self.base.recordings.duplicated() - self.base.recordings = self.base.recordings.loc[~has_duplicates] + view = deepcopy(self.nodes) + view["state"] = state + recording_view = view[["comp_index", "state"]] + recording_view = recording_view.rename(columns={"comp_index": "rec_index"}) + self._record(recording_view, verbose=verbose) + + def _record(self, view: pd.DataFrame, verbose: bool = True): + self.recordings = pd.concat([self.recordings, view], ignore_index=True) if verbose: - print( - f"Added {len(self._in_view)-sum(has_duplicates)} recordings. See `.recordings` for details." - ) + print(f"Added {len(view)} recordings. See `.recordings` for details.") - # TODO: Make this work on view def delete_recordings(self): """Removes all recordings from the module.""" - self.base.recordings = pd.DataFrame().from_dict({}) + self.recordings = pd.DataFrame().from_dict({}) def stimulate(self, current: Optional[jnp.ndarray] = None, verbose: bool = True): """Insert a stimulus into the compartment. @@ -975,7 +942,7 @@ def stimulate(self, current: Optional[jnp.ndarray] = None, verbose: bool = True) Args: current: Current in `nA`. """ - self._external_input("i", current, verbose=verbose) + self._external_input("i", current, self.nodes, verbose=verbose) def clamp(self, state_name: str, state_array: jnp.ndarray, verbose: bool = True): """Clamp a state to a given value across specified compartments. @@ -995,35 +962,26 @@ def _external_input( self, key: str, values: Optional[jnp.ndarray], + view: pd.DataFrame, verbose: bool = True, ): values = values if values.ndim == 2 else jnp.expand_dims(values, axis=0) batch_size = values.shape[0] - num_inserted = len(self._in_view) - is_multiple = num_inserted == batch_size - values = ( - values if is_multiple else jnp.repeat(values, len(self._in_view), axis=0) - ) - assert batch_size in [ - 1, - num_inserted, - ], "Number of comps and stimuli do not match." - - if key in self.base.externals.keys(): - self.base.externals[key] = jnp.concatenate( - [self.base.externals[key], values] - ) - self.base.external_inds[key] = jnp.concatenate( - [self.base.external_inds[key], self._in_view] + is_multiple = len(view) == batch_size + values = values if is_multiple else jnp.repeat(values, len(view), axis=0) + assert batch_size in [1, len(view)], "Number of comps and stimuli do not match." + + if key in self.externals.keys(): + self.externals[key] = jnp.concatenate([self.externals[key], values]) + self.external_inds[key] = jnp.concatenate( + [self.external_inds[key], view.comp_index.to_numpy()] ) else: - self.base.externals[key] = values - self.base.external_inds[key] = self._in_view + self.externals[key] = values + self.external_inds[key] = view.comp_index.to_numpy() if verbose: - print( - f"Added {num_inserted} external_states. See `.externals` for details." - ) + print(f"Added {len(view)} external_states. See `.externals` for details.") def data_stimulate( self, @@ -1038,15 +996,50 @@ def data_stimulate( verbose: Whether or not to print the number of inserted stimuli. `False` by default because this method is meant to be jitted. """ - current = current if current.ndim == 2 else jnp.expand_dims(current, axis=0) - batch_size = current.shape[0] - num_inserted = len(self._in_view) - is_multiple = num_inserted == batch_size - current = current if is_multiple else jnp.repeat(current, num_inserted, axis=0) - assert batch_size in [ - 1, - num_inserted, - ], "Number of comps and stimuli do not match." + return self._data_external_input( + "i", current, data_stimuli, self.nodes, verbose=verbose + ) + + def data_clamp( + self, + state_name: str, + state_array: jnp.ndarray, + data_clamps: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None, + verbose: bool = False, + ): + """Insert a clamp into the module within jit (or grad). + + Args: + state_name: Name of the state variable to set. + state_array: Time series of the state variable in the default Jaxley unit. + State array should be of shape (num_clamps, simulation_time) or + (simulation_time, ) for a single clamp. + verbose: Whether or not to print the number of inserted clamps. `False` + by default because this method is meant to be jitted. + """ + return self._data_external_input( + state_name, state_array, data_clamps, self.nodes, verbose=verbose + ) + + def _data_external_input( + self, + state_name: str, + state_array: jnp.ndarray, + data_external_input: Optional[Tuple[jnp.ndarray, pd.DataFrame]], + view: pd.DataFrame, + verbose: bool = False, + ): + state_array = ( + state_array + if state_array.ndim == 2 + else jnp.expand_dims(state_array, axis=0) + ) + batch_size = state_array.shape[0] + is_multiple = len(view) == batch_size + state_array = ( + state_array if is_multiple else jnp.repeat(state_array, len(view), axis=0) + ) + assert batch_size in [1, len(view)], "Number of comps and clamps do not match." if data_external_input is not None: external_input = data_external_input[1] @@ -1056,29 +1049,38 @@ def data_stimulate( external_input = state_array inds = pd.DataFrame().from_dict({}) - # Same as in `.stimulate()`. - if currents is not None: - currents = jnp.concatenate([currents, current]) - else: - currents = current - inds = pd.concat([inds, self._in_view]) + inds = pd.concat([inds, view]) if verbose: - print(f"Added {num_inserted} stimuli.") + if state_name == "i": + print(f"Added {len(view)} stimuli.") + else: + print(f"Added {len(view)} clamps.") return (state_name, external_input, inds) - # TODO: Make this work on view def delete_stimuli(self): """Removes all stimuli from the module.""" - self.base.externals.pop("i", None) - self.base.external_inds.pop("i", None) + self.externals.pop("i", None) + self.external_inds.pop("i", None) - def init_syns(self): - self.base.initialized_syns = True + def delete_clamps(self, state_name: str): + """Removes all clamps of the given state from the module.""" + self.externals.pop(state_name, None) + self.external_inds.pop(state_name, None) - def init_morph(self): - self.base.initialized_morph = True + def insert(self, channel: Channel): + """Insert a channel into the module. + + Args: + channel: The channel to insert.""" + self._insert(channel, self.nodes) + + def _insert(self, channel, view): + self._append_channel_to_nodes(view, channel) + + def init_syns(self): + self.initialized_syns = True def step( self, @@ -1248,7 +1250,7 @@ def _step_channels_state( voltages = states["v"] # Update states of the channels. - indices = channel_nodes["global_comp_index"].to_numpy() + indices = channel_nodes["comp_index"].to_numpy() for channel in channels: channel_param_names = list(channel.channel_params) channel_param_names += [ @@ -1307,9 +1309,7 @@ def _channel_currents( name = channel._name channel_param_names = list(channel.channel_params.keys()) channel_state_names = list(channel.channel_states.keys()) - indices = channel_nodes.loc[channel_nodes[name]][ - "global_comp_index" - ].to_numpy() + indices = channel_nodes.loc[channel_nodes[name]]["comp_index"].to_numpy() channel_params = {} for p in channel_param_names: @@ -1430,7 +1430,35 @@ def vis( type: The type of plot. One of ["line", "scatter", "comp", "morph"]. morph_plot_kwargs: Keyword arguments passed to the plotting function. """ - branches_inds = self.nodes["branch_index"].to_numpy() + return self._vis( + dims=dims, + col=col, + ax=ax, + view=self.nodes, + type=type, + morph_plot_kwargs=morph_plot_kwargs, + ) + + def _vis( + self, + ax: Axes, + col: str, + dims: Tuple[int], + view: pd.DataFrame, + type: str, + morph_plot_kwargs: Dict, + ) -> Axes: + branches_inds = view["branch_index"].to_numpy() + + if "comp" in type.lower(): + return plot_comps( + self, view, dims=dims, ax=ax, col=col, **morph_plot_kwargs + ) + if "morph" in type.lower(): + return plot_morph( + self, view, dims=dims, ax=ax, col=col, **morph_plot_kwargs + ) + coords = [] for branch_ind in branches_inds: assert not np.any( @@ -1449,6 +1477,34 @@ def vis( return ax + def _scatter(self, ax, col, dims, view, morph_plot_kwargs): + """Scatter visualization (used only for compartments).""" + assert len(view) == 1, "Scatter only deals with compartments." + branch_ind = view["branch_index"].to_numpy().item() + comp_ind = view["comp_index"].to_numpy().item() + assert not np.any( + np.isnan(self.xyzr[branch_ind][:, dims]) + ), "No coordinates available. Use `vis(detail='point')` or run `.compute_xyz()` before running `.vis()`." + + comp_fraction = loc_of_index( + comp_ind, + branch_ind, + self.nseg_per_branch, + ) + coords = self.xyzr[branch_ind] + interpolated_xyz = interpolate_xyz(comp_fraction, coords) + + ax = plot_graph( + np.asarray([[interpolated_xyz]]), + dims=dims, + col=col, + ax=ax, + type="scatter", + morph_plot_kwargs=morph_plot_kwargs, + ) + + return ax + def compute_xyz(self): """Return xyz coordinates of every branch, based on the branch length. @@ -1466,9 +1522,7 @@ def compute_xyz(self): levels = compute_levels(parents) # Extract branch. - inds_branch = self.nodes.groupby("global_branch_index")[ - "global_comp_index" - ].apply(list) + inds_branch = self.nodes.groupby("branch_index")["comp_index"].apply(list) branch_lens = [np.sum(self.nodes["length"][np.asarray(i)]) for i in inds_branch] endpoints = [] @@ -1526,9 +1580,16 @@ def move( `False` largely speeds up moving, especially for big networks, but `.nodes` or `.show` will not show the new xyz coordinates. """ - indizes = self.nodes["global_branch_index"].unique() + self._move(x, y, z, self.nodes, update_nodes) + + def _move(self, x: float, y: float, z: float, view, update_nodes: bool): + # Need to cast to set because this will return one columnn per compartment, + # not one column per branch. + indizes = set(view["branch_index"].to_numpy().tolist()) for i in indizes: - self.base.xyzr[i][:, :3] += np.array([x, x, y]) + self.xyzr[i][:, 0] += x + self.xyzr[i][:, 1] += y + self.xyzr[i][:, 2] += z if update_nodes: self._update_nodes_with_xyz() @@ -1555,25 +1616,63 @@ def move_to( `False` largely speeds up moving, especially for big networks, but `.nodes` or `.show` will not show the new xyz coordinates. """ + self._move_to(x, y, z, self.nodes, update_nodes) + + def _move_to( + self, + x: Union[float, np.ndarray], + y: Union[float, np.ndarray], + z: Union[float, np.ndarray], + view: pd.DataFrame, + update_nodes: bool, + ): # Test if any coordinate values are NaN which would greatly affect moving if np.any(np.concatenate(self.xyzr, axis=0)[:, :3] == np.nan): raise ValueError( "NaN coordinate values detected. Shift amounts cannot be computed. Please run compute_xyzr() or assign initial coordinate values." ) - indizes = self.nodes["global_branch_index"].unique() - move_by = ( - np.array([x, y, z]).T - self.xyzr[0][0, :3] - ) # move with respect to root idx + # Get the indices of the cells and branches to move + cell_inds = list(view.cell_index.unique()) + branch_inds = view.branch_index.unique() + + if ( + isinstance(x, np.ndarray) + and isinstance(y, np.ndarray) + and isinstance(z, np.ndarray) + ): + assert ( + x.shape == y.shape == z.shape == (len(cell_inds),) + ), "x, y, and z array shapes are not all equal to the number of cells to be moved." + + # Split the branches by cell id + tup_indices = np.array([view.cell_index, view.branch_index]) + view_cell_branch_inds = np.unique(tup_indices, axis=1)[0] + _, branch_split_inds = np.unique(view_cell_branch_inds, return_index=True) + branches_by_cell = np.split( + view.branch_index.unique(), branch_split_inds[1:] + ) + + # Calculate the amount to shift all of the branches of each cell + shift_amounts = ( + np.array([x, y, z]).T - np.stack(self[cell_inds, 0].xyzr)[:, 0, :3] + ) + + else: + # Treat as if all branches belong to the same cell to be moved + branches_by_cell = [branch_inds] + # Calculate the amount to shift all branches by the 1st branch of 1st cell + shift_amounts = [np.array([x, y, z]) - self[cell_inds].xyzr[0][0, :3]] + + # Move all of the branches + for i, branches in enumerate(branches_by_cell): + for b in branches: + self.xyzr[b][:, :3] += shift_amounts[i] - for idx in indizes: - self.base.xyzr[idx][:, :3] += move_by if update_nodes: self._update_nodes_with_xyz() - def rotate( - self, degrees: float, rotation_axis: str = "xy", update_nodes: bool = True - ): + def rotate(self, degrees: float, rotation_axis: str = "xy"): """Rotate jaxley modules clockwise. Used only for visualization. This function is used only for visualization. It does not affect the simulation. @@ -1582,6 +1681,9 @@ def rotate( degrees: How many degrees to rotate the module by. rotation_axis: Either of {`xy` | `xz` | `yz`}. """ + self._rotate(degrees=degrees, rotation_axis=rotation_axis, view=self.nodes) + + def _rotate(self, degrees: float, rotation_axis: str, view: pd.DataFrame): degrees = degrees / 180 * np.pi if rotation_axis == "xy": dims = [0, 1] @@ -1595,12 +1697,10 @@ def rotate( rotation_matrix = np.asarray( [[np.cos(degrees), np.sin(degrees)], [-np.sin(degrees), np.cos(degrees)]] ) - indizes = self.nodes["global_branch_index"].unique() + indizes = set(view["branch_index"].to_numpy().tolist()) for i in indizes: - rot = np.dot(rotation_matrix, self.base.xyzr[i][:, dims].T).T - self.base.xyzr[i][:, dims] = rot - if update_nodes: - self._update_nodes_with_xyz() + rot = np.dot(rotation_matrix, self.xyzr[i][:, dims].T).T + self.xyzr[i][:, dims] = rot @property def shape(self) -> Tuple[int]: @@ -1611,246 +1711,564 @@ def shape(self) -> Tuple[int]: cell.shape = (num_branches, num_compartments) branch.shape = (num_compartments,) ```""" - cols = ["global_cell_index", "global_branch_index", "global_comp_index"] - raw_shape = self.nodes[cols].nunique().to_list() - - # ensure (net.shape -> dim=3, cell.shape -> dim=2, branch.shape -> dim=1, comp.shape -> dim=0) - levels = ["network", "cell", "branch", "comp"] - module = self.base.__class__.__name__.lower() - module = "comp" if module == "compartment" else module - shape = tuple(raw_shape[levels.index(module) :]) - return shape + mod_name = self.__class__.__name__.lower() + if "comp" in mod_name: + return (1,) + elif "branch" in mod_name: + return self[:].shape[1:] + return self[:].shape def __getitem__(self, index): - levels = ["network", "cell", "branch", "comp"] - module = self.base.__class__.__name__.lower() # - module = "comp" if module == "compartment" else module - - children = levels[levels.index(module) + 1 :] - index = index if isinstance(index, tuple) else (index,) - view = self - for i, child in enumerate(children): - view = view._at_level(child, index[i]) - return view + return self._getitem(self, index) + + def _getitem( + self, + module: Union["Module", "View"], + index: Union[Tuple, int], + child_name: Optional[str] = None, + ) -> "View": + """Return View which is created from indexing the module. + + Args: + module: The module to be indexed. Will be a `Module` if `._getitem` is + called from `__getitem__` in a `Module` and will be a `View` if it was + called from `__getitem__` in a `View`. + index: The index (or indices) to index the module. + child_name: If passed, this will be the key that is used to index the + `module`, e.g. if it is the string `branch` then we will try to call + `module.xyz(index)`. If `None` then we try to infer automatically what + the childview should be, given the name of the `module`. + + Returns: + An indexed `View`. + """ + if isinstance(index, tuple): + if len(index) > 1: + return childview(module, index[0], child_name)[index[1:]] + return childview(module, index[0], child_name) + return childview(module, index, child_name) def __iter__(self): for i in range(self.shape[0]): yield self[i] -class View(Module): - def __init__(self, pointer, at=None): - # attrs with a static view - self._scope = pointer._scope - self.base = pointer.base - self.initialized_morph = pointer.initialized_morph - self.initialized_syns = pointer.initialized_syns - self.allow_make_trainable = pointer.allow_make_trainable - - # attrs affected by view - self.nseg = pointer.nseg - self._in_view = pointer._in_view if at is None else at - - self.nodes = pointer.nodes.loc[self._in_view] - self.edges = pointer.edges.loc[self._edges_in_view] - self.nseg = 1 if len(self.nodes) == 1 else pointer.nseg - self.total_nbranches = len(self._branches_in_view) - self.nbranches_per_cell = self._nbranches_per_cell_in_view() - self.cumsum_nbranches = np.cumsum(self.nbranches_per_cell) - self.comb_branches_in_each_level = pointer.comb_branches_in_each_level - self.branch_edges = pointer.branch_edges.loc[self._branch_edges_in_view] - - self.synapse_names = np.unique(self.edges["type"]).tolist() - self.synapses, self.synapse_param_names, self.synapse_state_names = ( - self._synapses_in_view(pointer) +class View: + """View of a `Module`.""" + + def __init__(self, pointer: Module, view: pd.DataFrame): + self.pointer = pointer + self.view = view + self.allow_make_trainable = True + + def __repr__(self): + return f"{type(self).__name__}. Use `.show()` for details." + + def __str__(self): + return f"{type(self).__name__}" + + def show( + self, + param_names: Optional[Union[str, List[str]]] = None, # TODO. + *, + indices: bool = True, + params: bool = True, + states: bool = True, + channel_names: Optional[List[str]] = None, + ) -> pd.DataFrame: + """Print detailed information about the Module or a view of it. + + Args: + param_names: The names of the parameters to show. If `None`, all parameters + are shown. NOT YET IMPLEMENTED. + indices: Whether to show the indices of the compartments. + params: Whether to show the parameters of the compartments. + states: Whether to show the states of the compartments. + channel_names: The names of the channels to show. If `None`, all channels are + shown. + + Returns: + A `pd.DataFrame` with the requested information. + """ + view = self.pointer._show( + self.view, param_names, indices, params, states, channel_names ) + if not indices: + for name in [ + "global_comp_index", + "global_branch_index", + "global_cell_index", + "controlled_by_param", + ]: + if name in view.columns: + view = view.drop(name, axis=1) + return view - if pointer.recordings.empty: - self.recordings = pd.DataFrame() - else: - self.recordings = pointer.recordings.loc[ - pointer.recordings["rec_index"].isin(self._comps_in_view) - ] + def set_global_index_and_index(self, nodes: pd.DataFrame) -> pd.DataFrame: + """Use the global compartment, branch, and cell index as the index.""" + nodes = nodes.drop("controlled_by_param", axis=1) + nodes = nodes.drop("comp_index", axis=1) + nodes = nodes.drop("branch_index", axis=1) + nodes = nodes.drop("cell_index", axis=1) + nodes = nodes.rename( + columns={ + "global_comp_index": "comp_index", + "global_branch_index": "branch_index", + "global_cell_index": "cell_index", + } + ) + return nodes + + def insert(self, channel: Channel): + """Insert a channel into the module at the currently viewed location(s). + + Args: + channel: The channel to insert. + """ + assert not inspect.isclass( + channel + ), """ + Channel is a class, but it was not initialized. Use `.insert(Channel())` + instead of `.insert(Channel)`. + """ + nodes = self.set_global_index_and_index(self.view) + self.pointer._insert(channel, nodes) + + def record(self, state: str = "v", verbose: bool = True): + """Record a state variable of the compartment(s) at the currently view location(s). + + Args: + state: The name of the state to record. + verbose: Whether to print number of inserted recordings.""" + nodes = self.set_global_index_and_index(self.view) + view = deepcopy(nodes) + view["state"] = state + recording_view = view[["comp_index", "state"]] + recording_view = recording_view.rename(columns={"comp_index": "rec_index"}) + self.pointer._record(recording_view, verbose=verbose) + + def stimulate(self, current: Optional[jnp.ndarray] = None, verbose: bool = True): + nodes = self.set_global_index_and_index(self.view) + self.pointer._external_input("i", current, nodes, verbose=verbose) - self.xyzr = self._xyzr_in_view(pointer) - self.channels = self._channels_in_view(pointer) - self.membrane_current_names = [c._name for c in self.channels] + def data_stimulate( + self, + current: jnp.ndarray, + data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]], + verbose: bool = False, + ): + """Insert a stimulus into the module within jit (or grad). - self.indices_set_by_trainables, self.trainable_params = ( - self._trainables_in_view() + Args: + current: Current in `nA`. + verbose: Whether or not to print the number of inserted stimuli. `False` + by default because this method is meant to be jitted. + """ + nodes = self.set_global_index_and_index(self.view) + return self.pointer._data_external_input( + "i", current, data_stimuli, nodes, verbose=verbose ) - self.num_trainable_params = np.sum( - [len(inds) for inds in self.indices_set_by_trainables] + + def clamp(self, state_name: str, state_array: jnp.ndarray, verbose: bool = True): + """Clamp a state to a given value across specified compartments. + + Args: + state_name: The name of the state to clamp. + state_array: Array of values to clamp the state to. + verbose: If True, prints details about the clamping. + + This function sets external states for the compartments. + """ + nodes = self.set_global_index_and_index(self.view) + self.pointer._external_input(state_name, state_array, nodes, verbose=verbose) + + def data_clamp( + self, + state_name: str, + state_array: jnp.ndarray, + data_clamps: Optional[Tuple[jnp.ndarray, pd.DataFrame]], + verbose: bool = False, + ): + """Insert a clamp into the module within jit (or grad).""" + nodes = self.set_global_index_and_index(self.view) + return self.pointer._data_external_input( + state_name, state_array, data_clamps, nodes, verbose=verbose ) - self.comb_parents = self.base.comb_parents[self._branches_in_view] - self.externals, self.external_inds = self._externals_in_view() - self.groups = { - k: np.intersect1d(v, self._in_view) for k, v in pointer.groups.items() - } + def set(self, key: str, val: float): + """Set parameters of the pointer.""" + self.pointer._set(key, val, self.view, self.pointer.nodes) - self.children_in_level, self.parents_in_level = self._levels_in_view() + def data_set( + self, + key: str, + val: Union[float, jnp.ndarray], + param_state: Optional[List[Dict]] = None, + ): + """Set parameter of module (or its view) to a new value within `jit`.""" + return self.pointer._data_set(key, val, self.view, param_state) - # TODO: - # self.debug_states + def make_trainable( + self, + key: str, + init_val: Optional[Union[float, list]] = None, + verbose: bool = True, + ): + """Make a parameter trainable.""" + self.pointer._make_trainable(self.view, key, init_val, verbose=verbose) - if len(self.nodes) == 0: - raise ValueError("Nothing in view. Check your indices.") + def add_to_group(self, group_name: str): + self.pointer._add_to_group(group_name, self.view) - def _externals_in_view(self): - externals_in_view = {} - external_inds_in_view = [] - for (name, inds), data in zip( - self.base.external_inds.items(), self.base.externals.values() - ): - in_view = np.isin(inds, self._in_view) - inds_in_view = inds[in_view] - if len(inds_in_view) > 0: - externals_in_view[name] = data[in_view] - external_inds_in_view.append(inds_in_view) - return externals_in_view, external_inds_in_view - - def _trainables_in_view(self): - trainable_inds = self.base.indices_set_by_trainables - trainable_inds = ( - np.unique(np.hstack([inds.reshape(-1) for inds in trainable_inds])) - if len(trainable_inds) > 0 - else [] + def vis( + self, + ax: Optional[Axes] = None, + col: str = "k", + dims: Tuple[int] = (0, 1), + type: str = "line", + morph_plot_kwargs: Dict = {}, + ) -> Axes: + """Visualize the module. + + Modules can be visualized on one of the cardinal planes (xy, xz, yz) or + even in 3D. + + Several options are available: + - `line`: All points from the traced morphology (`xyzr`), are connected + with a line plot. + - `scatter`: All traced points, are plotted as scatter points. + - `comp`: Plots the compartmentalized morphology, including radius + and shape. (shows the true compartment lengths per default, but this can + be changed via the `morph_plot_kwargs`, for details see + `jaxley.utils.plot_utils.plot_comps`). + - `morph`: Reconstructs the 3D shape of the traced morphology. For details see + `jaxley.utils.plot_utils.plot_morph`. Warning: For 3D plots and morphologies + with many traced points this can be very slow. + + Args: + ax: An axis into which to plot. + col: The color for all branches. + type: The type of plot. One of ["line", "scatter", "comp", "morph"]. + dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of + two of them. + morph_plot_kwargs: Keyword arguments passed to the plotting function. + """ + nodes = self.set_global_index_and_index(self.view) + return self.pointer._vis( + ax=ax, + col=col, + dims=dims, + view=nodes, + type=type, + morph_plot_kwargs=morph_plot_kwargs, ) - trainable_inds_in_view = np.intersect1d(trainable_inds, self._in_view) - índices_set_by_trainables_in_view = [] - trainable_params_in_view = [] - for inds, params in zip( - self.base.indices_set_by_trainables, self.base.trainable_params - ): - in_view = np.isin(inds, trainable_inds_in_view) + def move( + self, x: float = 0.0, y: float = 0.0, z: float = 0.0, update_nodes: bool = True + ): + """Move cells or networks by adding to their (x, y, z) coordinates. - completely_in_view = in_view.all(axis=1) - índices_set_by_trainables_in_view.append(inds[completely_in_view]) - trainable_params_in_view.append( - {k: v[completely_in_view] for k, v in params.items()} - ) + This function is used only for visualization. It does not affect the simulation. - partially_in_view = in_view.any(axis=1) & ~completely_in_view - índices_set_by_trainables_in_view.append( - inds[partially_in_view][in_view[partially_in_view]] - ) - trainable_params_in_view.append( - {k: v[partially_in_view] for k, v in params.items()} - ) + Args: + x: The amount to move in the x direction in um. + y: The amount to move in the y direction in um. + z: The amount to move in the z direction in um. + """ + nodes = self.set_global_index_and_index(self.view) + self.pointer._move(x, y, z, nodes, update_nodes=update_nodes) - índices_set_by_trainables_in_view = [ - inds for inds in índices_set_by_trainables_in_view if len(inds) > 0 - ] - trainable_params_in_view = [ - p for p in trainable_params_in_view if len(next(iter(p.values()))) > 0 - ] - return índices_set_by_trainables_in_view, trainable_params_in_view - - def _channels_in_view(self, pointer): - names = [name._name for name in pointer.channels] - channel_in_view = self.nodes[names].any(axis=0) - channel_in_view = channel_in_view[channel_in_view].index - return [c for c in pointer.channels if c._name in channel_in_view] - - def _synapses_in_view(self, pointer): - viewed_synapses = [] - viewed_params = [] - viewed_states = [] - if not pointer.synapses is None: - for syn in pointer.synapses: - if syn is not None: # needed for recurive viewing - in_view = syn._name in self.synapse_names - viewed_synapses += ( - [syn] if in_view else [None] - ) # padded with None to keep indices consistent - viewed_params += list(syn.synapse_params.keys()) if in_view else [] - viewed_states += list(syn.synapse_states.keys()) if in_view else [] - - return viewed_synapses, viewed_params, viewed_states - - def _nbranches_per_cell_in_view(self): - cell_nodes = self.nodes.groupby("global_cell_index") - return cell_nodes["global_branch_index"].nunique().to_numpy() - - def _xyzr_in_view(self, pointer): - viewed_branch_inds = self._branches_in_view - if hasattr(pointer, "_branches_in_view"): - prev_branch_inds = pointer._branches_in_view - else: - prev_branch_inds = pointer.nodes["global_branch_index"].unique() - if prev_branch_inds is None: - xyzr = pointer.xyzr.copy() # copy to prevent editing original + def move_to( + self, x: float = 0.0, y: float = 0.0, z: float = 0.0, update_nodes: bool = True + ): + """Move cells or networks to a location (x, y, z). + + If x, y, and z are floats, then the first compartment of the first branch + of the first cell is moved to that float coordinate, and everything else is + shifted by the difference between that compartment's previous coordinate and + the new float location. + + If x, y, and z are arrays, then they must each have a length equal to the number + of cells being moved. Then the first compartment of the first branch of each + cell is moved to the specified location. + """ + # Ensuring here that the branch indices in the view passed are global + nodes = self.set_global_index_and_index(self.view) + self.pointer._move_to(x, y, z, nodes, update_nodes=update_nodes) + + def adjust_view( + self, key: str, index: Union[int, str, list, range, slice] + ) -> "View": + """Update view. + + Select a subset, range, slice etc. of the self.view based on the index key, + i.e. (cell_index, [1,2]). returns a view of all compartments of cell 1 and 2. + + Args: + key: The key to adjust the view by. + index: The index to adjust the view by. + + Returns: + A new view. + """ + if isinstance(index, int) or isinstance(index, np.int64): + self.view = self.view[self.view[key] == index] + elif isinstance(index, list) or isinstance(index, range): + self.view = self.view[self.view[key].isin(index)] + elif isinstance(index, slice): + index = list(range(self.view[key].max() + 1))[index] + return self.adjust_view(key, index) else: - branches2keep = np.isin(prev_branch_inds, viewed_branch_inds) - branch_inds2keep = np.where(branches2keep)[0] - xyzr = [pointer.xyzr[i] for i in branch_inds2keep].copy() - - # Currently viewing with `.loc` will show the closest compartment - # rather than the actual loc along the branch! - viewed_nseg_for_branch = self.nodes.groupby("global_branch_index").size() - incomplete_inds = np.where(viewed_nseg_for_branch != self.base.nseg)[0] - incomplete_branch_inds = viewed_branch_inds[incomplete_inds] - - cond = self.nodes["global_branch_index"].isin(incomplete_branch_inds) - interp_inds = self.nodes.loc[cond] - local_inds_per_branch = interp_inds.groupby("global_branch_index")["local_comp_index"] - locs = [loc_of_index(inds.to_numpy(), self.base.nseg) for _, inds in local_inds_per_branch] - - for i, loc in zip(incomplete_inds, locs): - xyzr[i] = interpolate_xyz(loc, xyzr[i]).T - return xyzr - - #TODO: Implement this! - def _levels_in_view(self): - children_in_level = [] - parents_in_level = [] - return children_in_level, parents_in_level + assert index == "all" + self.view["controlled_by_param"] -= self.view["controlled_by_param"].iloc[0] + return self - @property - def _nodes_in_view(self): - return self._in_view + def _get_local_indices(self) -> pd.DataFrame: + """Computes local from global indices. + + #cell_index, branch_index, comp_index + 0, 0, 0 --> 0, 0, 0 # 1st compartment of 1st branch of 1st cell + 0, 0, 1 --> 0, 0, 1 # 2nd compartment of 1st branch of 1st cell + 0, 1, 2 --> 0, 1, 0 # 1st compartment of 2nd branch of 1st cell + 0, 1, 3 --> 0, 1, 1 # 2nd compartment of 2nd branch of 1st cell + 1, 2, 4 --> 1, 0, 0 # 1st compartment of 1st branch of 2nd cell + 1, 2, 5 --> 1, 0, 1 # 2nd compartment of 1st branch of 2nd cell + 1, 3, 6 --> 1, 1, 0 # 1st compartment of 2nd branch of 2nd cell + 1, 3, 7 --> 1, 1, 1 # 2nd compartment of 2nd branch of 2nd cell + """ - @property - def _branches_in_view(self): - return self.nodes["global_branch_index"].unique() + def reindex_a_by_b(df, a, b): + df.loc[:, a] = df.groupby(b)[a].rank(method="dense").astype(int) - 1 + return df - @property - def _cells_in_view(self): - return self.nodes["global_cell_index"].unique() + idcs = self.view[["cell_index", "branch_index", "comp_index"]] + idcs = reindex_a_by_b(idcs, "branch_index", "cell_index") + idcs = reindex_a_by_b(idcs, "comp_index", ["cell_index", "branch_index"]) + return idcs - @property - def _comps_in_view(self): - return self.nodes["global_comp_index"].unique() + def __getitem__(self, index): + return self.pointer._getitem(self, index) + + def __iter__(self): + for i in range(self.shape[0]): + yield self[i] + + def rotate(self, degrees: float, rotation_axis: str = "xy"): + """Rotate jaxley modules clockwise. Used only for visualization. + + Args: + degrees: How many degrees to rotate the module by. + rotation_axis: Either of {`xy` | `xz` | `yz`}. + """ + raise NotImplementedError( + "Only entire `jx.Module`s or entire cells within a network can be rotated." + ) @property - def _branch_edges_in_view(self): - incl_branches = self.nodes["global_branch_index"].unique() - pre = self.base.branch_edges["parent_branch_index"].isin(incl_branches) - post = self.base.branch_edges["child_branch_index"].isin(incl_branches) - viewed_branch_inds = self.base.branch_edges.index.to_numpy()[pre & post] - return viewed_branch_inds + def shape(self) -> Tuple[int]: + """Returns the number of elements currently in view. + + ``` + network.shape = (num_cells, num_branches, num_compartments) + cell.shape = (num_branches, num_compartments) + branch.shape = (num_compartments,) + ```""" + local_idcs = self._get_local_indices() + return tuple(local_idcs.nunique()) @property - def _edges_in_view(self): - incl_comps = self.nodes["global_comp_index"].unique() - pre = self.base.edges["global_pre_comp_index"].isin(incl_comps).to_numpy() - post = self.base.edges["global_post_comp_index"].isin(incl_comps).to_numpy() - viewed_edge_inds = self.base.edges.index.to_numpy()[(pre & post).flatten()] - return viewed_edge_inds - - # point abstract methods to base - def init_conds(self, params: Dict): - return self.base.init_conds(params) - - def __enter__(self): - return self + def xyzr(self) -> List[np.ndarray]: + """Returns the xyzr entries of a branch, cell, or network. - def __exit__(self, exc_type, exc_value, exc_traceback): - pass + If called on a compartment or location, it will return the (x, y, z) of the + center of the compartment. + """ + idxs = self.view.global_branch_index.unique() + if self.__class__.__name__ == "CompartmentView": + loc = loc_of_index( + self.view["global_comp_index"].to_numpy(), + self.view["global_branch_index"].to_numpy(), + self.pointer.nseg_per_branch, + ) + return list(interpolate_xyz(loc, self.pointer.xyzr[idxs[0]])) + else: + return [self.pointer.xyzr[i] for i in idxs] + + def _append_multiple_synapses( + self, pre_rows: pd.DataFrame, post_rows: pd.DataFrame, synapse_type: Synapse + ): + """Append multiple rows to the `self.edges` table. + + This is used, e.g. by `fully_connect` and `connect`. + + Args: + pre_rows: The pre-synaptic compartments. + post_rows: The post-synaptic compartments. + synapse_type: The synapse to append. + + both `pre_rows` and `post_rows` can be obtained from self.view. + """ + # Add synapse types to the module and infer their unique identifier. + synapse_name = synapse_type._name + index = len(self.pointer.edges) + type_ind, is_new = self._infer_synapse_type_ind(synapse_name) + if is_new: # synapse is not known + self._update_synapse_state_names(synapse_type) + + post_loc = loc_of_index( + post_rows["global_comp_index"].to_numpy(), + post_rows["global_branch_index"].to_numpy(), + self.pointer.nseg_per_branch, + ) + pre_loc = loc_of_index( + pre_rows["global_comp_index"].to_numpy(), + pre_rows["global_branch_index"].to_numpy(), + self.pointer.nseg_per_branch, + ) + + # Define new synapses. Each row is one synapse. + new_rows = dict( + pre_locs=pre_loc, + post_locs=post_loc, + pre_branch_index=pre_rows["branch_index"].to_numpy(), + post_branch_index=post_rows["branch_index"].to_numpy(), + pre_cell_index=pre_rows["cell_index"].to_numpy(), + post_cell_index=post_rows["cell_index"].to_numpy(), + type=synapse_name, + type_ind=type_ind, + global_pre_comp_index=pre_rows["global_comp_index"].to_numpy(), + global_post_comp_index=post_rows["global_comp_index"].to_numpy(), + global_pre_branch_index=pre_rows["global_branch_index"].to_numpy(), + global_post_branch_index=post_rows["global_branch_index"].to_numpy(), + ) + + # Update edges. + self.pointer.edges = concat_and_ignore_empty( + [self.pointer.edges, pd.DataFrame(new_rows)], + ignore_index=True, + ) + + indices = [idx for idx in range(index, index + len(pre_loc))] + self._add_params_to_edges(synapse_type, indices) + + def _infer_synapse_type_ind(self, synapse_name: str) -> Tuple[int, bool]: + """Return the unique identifier for every synapse type. + + Also returns a boolean indicating whether the synapse is already in the + `module`. + + Used during `self._append_multiple_synapses`. + + Args: + synapse_name: The name of the synapse. + + Returns: + type_ind: Index referencing the synapse type in self.synapses. + is_new_type: Whether the synapse is new to the module. + """ + syn_names = self.pointer.synapse_names + is_new_type = False if synapse_name in syn_names else True + type_ind = len(syn_names) if is_new_type else syn_names.index(synapse_name) + return type_ind, is_new_type + + def _add_params_to_edges(self, synapse_type: Synapse, indices: list): + """Fills parameter and state columns of new synapses in the `edges` table. + + This method does not create new rows in the `.edges` table. It only fills + columns of already existing rows. + + Used during `self._append_multiple_synapses`. + + Args: + synapse_type: The synapse to append. + indices: The indices of the synapses according to self.synapses. + """ + # Add parameters and states to the `.edges` table. + for key, param_val in synapse_type.synapse_params.items(): + self.pointer.edges.loc[indices, key] = param_val + + # Update synaptic state array. + for key, state_val in synapse_type.synapse_states.items(): + self.pointer.edges.loc[indices, key] = state_val + + def _update_synapse_state_names(self, synapse_type: Synapse): + """Update attributes with information about the synapses. + + Used during `self._append_multiple_synapses`. + + Args: + synapse_type: The synapse to append. + """ + # (Potentially) update variables that track meta information about synapses. + self.pointer.synapse_names.append(synapse_type._name) + self.pointer.synapse_param_names += list(synapse_type.synapse_params.keys()) + self.pointer.synapse_state_names += list(synapse_type.synapse_states.keys()) + self.pointer.synapses.append(synapse_type) -class GroupView: - # TEMPORARY HOTFIX TO ALLOW IMPORT - pass +class GroupView(View): + """GroupView (aka sectionlist). + + Unlike the standard `View` it sets `controlled_by_param` to + 0 for all compartments. This means that a group will be controlled by a single + parameter (unless it is subclassed). + """ + + def __init__( + self, + pointer: Module, + view: pd.DataFrame, + childview: type, + childview_keys: List[str], + ): + """Initialize group. + + Args: + pointer: The module from which the group was created. + view: The dataframe which defines the compartments, branches, and cells in + the group. + childview: An uninitialized view (e.g. `CellView`). Depending on the module, + subclassing groups will return a different `View`. E.g., `net.group[0]` + will return a `CellView`, whereas `cell.group[0]` will return a + `BranchView`. The childview argument defines which view is created. We + do not automatically infer this because that would force us to import + `CellView`, `BranchView`, and `CompartmentView` in the `base.py` file. + childview_keys: The names by which the group can be subclassed. Used to + raise `KeyError` if one does, e.g. `net.group.branch(0)` (i.e. `.cell` + is skipped). + """ + self.childview_of_group = childview + self.names_of_childview = childview_keys + view["controlled_by_param"] = 0 + super().__init__(pointer, view) + + def __getattr__(self, key: str) -> View: + """Subclass the group. + + This first checks whether the key that is used to subclass the view is allowed. + For example, one cannot `net.group.branch(0)` but instead must use + `net.group.cell("all").branch(0).` If this is valid, then it instantiates the + correct `View` which had been passed to `__init__()`. + + Args: + key: The key which is used to subclass the group. + + Return: + View of the subclassed group. + """ + # Ensure that hidden methods such as `__deepcopy__` still work. + if key.startswith("__"): + return super().__getattribute__(key) + + if key in self.names_of_childview: + view = deepcopy(self.view) + view["global_comp_index"] = view["comp_index"] + view["global_branch_index"] = view["branch_index"] + view["global_cell_index"] = view["cell_index"] + return self.childview_of_group(self.pointer, view) + else: + raise KeyError(f"Key {key} not recognized.") + + def __getitem__(self, index): + """Subclass the group with lazy indexing.""" + return self.pointer._getitem(self, index, self.names_of_childview[0]) diff --git a/jaxley/modules/branch.py b/jaxley/modules/branch.py index fa433e68..7fe35e2c 100644 --- a/jaxley/modules/branch.py +++ b/jaxley/modules/branch.py @@ -65,15 +65,11 @@ def __init__( self.cumsum_nseg = cumsum_leading_zero(self.nseg_per_branch) # Indexing. - # TODO: Might have to be self.base.nodes self.nodes = pd.concat([c.nodes for c in compartment_list], ignore_index=True) self._append_params_and_states(self.branch_params, self.branch_states) - self.nodes["global_comp_index"] = np.arange(self.nseg).tolist() - self.nodes["global_branch_index"] = [0] * self.nseg - self.nodes["global_cell_index"] = [0] * self.nseg - self.nodes["controlled_by_param"] = 0 - self._in_view = self.nodes.index.to_numpy() - self._update_local_indices() + self.nodes["comp_index"] = np.arange(self.nseg).tolist() + self.nodes["branch_index"] = [0] * self.nseg + self.nodes["cell_index"] = [0] * self.nseg # Channels. self._gather_channels_from_constituents(compartment_list) @@ -98,9 +94,36 @@ def __init__( # Coordinates. self.xyzr = [float("NaN") * np.zeros((2, 4))] - def init_conds(self, params: Dict) -> Dict[str, jnp.ndarray]: - conds = self.init_branch_conds( - params["axial_resistivity"], params["radius"], params["length"], self.nseg + def __getattr__(self, key: str): + # Ensure that hidden methods such as `__deepcopy__` still work. + if key.startswith("__"): + return super().__getattribute__(key) + + if key in ["comp", "loc"]: + view = deepcopy(self.nodes) + view["global_comp_index"] = view["comp_index"] + view["global_branch_index"] = view["branch_index"] + view["global_cell_index"] = view["cell_index"] + compview = CompartmentView(self, view) + return compview if key == "comp" else compview.loc + elif key in self.group_nodes: + inds = self.group_nodes[key].index.values + view = self.nodes.loc[inds] + view["global_comp_index"] = view["comp_index"] + view["global_branch_index"] = view["branch_index"] + view["global_cell_index"] = view["cell_index"] + return GroupView(self, view, CompartmentView, ["comp", "loc"]) + else: + raise KeyError(f"Key {key} not recognized.") + + def _init_morph_jaxley_spsolve(self): + self.solve_indexer = JaxleySolveIndexer( + cumsum_nseg=self.cumsum_nseg, + branchpoint_group_inds=np.asarray([]).astype(int), + remapped_node_indices=self._internal_node_inds, + children_in_level=[], + parents_in_level=[], + root_inds=np.asarray([0]), ) def _init_morph_jax_spsolve(self): @@ -175,5 +198,80 @@ def set_ncomp(self, ncomp: int, min_radius: Optional[float] = None): self.initialize() -class BranchView: - pass +class BranchView(View): + """BranchView.""" + + def __init__(self, pointer: Module, view: pd.DataFrame): + view = view.assign(controlled_by_param=view.global_branch_index) + super().__init__(pointer, view) + + def __call__(self, index: float): + local_idcs = self._get_local_indices() + self.view[local_idcs.columns] = ( + local_idcs # set indexes locally. enables net[0:2,0:2] + ) + self.allow_make_trainable = True + new_view = super().adjust_view("branch_index", index) + return new_view + + def __getattr__(self, key): + assert key in ["comp", "loc"] + compview = CompartmentView(self.pointer, self.view) + return compview if key == "comp" else compview.loc + + def set_ncomp(self, ncomp: int, min_radius: Optional[float] = None): + """Set the number of compartments with which the branch is discretized. + + Args: + ncomp: The number of compartments that the branch should be discretized + into. + min_radius: Only used if the morphology was read from an SWC file. If passed + the radius is capped to be at least this value. + + Raises: + - When there are stimuli in any compartment in the module. + - When there are recordings in any compartment in the module. + - When the channels of the compartments are not the same within the branch + that is modified. + - When the lengths of the compartments are not the same within the branch + that is modified. + - Unless the morphology was read from an SWC file, when the radiuses of the + compartments are not the same within the branch that is modified. + """ + if self.pointer._module_type == "network": + raise NotImplementedError( + "`.set_ncomp` is not yet supported for a `Network`. To overcome this, " + "first build individual cells with the desired `ncomp` and then " + "assemble them into a network." + ) + + error_msg = lambda name: ( + f"Your module contains a {name}. This is not allowed. First build the " + "morphology with `set_ncomp()` and then insert stimuli, recordings, and " + "define trainables." + ) + assert len(self.pointer.externals) == 0, error_msg("stimulus") + assert len(self.pointer.recordings) == 0, error_msg("recording") + assert len(self.pointer.trainable_params) == 0, error_msg("trainable parameter") + # Update all attributes that are affected by compartment structure. + ( + self.pointer.nodes, + self.pointer.nseg_per_branch, + self.pointer.nseg, + self.pointer.cumsum_nseg, + self.pointer._internal_node_inds, + ) = self.pointer._set_ncomp( + ncomp, + self.view, + self.pointer.nodes, + self.view["global_comp_index"].to_numpy()[0], + self.pointer.nseg_per_branch, + [c._name for c in self.pointer.channels], + list(chain(*[c.channel_params for c in self.pointer.channels])), + list(chain(*[c.channel_states for c in self.pointer.channels])), + self.pointer._radius_generating_fns, + min_radius, + ) + + # Update the morphology indexing (e.g., `.comp_edges`). + self.pointer.initialize() diff --git a/jaxley/modules/cell.py b/jaxley/modules/cell.py index 66da9e15..961928ff 100644 --- a/jaxley/modules/cell.py +++ b/jaxley/modules/cell.py @@ -105,14 +105,15 @@ def __init__( # Build nodes. Has to be changed when `.set_ncomp()` is run. self.nodes = pd.concat([c.nodes for c in branch_list], ignore_index=True) - self._append_params_and_states(self.cell_params, self.cell_states) - self.nodes["global_comp_index"] = np.arange( - self.nseg * self.total_nbranches - ).tolist() - self.nodes["global_branch_index"] = ( - np.arange(self.nseg * self.total_nbranches) // self.nseg + self.nodes["comp_index"] = np.arange(self.cumsum_nseg[-1]) + self.nodes["branch_index"] = np.repeat( + np.arange(self.total_nbranches), self.nseg_per_branch ).tolist() - self.nodes["global_cell_index"] = [0] * (self.nseg * self.total_nbranches) + self.nodes["cell_index"] = np.repeat(0, self.cumsum_nseg[-1]).tolist() + + # Appending general parameters (radius, length, r_a, cm) and channel parameters, + # as well as the states (v, and channel states). + self._append_params_and_states(self.cell_params, self.cell_states) # Channels. self._gather_channels_from_constituents(branch_list) @@ -128,10 +129,6 @@ def __init__( ) ) - self.nodes["controlled_by_param"] = 0 - self._in_view = self.nodes.index.to_numpy() - self._update_local_indices() - # For morphology indexing. self.par_inds, self.child_inds, self.child_belongs_to_branchpoint = ( compute_children_and_parents(self.branch_edges) @@ -140,10 +137,29 @@ def __init__( self.initialize() self.init_syns() - # TODO: update with new functionality? - # TODO: Verify that this works - def init_morph(self): - """Initialize morphology.""" + def __getattr__(self, key: str): + # Ensure that hidden methods such as `__deepcopy__` still work. + if key.startswith("__"): + return super().__getattribute__(key) + + if key == "branch": + view = deepcopy(self.nodes) + view["global_comp_index"] = view["comp_index"] + view["global_branch_index"] = view["branch_index"] + view["global_cell_index"] = view["cell_index"] + return BranchView(self, view) + elif key in self.group_nodes: + inds = self.group_nodes[key].index.values + view = self.nodes.loc[inds] + view["global_comp_index"] = view["comp_index"] + view["global_branch_index"] = view["branch_index"] + view["global_cell_index"] = view["cell_index"] + return GroupView(self, view, BranchView, ["branch"]) + else: + raise KeyError(f"Key {key} not recognized.") + + def _init_morph_jaxley_spsolve(self): + """Initialize morphology for the custom sparse solver. Running this function is only required for custom Jaxley solvers, i.e., for `voltage_solver={'jaxley.stone', 'jaxley.thomas'}`. However, because at @@ -311,8 +327,36 @@ def set_ncomp(self, ncomp: int, min_radius: Optional[float] = None): ) -class CellView: - pass +class CellView(View): + """CellView.""" + + def __init__(self, pointer: Module, view: pd.DataFrame): + view = view.assign(controlled_by_param=view.global_cell_index) + super().__init__(pointer, view) + + def __call__(self, index: float): + local_idcs = self._get_local_indices() + self.view[local_idcs.columns] = ( + local_idcs # set indexes locally. enables net[0:2,0:2] + ) + if index == "all": + self.allow_make_trainable = False + new_view = super().adjust_view("cell_index", index) + return new_view + + def __getattr__(self, key: str): + assert key == "branch" + return BranchView(self.pointer, self.view) + + def rotate(self, degrees: float, rotation_axis: str = "xy"): + """Rotate jaxley modules clockwise. Used only for visualization. + + Args: + degrees: How many degrees to rotate the module by. + rotation_axis: Either of {`xy` | `xz` | `yz`}. + """ + nodes = self.set_global_index_and_index(self.view) + self.pointer._rotate(degrees=degrees, rotation_axis=rotation_axis, view=nodes) def read_swc( diff --git a/jaxley/modules/compartment.py b/jaxley/modules/compartment.py index cf380bae..0a00cbea 100644 --- a/jaxley/modules/compartment.py +++ b/jaxley/modules/compartment.py @@ -49,11 +49,8 @@ def __init__(self): dict(comp_index=[0], branch_index=[0], cell_index=[0]) ) self._append_params_and_states(self.compartment_params, self.compartment_states) - self.nodes["controlled_by_param"] = 0 - self._in_view = self.nodes.index.to_numpy() - self._update_local_indices() - # Synapses. TODO: <- THIS IS NOT FOR SYNAPSES OR IS IT? + # Synapses. self.branch_edges = pd.DataFrame( dict(parent_branch_index=[], child_branch_index=[]) ) @@ -106,31 +103,113 @@ def init_conds(self, params: Dict[str, jnp.ndarray]): This is because compartments do not have any axial conductances.""" return {"axial_conductances": jnp.asarray([])} - # TODO: - # def distance(self, endpoint: "CompartmentView") -> float: - # """Return the direct distance between two compartments. - # This does not compute the pathwise distance (which is currently not - # implemented). +class CompartmentView(View): + """CompartmentView.""" - # Args: - # endpoint: The compartment to which to compute the distance to. - # """ - # start_branch = self.view["global_branch_index"].item() - # start_comp = self.view["comp_index"].item() - # start_xyz = interpolate_xyz( - # loc_of_index(start_comp, self.pointer.nseg), self.pointer.xyzr[start_branch] - # ) + def __init__(self, pointer: Module, view: pd.DataFrame): + view = view.assign(controlled_by_param=view.global_comp_index) + super().__init__(pointer, view) - # end_branch = endpoint.view["global_branch_index"].item() - # end_comp = endpoint.view["comp_index"].item() - # end_xyz = interpolate_xyz( - # loc_of_index(end_comp, self.pointer.nseg), self.pointer.xyzr[end_branch] - # ) + def __call__(self, index: int): + if not hasattr(self, "_has_been_called"): + view = super().adjust_view("comp_index", index) + view._has_been_called = True + return view + raise AttributeError( + "'CompartmentView' object has no attribute 'comp' or 'loc'." + ) - # return np.sqrt(np.sum((start_xyz - end_xyz) ** 2)) + def loc(self, loc: float) -> "CompartmentView": + if loc != "all": + assert ( + loc >= 0.0 and loc <= 1.0 + ), "Compartments must be indexed by a continuous value between 0 and 1." + + branch_ind = np.unique(self.view["global_branch_index"].to_numpy()) + if loc != "all" and len(branch_ind) != 1: + raise NotImplementedError( + "Using `.loc()` to index a single compartment of multiple branches is " + "not supported. Use a for loop or use `.comp` to index." + ) + branch_ind = np.squeeze(branch_ind) # shape == (1,) --> shape == () + + # Cast nseg to numpy because in `local_index_of_loc` we instatiate an array + # of length `nseg`. However, if we use `.data_set()` or `.data_stimulate()`, + # the `local_index_of_loc()` method must be compatible with `jit`. Therefore, + # we have to stop this from being traced here and cast to numpy. + nsegs = np.asarray(self.pointer.nseg_per_branch) + index = local_index_of_loc(loc, branch_ind, nsegs) if loc != "all" else "all" + view = self(index) + view._has_been_called = True + return view + + def distance(self, endpoint: "CompartmentView") -> float: + """Return the direct distance between two compartments. + + This does not compute the pathwise distance (which is currently not + implemented). + + Args: + endpoint: The compartment to which to compute the distance to. + """ + start_branch = self.view["global_branch_index"].item() + start_comp = self.view["global_comp_index"].item() + start_xyz = interpolate_xyz( + loc_of_index( + start_comp, + start_branch, + self.pointer.nseg_per_branch, + ), + self.pointer.xyzr[start_branch], + ) + end_branch = endpoint.view["global_branch_index"].item() + end_comp = endpoint.view["global_comp_index"].item() + end_xyz = interpolate_xyz( + loc_of_index( + end_comp, + end_branch, + self.pointer.nseg_per_branch, + ), + self.pointer.xyzr[end_branch], + ) -class CompartmentView: - # TEMPORARY HOTFIX TO ALLOW IMPORT - pass + return np.sqrt(np.sum((start_xyz - end_xyz) ** 2)) + + def vis( + self, + ax: Optional[Axes] = None, + col: str = "k", + type: str = "scatter", + dims: Tuple[int] = (0, 1), + morph_plot_kwargs: Dict = {}, + ) -> Axes: + """Visualize the compartment. + + Args: + ax: An axis into which to plot. + col: The color for all branches. + type: Whether to plot as point ("scatter") or the projected volume ("volume"). + dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of + two of them. + morph_plot_kwargs: Keyword arguments passed to the plotting function. + """ + nodes = self.set_global_index_and_index(self.view) + if type == "volume": + return self.pointer._vis( + ax=ax, + col=col, + dims=dims, + view=nodes, + type="volume", + morph_plot_kwargs=morph_plot_kwargs, + ) + + return self.pointer._scatter( + ax=ax, + col=col, + dims=dims, + view=nodes, + morph_plot_kwargs=morph_plot_kwargs, + ) diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 90afb488..615d08fa 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -51,29 +51,31 @@ def __init__( for cell in cells: self.xyzr += deepcopy(cell.xyzr) - self.cell_list = cells # TODO: THIS IS A TEMPORARY HOTFIX - self.nseg = cells[0].nseg + self.cells = cells + self.nseg_per_branch = np.concatenate( + [cell.nseg_per_branch for cell in self.cells] + ) + self.nseg = int(np.max(self.nseg_per_branch)) + self.cumsum_nseg = cumsum_leading_zero(self.nseg_per_branch) + self._internal_node_inds = np.arange(self.cumsum_nseg[-1]) self._append_params_and_states(self.network_params, self.network_states) - self.nbranches_per_cell = [cell.total_nbranches for cell in cells] - self.nbranchpoints_per_cell = [cell.total_nbranchpoints for cell in cells] + self.nbranches_per_cell = [cell.total_nbranches for cell in self.cells] self.total_nbranches = sum(self.nbranches_per_cell) self.cumsum_nbranches = cumsum_leading_zero(self.nbranches_per_cell) self.nodes = pd.concat([c.nodes for c in cells], ignore_index=True) - self.nodes["global_comp_index"] = np.arange( - self.nseg * self.total_nbranches - ).tolist() - self.nodes["global_branch_index"] = ( - np.arange(self.nseg * self.total_nbranches) // self.nseg + self.nodes["comp_index"] = np.arange(self.cumsum_nseg[-1]) + self.nodes["branch_index"] = np.repeat( + np.arange(self.total_nbranches), self.nseg_per_branch ).tolist() - self.nodes["global_cell_index"] = list( + self.nodes["cell_index"] = list( itertools.chain( *[[i] * int(cell.cumsum_nseg[-1]) for i, cell in enumerate(self.cells)] ) ) - parents = [cell.comb_parents for cell in cells] + parents = [cell.comb_parents for cell in self.cells] self.comb_parents = jnp.concatenate( [p.at[1:].add(self.cumsum_nbranches[i]) for i, p in enumerate(parents)] ) @@ -99,69 +101,58 @@ def __init__( nbranchpoints = jnp.asarray([len(cell.par_inds) for cell in self.cells]) self.cumsum_nbranchpoints_per_cell = cumsum_leading_zero(nbranchpoints) - self.nodes["controlled_by_param"] = 0 - self._in_view = self.nodes.index.to_numpy() - self._update_local_indices() - # Channels. self._gather_channels_from_constituents(cells) self.initialize() self.init_syns() - def init_morph(self): - self.branchpoint_group_inds = build_branchpoint_group_inds( + def __getattr__(self, key: str): + # Ensure that hidden methods such as `__deepcopy__` still work. + if key.startswith("__"): + return super().__getattribute__(key) + + if key == "cell": + view = deepcopy(self.nodes) + view["global_comp_index"] = view["comp_index"] + view["global_branch_index"] = view["branch_index"] + view["global_cell_index"] = view["cell_index"] + return CellView(self, view) + elif key in self.synapse_names: + type_index = self.synapse_names.index(key) + return SynapseView(self, self.edges, key, self.synapses[type_index]) + elif key in self.group_nodes: + inds = self.group_nodes[key].index.values + view = self.nodes.loc[inds] + view["global_comp_index"] = view["comp_index"] + view["global_branch_index"] = view["branch_index"] + view["global_cell_index"] = view["cell_index"] + return GroupView(self, view, CellView, ["cell"]) + else: + raise KeyError(f"Key {key} not recognized.") + + def _init_morph_jaxley_spsolve(self): + branchpoint_group_inds = build_branchpoint_group_inds( len(self.par_inds), self.child_belongs_to_branchpoint, self.cumsum_nseg[-1], ) children_in_level = merge_cells( self.cumsum_nbranches, - self.cumsum_nbranchpoints, - [cell.children_in_level for cell in self.cell_list], + self.cumsum_nbranchpoints_per_cell, + [cell.solve_indexer.children_in_level for cell in self.cells], exclude_first=False, ) parents_in_level = merge_cells( self.cumsum_nbranches, - self.cumsum_nbranchpoints, - [cell.parents_in_level for cell in self.cell_list], + self.cumsum_nbranchpoints_per_cell, + [cell.solve_indexer.parents_in_level for cell in self.cells], exclude_first=False, ) - del self.cell_list - self.initialized_morph = True - - def init_conds(self, params: Dict) -> Dict[str, jnp.ndarray]: - """Given an axial resisitivity, set the coupling conductances.""" - nbranches = self.total_nbranches - nseg = self.nseg - parents = self.comb_parents - - axial_resistivity = jnp.reshape(params["axial_resistivity"], (nbranches, nseg)) - radiuses = jnp.reshape(params["radius"], (nbranches, nseg)) - lengths = jnp.reshape(params["length"], (nbranches, nseg)) - - conds = vmap(Branch.init_branch_conds, in_axes=(0, 0, 0, None))( - axial_resistivity, radiuses, lengths, self.nseg - ) - coupling_conds_fwd = conds[0] - coupling_conds_bwd = conds[1] - summed_coupling_conds = conds[2] - - # The conductance from the children to the branch point. - branchpoint_conds_children = vmap( - compute_coupling_cond_branchpoint, in_axes=(0, 0, 0) - )( - radiuses[self.child_inds, 0], - axial_resistivity[self.child_inds, 0], - lengths[self.child_inds, 0], - ) - # The conductance from the parents to the branch point. - branchpoint_conds_parents = vmap( - compute_coupling_cond_branchpoint, in_axes=(0, 0, 0) - )( - radiuses[self.par_inds, -1], - axial_resistivity[self.par_inds, -1], - lengths[self.par_inds, -1], + padded_cumsum_nseg = cumsum_leading_zero( + np.concatenate( + [np.diff(cell.solve_indexer.cumsum_nseg) for cell in self.cells] + ) ) # Generate mapping to dealing with the masking which allows using the custom @@ -420,58 +411,6 @@ def _synapse_currents( return states, (syn_voltage_terms, syn_constant_terms) - # def _infer_synapse_type_ind(self, synapse_name): - # syn_names = self.base.synapse_names - # is_new_type = False if synapse_name in syn_names else True - # type_ind = len(syn_names) if is_new_type else syn_names.index(synapse_name) - # return type_ind, is_new_type - - # def _update_synapse_state_names(self, synapse_type): - # # (Potentially) update variables that track meta information about synapses. - # self.base.synapse_names.append(synapse_type._name) - # self.base.synapse_param_names += list(synapse_type.synapse_params.keys()) - # self.base.synapse_state_names += list(synapse_type.synapse_states.keys()) - # self.base.synapses.append(synapse_type) - - # def _append_multiple_synapses(self, pre, post, synapse_type): - # # Add synapse types to the module and infer their unique identifier. - # synapse_name = synapse_type._name - # type_ind, is_new = self._infer_synapse_type_ind(synapse_name) - # if is_new: # synapse is not known - # self._update_synapse_state_names(synapse_type) - - # index = len(self.base.edges) - # post_loc = loc_of_index(post._comps_in_view, self.nseg) - # pre_loc = loc_of_index(pre._comps_in_view, self.nseg) - - # # Define new synapses. Each row is one synapse. - # cols = ["comp_index", "branch_index", "cell_index"] - # pre_nodes = pre.nodes[[f"{scope}_{col}" for col in cols for scope in ["local", "global"]]] - # pre_nodes.columns = [f"{scope}_pre_{col}" for col in cols for scope in ["local", "global"]] - # post_nodes = post.nodes[[f"{scope}_{col}" for col in cols for scope in ["local", "global"]]] - # post_nodes.columns = [f"{scope}_post_{col}" for col in cols for scope in ["local", "global"]] - # new_rows = pd.concat([pre_nodes.reset_index(drop=True), post_nodes.reset_index(drop=True)], axis=1) - # new_rows["type"] = synapse_name - # new_rows["type_ind"] = type_ind - # new_rows["pre_loc"] = pre_loc - # new_rows["post_loc"] = post_loc - # self.base.edges = concat_and_ignore_empty( - # [self.base.edges, new_rows], - # ignore_index=True, axis=0 - # ) - - # indices = [idx for idx in range(index, index + len(pre_loc))] - # self._add_params_to_edges(synapse_type, indices) - - # def _add_params_to_edges(self, synapse_type, indices): - # # Add parameters and states to the `.edges` table. - # for key, param_val in synapse_type.synapse_params.items(): - # self.base.edges.loc[indices, key] = param_val - - # # Update synaptic state array. - # for key, state_val in synapse_type.synapse_states.items(): - # self.base.edges.loc[indices, key] = state_val - def vis( self, detail: str = "full", @@ -611,29 +550,140 @@ def vis( return ax - # TODO: CHECK IF THIS WORKS - # def _build_graph(self, layers: Optional[List] = None, **options): - # graph = nx.DiGraph() + def _build_graph(self, layers: Optional[List] = None, **options): + graph = nx.DiGraph() - # def build_extents(*subset_sizes): - # return nx.utils.pairwise(itertools.accumulate((0,) + subset_sizes)) + def build_extents(*subset_sizes): + return nx.utils.pairwise(itertools.accumulate((0,) + subset_sizes)) - # if layers is not None: - # extents = build_extents(*layers) - # layers = [range(start, end) for start, end in extents] - # for i, layer in enumerate(layers): - # graph.add_nodes_from(layer, layer=i) - # else: - # graph.add_nodes_from(range(len(cell.view._cells_in_view))) + if layers is not None: + extents = build_extents(*layers) + layers = [range(start, end) for start, end in extents] + for i, layer in enumerate(layers): + graph.add_nodes_from(layer, layer=i) + else: + graph.add_nodes_from(range(len(self.cells))) + + pre_cell = self.edges["pre_cell_index"].to_numpy() + post_cell = self.edges["post_cell_index"].to_numpy() + + inds = np.stack([pre_cell, post_cell]).T + graph.add_edges_from(inds) + + return graph + + +class SynapseView(View): + """SynapseView.""" - # pre_cell = self.edges["pre_cell_index"].to_numpy() - # post_cell = self.edges["post_cell_index"].to_numpy() + def __init__(self, pointer, view, key, synapse: "jx.Synapse"): + self.synapse = synapse + view = deepcopy(view[view["type"] == key]) + view = view.assign(controlled_by_param=0) - # inds = np.stack([pre_cell, post_cell]).T - # graph.add_edges_from(inds) + # Used for `.set()`. + view["global_index"] = view.index.values + # Used for `__call__()`. + view["index"] = list(range(len(view))) + # Because `make_trainable` needs to access the rows of `jaxedges` (which does + # not contain `NaNa` rows) we need to reset the index here. We undo this for + # `.set()`. `.index.values` is used for `make_trainable`. + view = view.reset_index(drop=True) - # return graph + super().__init__(pointer, view) + def __call__(self, index: int): + self.view["controlled_by_param"] = self.view.index.values + return self.adjust_view("index", index) -class SynapseView: - pass + def show( + self, + *, + indices: bool = True, + params: bool = True, + states: bool = True, + ) -> pd.DataFrame: + """Show synapses.""" + printable_nodes = deepcopy(self.view[["type", "type_ind"]]) + + if indices: + names = [ + "pre_locs", + "pre_branch_index", + "pre_cell_index", + "post_locs", + "post_branch_index", + "post_cell_index", + ] + printable_nodes[names] = self.view[names] + + if params: + for key in self.synapse.synapse_params.keys(): + printable_nodes[key] = self.view[key] + + if states: + for key in self.synapse.synapse_states.keys(): + printable_nodes[key] = self.view[key] + + printable_nodes["controlled_by_param"] = self.view["controlled_by_param"] + return printable_nodes + + def set(self, key: str, val: float): + """Set parameters of the pointer.""" + synapse_index = self.view["type_ind"].values[0] + synapse_type = self.pointer.synapses[synapse_index] + synapse_param_names = list(synapse_type.synapse_params.keys()) + synapse_state_names = list(synapse_type.synapse_states.keys()) + + assert ( + key in synapse_param_names or key in synapse_state_names + ), f"{key} does not exist in synapse of type {synapse_type._name}." + + # Reset index to global index because we are writing to `self.edges`. + self.view = self.view.set_index("global_index", drop=False) + self.pointer._set(key, val, self.view, self.pointer.edges) + + def _assert_key_in_params_or_states(self, key: str): + synapse_index = self.view["type_ind"].values[0] + synapse_type = self.pointer.synapses[synapse_index] + synapse_param_names = list(synapse_type.synapse_params.keys()) + synapse_state_names = list(synapse_type.synapse_states.keys()) + + assert ( + key in synapse_param_names or key in synapse_state_names + ), f"{key} does not exist in synapse of type {synapse_type._name}." + + def make_trainable( + self, + key: str, + init_val: Optional[Union[float, list]] = None, + verbose: bool = True, + ): + """Make a parameter trainable.""" + self._assert_key_in_params_or_states(key) + # Use `.index.values` for indexing because we are memorizing the indices for + # `jaxedges`. + self.pointer._make_trainable(self.view, key, init_val, verbose=verbose) + + def data_set( + self, + key: str, + val: Union[float, jnp.ndarray], + param_state: Optional[List[Dict]] = None, + ): + """Set parameter of module (or its view) to a new value within `jit`.""" + self._assert_key_in_params_or_states(key) + return self.pointer._data_set(key, val, self.view, param_state=param_state) + + def record(self, state: str = "v"): + """Record a state.""" + assert ( + state in self.pointer.synapse_state_names[self.view["type_ind"].values[0]] + ), f"State {state} does not exist in synapse of type {self.view['type'].values[0]}." + + view = deepcopy(self.view) + view["state"] = state + + recording_view = view[["state"]] + recording_view = recording_view.assign(rec_index=view.index) + self.pointer._record(recording_view)