Skip to content

Commit

Permalink
doc: add documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
jnsbck committed Dec 16, 2024
1 parent 171246e commit b890044
Showing 1 changed file with 41 additions and 7 deletions.
48 changes: 41 additions & 7 deletions jaxley/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,16 @@ def _gather_channels_from_constituents(self, constituents: List):
self.base.nodes.loc[self.nodes[name].isna(), name] = False

def _prepare_for_jax(self):
"""Prepare the module for simulation with JAX.
This function has to be run inside or before `to_jax`. It's main purpose is to;
1. Prepare the lookup of indices of states, parameters and mechanisms.
2. Add index attributes to mechanisms (i.e. where was it inserted) and also keep
track of states / parameters that are also shared by other mechanisms.
Adds `_inds_of_state_param(key: str)` to the module and also adds `indices` and
`_jax_inds` to the mechanisms.
"""
# prepare lookup of indices of states, parameters and mechanisms
global_params = ["radius", "length", "axial_resistivity", "capacitance"]
global_states = ["v"]
Expand Down Expand Up @@ -1248,9 +1258,18 @@ def _get_state_names(self) -> Tuple[List, List]:
)

def _iter_states_params(
self, params=False, states=False, currents=False
self, params: bool = False, states: bool = False, currents: bool = False
) -> Tuple[str, np.ndarray]: # type: ignore
# assert that either params or states is True
"""Iterate over all states and parameters.
Args:
params: Whether to iterate over parameters.
states: Whether to iterate over states.
currents: Whether to iterate over currents.
Yields:
The key and the indices of the states / parameters.
"""
assert params or states or currents, "Select either params / states / currents."
all_mechs = self.channels + self.synapses

Expand Down Expand Up @@ -1285,12 +1304,27 @@ def get_parameters(self) -> List[Dict[str, jnp.ndarray]]:
def _get_all_states_params(
self,
pstate: List[Dict],
voltage_solver=None,
delta_t=None,
all_params=None,
params=False,
states=False,
voltage_solver: str = None,
delta_t: float = None,
all_params: Dict[str, jnp.ndarray] = None,
params: bool = False,
states: bool = False,
) -> Dict[str, jnp.ndarray]:
"""Get all parameters and/or states of the module.
Common backbone of both `get_all_parameters()` and `get_all_states()`.
Args:
pstate: The state of the trainable parameters.
voltage_solver: The voltage solver that is used.
delta_t: The stepsize.
all_params: All parameters of the module.
params: Whether to get the parameters.
states: Whether to get the states.
Returns:
A dictionary of all parameters and/or states of the module.
"""
states_params = {}
pkeys = {}
for i, p in enumerate(pstate):
Expand Down

0 comments on commit b890044

Please sign in to comment.