diff --git a/jaxley/modules/base.py b/jaxley/modules/base.py index 6d7a74f6..3a5d3d8a 100644 --- a/jaxley/modules/base.py +++ b/jaxley/modules/base.py @@ -185,7 +185,7 @@ def __str__(self): def __dir__(self): base_dir = object.__dir__(self) - synapses = [s._name for s in self.synapses] + synapses = [s.name for s in self.synapses] groups = [] if len(self.groups) == 0 else list(self.groups.keys()) return sorted(base_dir + synapses + groups) @@ -205,16 +205,16 @@ def __getattr__(self, key): return view # intercepts calls to channels - if key in [c._name for c in self.base.channels]: - channel_names = [c._name for c in self.channels] + if key in [c.name for c in self.base.channels]: + channel_names = [c.name for c in self.channels] inds = self.nodes.index[self.nodes[key]].to_numpy() view = self.select(inds) if key in channel_names else self.select(None) view._set_controlled_by_param(key) return view # intercepts calls to synapse types - base_syn_names = [s._name for s in self.base.synapses] - syn_names = [s._name for s in self.synapses] + base_syn_names = [s.name for s in self.base.synapses] + syn_names = [s.name for s in self.synapses] if key in base_syn_names: syn_inds = self.edges[self.edges["type"] == key][ "global_edge_index" @@ -714,15 +714,60 @@ def _gather_channels_from_constituents(self, constituents: List): """ for module in constituents: for channel in module.channels: - if channel._name not in [c._name for c in self.channels]: + if channel.name not in [c.name for c in self.channels]: self.base.channels.append(channel) if channel.current_name not in self.membrane_current_names: self.base.membrane_current_names.append(channel.current_name) # Setting columns of channel names to `False` instead of `NaN`. for channel in self.base.channels: - name = channel._name + name = channel.name self.base.nodes.loc[self.nodes[name].isna(), name] = False + def _prepare_for_jax(self): + # prepare lookup of indices of states, parameters and mechanisms + global_params = ["radius", "length", "axial_resistivity", "capacitance"] + global_states = ["v"] + + current_names = self.membrane_current_names + self.synapse_current_names + global_states_params = global_states + global_params + current_names + + channel_names = [c.name for c in self.channels] + syn_names = [s.name for s in self.synapses] + + node_attrs = self.nodes.columns.to_list() + current_names + channel_names + + def inds_of_key(key: str) -> np.ndarray: + """Return the indices for params, states, mechanisms and currents.""" + data = self.nodes if key in node_attrs else pd.DataFrame() + data = self.edges if key in self.edges.columns or key in syn_names else data + + if key in channel_names + syn_names: + where = data["type"] == key if key in syn_names else data[key] + elif key in data.columns: + where = ~data[key].isna() + elif key in global_states_params: + where = pd.Index([True] * len(data)) + else: + raise ValueError(f"Key '{key}' not found in nodes or edges") + return data.index[where].to_numpy() + + # expose the lookup function to the class with precomputed attrs in scope + self._inds_of_state_param = inds_of_key + + # add index attrs to mechansisms (i.e. where was it inserted) and also keep track + # of states / parameters that are also shared by other mechanisms. + for mech in self.channels + self.synapses: + mech.indices = self._inds_of_state_param(mech.name) + mech._jax_inds = {} + current = {mech.current_name: None} if isinstance(mech, Channel) else {} + + for param_state in {**mech.params, **mech.states, **current}: + is_global = not param_state.startswith(f"{mech.name}_") + if is_global: + global_inds = self._inds_of_state_param(param_state) + local_inds = np.where(np.isin(global_inds, mech.indices))[0] + mech._jax_inds[param_state] = local_inds + def to_jax(self): """Move `.nodes` to `.jaxnodes`. @@ -784,7 +829,7 @@ def show( scopes = ["local", "global"] inds = [f"{s}_{i}" for i in inds for s in scopes] if indices else [] cols += inds - cols += [ch._name for ch in self.channels] if channel_names else [] + cols += [ch.name for ch in self.channels] if channel_names else [] cols += sum([list(ch.params) for ch in self.channels], []) if params else [] cols += sum([list(ch.states) for ch in self.channels], []) if states else [] @@ -914,7 +959,7 @@ def set_ncomp( all_nodes = self.base.nodes start_idx = self.nodes["global_comp_index"].to_numpy()[0] ncomp_per_branch = self.base.ncomp_per_branch - channel_names = [c._name for c in self.base.channels] + channel_names = [c.name for c in self.base.channels] channel_param_names = list(chain(*[c.params for c in self.base.channels])) channel_state_names = list(chain(*[c.states for c in self.base.channels])) radius_generating_fns = self.base._radius_generating_fns @@ -1202,33 +1247,22 @@ def _get_state_names(self) -> Tuple[List, List]: synapse_states + self.synapse_current_names, ) - def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: - """Get all trainable parameters. - - The returned parameters should be passed to `jx.integrate(..., params=params). - - Returns: - A list of all trainable parameters in the form of - [{"gNa": jnp.array([0.1, 0.2, 0.3])}, ...]. - """ - return self.trainable_params - def _iter_states_params( self, params=False, states=False, currents=False - ) -> Tuple[str, np.ndarray]: # type: ignore + ) -> Tuple[str, np.ndarray]: # type: ignore # assert that either params or states is True assert params or states or currents, "Select either params / states / currents." all_mechs = self.channels + self.synapses - + if params: global_params = ["radius", "length", "axial_resistivity", "capacitance"] - all_params = sum([list(m.params) for m in all_mechs], []) + global_params + all_params = [p for m in all_mechs for p in m.params] + global_params for key in all_params: yield key, self._inds_of_state_param(key) if states: global_states = ["v"] - all_states = sum([list(m.states) for m in all_mechs], []) + global_states + all_states = [s for m in all_mechs for s in m.states] + global_states for key in all_states: yield key, self._inds_of_state_param(key) @@ -1237,49 +1271,16 @@ def _iter_states_params( for key in current_names: yield key, self._inds_of_state_param(key) - def _prepare_for_jax(self): - # prepare lookup of indices of states, parameters and mechanisms - global_params = ["radius", "length", "axial_resistivity", "capacitance"] - global_states = ["v"] - - current_names = self.membrane_current_names + self.synapse_current_names - global_states_params = global_states + global_params + current_names - - channel_names = [c._name for c in self.channels] - syn_names = [s._name for s in self.synapses] - - node_attrs = self.nodes.columns.to_list() + current_names + channel_names - def inds_of_key(key: str) -> np.ndarray: - """Return the indices for params, states, mechanisms and currents.""" - data = self.nodes if key in node_attrs else pd.DataFrame() - data = self.edges if key in self.edges.columns or key in syn_names else data - - if key in channel_names + syn_names: - where = data["type"] == key if key in syn_names else data[key] - elif key in data.columns: - where = ~data[key].isna() - elif key in global_states_params: - where = pd.Index([True] * len(data)) - else: - raise ValueError(f"Key '{key}' not found in nodes or edges") - return data.index[where].to_numpy() - - # expose the lookup function to the class with precomputed attrs in scope - self._inds_of_state_param = inds_of_key + def get_parameters(self) -> List[Dict[str, jnp.ndarray]]: + """Get all trainable parameters. - # add index attrs to mechansisms (i.e. where was it inserted) and also keep track - # of states / parameters that are also shared by other mechanisms. - for mech in self.channels + self.synapses: - mech.indices = self._inds_of_state_param(mech._name) - mech._jax_inds = {} - current = {mech.current_name: None} if isinstance(mech, Channel) else {} + The returned parameters should be passed to `jx.integrate(..., params=params). - for param_state in {**mech.params, **mech.states, **current}: - is_global = not param_state.startswith(f"{mech._name}_") - if is_global: - global_inds = self._inds_of_state_param(param_state) - local_inds = np.where(np.isin(global_inds, mech.indices))[0] - mech._jax_inds[param_state] = local_inds + Returns: + A list of all trainable parameters in the form of + [{"gNa": jnp.array([0.1, 0.2, 0.3])}, ...]. + """ + return self.trainable_params def _get_all_states_params( self, @@ -1737,10 +1738,10 @@ def insert(self, channel: Channel): Args: channel: The channel to insert.""" - name = channel._name + name = channel.name # Channel does not yet exist in the `jx.Module` at all. - if name not in [c._name for c in self.base.channels]: + if name not in [c.name for c in self.base.channels]: self.base.channels.append(channel) self.base.nodes[name] = ( False # Previous columns do not have the new channel. @@ -1765,9 +1766,9 @@ def delete_channel(self, channel: Channel): Args: channel: The channel to remove.""" - name = channel._name - channel_names = [c._name for c in self.channels] - all_channel_names = [c._name for c in self.base.channels] + name = channel.name + channel_names = [c.name for c in self.channels] + all_channel_names = [c.name for c in self.base.channels] if name in channel_names: channel_cols = list({**channel.params, **channel.states}.keys()) self.base.nodes.loc[self._nodes_in_view, channel_cols] = float("nan") @@ -2601,15 +2602,15 @@ def _set_trainables_in_view(self): def _channels_in_view(self, pointer: Union[Module, View]) -> List[Channel]: """Set channels to show only those in view.""" - names = [name._name for name in pointer.channels] + names = [c.name for c in pointer.channels] channel_in_view = self.nodes[names].any(axis=0) channel_in_view = channel_in_view[channel_in_view].index - return [deepcopy(c) for c in pointer.channels if c._name in channel_in_view] + return [deepcopy(c) for c in pointer.channels if c.name in channel_in_view] def _synapses_in_view(self, pointer: Union[Module, View]): """Set synapses to show only those in view.""" names = self.edges["type"].unique() - return [deepcopy(syn) for syn in pointer.synapses if syn._name in names] + return [deepcopy(syn) for syn in pointer.synapses if syn.name in names] def _nbranches_per_cell_in_view(self) -> np.ndarray: cell_nodes = self.nodes.groupby("global_cell_index") diff --git a/jaxley/modules/network.py b/jaxley/modules/network.py index 05f6fabe..4a4e4f9d 100644 --- a/jaxley/modules/network.py +++ b/jaxley/modules/network.py @@ -303,7 +303,7 @@ def _synapse_currents( diff = 1e-3 num_comp = len(voltages) - synapse_current_states = {f"i_{s._name}": zeros for s in syn_channels} + synapse_current_states = {f"i_{s.name}": zeros for s in syn_channels} for i, group in edges.groupby("type_ind"): synapse = syn_channels[i] pre_inds = group["pre_global_comp_index"].to_numpy() @@ -340,15 +340,15 @@ def _synapse_currents( syn_const_terms = syn_const_terms.at[:].add(-gathered_syn_currents[1]) # Save the current (for the unperturbed voltage) as a state that will # also be passed to the state update. - synapse_current_states[f"i_{synapse._name}"] = ( - synapse_current_states[f"i_{synapse._name}"] + synapse_current_states[f"i_{synapse.name}"] = ( + synapse_current_states[f"i_{synapse.name}"] .at[post_inds] .add(synapse_currents_dist[0]) ) # Copy the currents into the `state` dictionary such that they can be # recorded and used by `Channel.update_states()`. - for name in [s._name for s in self.synapses]: + for name in [s.name for s in self.synapses]: states[f"i_{name}"] = synapse_current_states[f"i_{name}"] return states, (syn_voltage_terms, syn_const_terms) @@ -474,14 +474,14 @@ def vis( return ax def _infer_synapse_type_ind(self, synapse_name): - syn_names = [s._name for s in self.base.synapses] + syn_names = [s.name for s in self.base.synapses] is_new_type = False if synapse_name in syn_names else True type_ind = len(syn_names) if is_new_type else syn_names.index(synapse_name) return type_ind, is_new_type def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type): # Add synapse types to the module and infer their unique identifier. - synapse_name = synapse_type._name + synapse_name = synapse_type.name synapse_current_name = f"i_{synapse_name}" type_ind, is_new = self._infer_synapse_type_ind(synapse_name) if is_new: # synapse is not known