diff --git a/dev/index.html b/dev/index.html index 44b47cc6..0ea1051c 100644 --- a/dev/index.html +++ b/dev/index.html @@ -936,6 +936,10 @@
+The official documentation for Jaxley has moved to jaxley.readthedocs.io. +The website you are currently on will be taken down in the future.
+
Jaxley
is a differentiable simulator for biophysical neuron models in JAX. Its key features are:
Jaxley
is a differentiable simulator for biophysical neuron models in JAX. Its key features are:
jit
-compilation, making it as fast as other packages while being fully written in python Jaxley
allows to simulate biophysical neuron models on CPU, GPU, or TPU:
import matplotlib.pyplot as plt\nfrom jax import config\n\nimport jaxley as jx\nfrom jaxley.channels import HH\n\nconfig.update(\"jax_platform_name\", \"cpu\") # Or \"gpu\" / \"tpu\".\n\ncell = jx.Cell() # Define cell.\ncell.insert(HH()) # Insert channels.\n\ncurrent = jx.step_current(i_delay=1.0, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=10.0)\ncell.stimulate(current) # Stimulate with step current.\ncell.record(\"v\") # Record voltage.\n\nv = jx.integrate(cell) # Run simulation.\nplt.plot(v.T) # Plot voltage trace.\n
If you want to learn more, we have tutorials on how to:
Jaxley
is available on pypi
:
pip install jaxley\n
This will install Jaxley
with CPU support. If you want GPU support, follow the instructions on the JAX
github repository to install JAX
with GPU support (in addition to installing Jaxley
). For example, for NVIDIA GPUs, run pip install -U \"jax[cuda12]\"\n
"},{"location":"#feedback-and-contributions","title":"Feedback and Contributions","text":"We welcome any feedback on how Jaxley
is working for your neuron models and are happy to receive bug reports, pull requests and other feedback (see contribute). We wish to maintain a positive community, please read our Code of Conduct.
Apache License Version 2.0 (Apache-2.0)
"},{"location":"#citation","title":"Citation","text":"If you use Jaxley
, consider citing the corresponding paper:
@article{deistler2024differentiable,\n doi = {10.1101/2024.08.21.608979},\n year = {2024},\n publisher = {Cold Spring Harbor Laboratory},\n author = {Deistler, Michael and Kadhim, Kyra L. and Pals, Matthijs and Beck, Jonas and Huang, Ziwei and Gloeckler, Manuel and Lappalainen, Janne K. and Schr{\\\"o}der, Cornelius and Berens, Philipp and Gon{\\c c}alves, Pedro J. and Macke, Jakob H.},\n title = {Differentiable simulation enables large-scale training of detailed biophysical models of neural dynamics},\n journal = {bioRxiv}\n}\n
"},{"location":"code_of_conduct/","title":"Contributor Covenant Code of Conduct","text":""},{"location":"code_of_conduct/#our-pledge","title":"Our Pledge","text":"We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, caste, color, religion, or sexual identity and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community.
"},{"location":"code_of_conduct/#our-standards","title":"Our Standards","text":"Examples of behavior that contributes to a positive environment for our community include:
Examples of unacceptable behavior include:
Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful.
Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate.
"},{"location":"code_of_conduct/#scope","title":"Scope","text":"This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event.
"},{"location":"code_of_conduct/#enforcement","title":"Enforcement","text":"Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting jaxley
developer Michael Deistler via email (michael.deistler@uni-tuebingen.de). All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the reporter of any incident.
"},{"location":"code_of_conduct/#enforcement-guidelines","title":"Enforcement Guidelines","text":"Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct:
"},{"location":"code_of_conduct/#1-correction","title":"1. Correction","text":"Community Impact: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community.
Consequence: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested.
"},{"location":"code_of_conduct/#2-warning","title":"2. Warning","text":"Community Impact: A violation through a single incident or series of actions.
Consequence: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban.
"},{"location":"code_of_conduct/#3-temporary-ban","title":"3. Temporary Ban","text":"Community Impact: A serious violation of community standards, including sustained inappropriate behavior.
Consequence: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban.
"},{"location":"code_of_conduct/#4-permanent-ban","title":"4. Permanent Ban","text":"Community Impact: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals.
Consequence: A permanent ban from any sort of public interaction within the community.
"},{"location":"code_of_conduct/#attribution","title":"Attribution","text":"This Code of Conduct is adapted from the Contributor Covenant, version 2.1, available at https://www.contributor-covenant.org/version/2/1/code_of_conduct.html.
Community Impact Guidelines were inspired by Mozilla\u2019s code of conduct enforcement ladder.
For answers to common questions about this code of conduct, see the FAQ at https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations.
"},{"location":"contribute/","title":"Guide","text":""},{"location":"contribute/#user-experiences-bugs-and-feature-requests","title":"User experiences, bugs, and feature requests","text":"To report bugs and suggest features (including better documentation), please head over to issues on GitHub.
"},{"location":"contribute/#code-contributions","title":"Code contributions","text":"In general, we use pull requests to make changes to Jaxley
. So, if you are planning to make a contribution, please fork, create a feature branch and then make a PR from your feature branch to the upstream Jaxley
(details).
Clone the repo and install via setup.py
using pip install -e \".[dev]\"
(the dev flag installs development and testing dependencies).
For docstrings and comments, we use Google Style.
Code needs to pass through the following tools, which are installed alongside Jaxley
:
black: Automatic code formatting for Python. You can run black manually from the console using black .
in the top directory of the repository, which will format all files.
isort: Used to consistently order imports. You can run isort manually from the console using isort
in the top directory.
black
and isort
are checked as part of our CI actions. If these checks fail please make sure you have installed the latest versions for each of them and run them locally.
Most of the documentation is written in markdown (basic markdown guide).
You can directly fix mistakes and suggest clearer formulations in markdown files simply by initiating a PR on through GitHub. Click on documentation file and look for the little pencil at top right.
"},{"location":"credits/","title":"Credits","text":"Jaxley
is a collaborative project between the groups of Jakob Macke (Uni T\u00fcbingen), Pedro Gon\u00e7alves (KU Leuven / NERF), and Philipp Berens (Uni T\u00fcbingen).
Jaxley
is licensed under the Apache License Version 2.0 (Apache-2.0) and
Copyright (C) 2024 Michael Deistler, Jakob H. Macke, Pedro J. Goncalves, Philipp Berens.
"},{"location":"credits/#important-dependencies-and-prior-art","title":"Important dependencies and prior art","text":"This work was supported by the German Research Foundation (DFG) through Germany\u2019s Excellence Strategy (EXC 2064 \u2013 Project number 390727645) and the CRC 1233 \u201cRobust Vision\u201d, the German Federal Ministry of Education and Research (Tu\u0308bingen AI Center, FKZ: 01IS18039A), the \u2018Certification and Foundations of Safe Machine Learning Systems in Healthcare\u2019 project funded by the Carl Zeiss Foundation, and the European Union (ERC, \u201cDeepCoMechTome\u201d, ref. 101089288, \u201cNextMechMod\u201d, ref. 101039115).
"},{"location":"faq/","title":"Frequently asked questions","text":"Jaxley
? Jaxley
use? See also the discussion page and the issue tracker on the Jaxley
GitHub repository for recent questions and problems.
Jaxley
is available on PyPI
:
pip install jaxley\n
This will install Jaxley
with CPU support. If you want GPU support, follow the instructions on the JAX
github repository to install JAX
with GPU support (in addition to installing Jaxley
). For example, for NVIDIA GPUs, run pip install -U \"jax[cuda12]\"\n
"},{"location":"install/#install-from-source","title":"Install from source","text":"You can also install Jaxley
from source:
git clone https://github.com/jaxleyverse/jaxley.git\ncd jaxley\npip install -e .\n
Note that pip>=21.3
is required to install the editable version with pyproject.toml
see pip docs.
Jaxley
use?","text":"Jaxley
uses the same units as the NEURON
simulator, which are listed here.
All module
s (i.e., compartments, branches, cells, and networks) in Jaxley
can be saved and loaded with pickle:
import jaxley as jx\nimport pickle\n\n# ... define network, cell, etc.\nnetwork = jx.Network([cell1, cell2])\n\n# Save.\nwith open(\"path/to/file.pkl\", \"wb\") as handle:\n pickle.dump(network, handle)\n\n# Load.\nwith open(\"path/to/file.pkl\", \"rb\") as handle:\n network = pickle.load(handle)\n
"},{"location":"faq/question_03/","title":"What kinds of models can be implemented in Jaxley
?","text":"Jaxley
focuses on biophysical, Hodgkin-Huxley-type models. You can think of Jaxley
like the NEURON
simulator written in JAX
.
Jaxley
allows to simulate the following types of models, as well as networks thereof:
For all of these models, Jaxley
is flexible and accurate. For example, it can flexibly add new channel models, use different kinds of synapses (conductance-based, tanh, \u2026), and it can insert different kinds of channels in different branches (or compartments) within single cells. Like NEURON
, Jaxley
implements a backward-Euler solver for stable numerical solution of multi-compartment neurons.
However, Jaxley
does not implement the following types of models:
connect(pre, post, synapse_type)
","text":"Connect two compartments with a chemical synapse.
The pre- and postsynaptic compartments must be different compartments of the same network.
Parameters:
Name Type Description Defaultpre
View
View of the presynaptic compartment.
requiredpost
View
View of the postsynaptic compartment.
requiredsynapse_type
Synapse
The synapse to append
required Source code injaxley/connect.py
def connect(\n pre: \"View\",\n post: \"View\",\n synapse_type: \"Synapse\",\n):\n \"\"\"Connect two compartments with a chemical synapse.\n\n The pre- and postsynaptic compartments must be different compartments of the\n same network.\n\n Args:\n pre: View of the presynaptic compartment.\n post: View of the postsynaptic compartment.\n synapse_type: The synapse to append\n \"\"\"\n assert is_same_network(\n pre, post\n ), \"Pre and post compartments must be part of the same network.\"\n\n pre.base._append_multiple_synapses(pre.nodes, post.nodes, synapse_type)\n
"},{"location":"reference/connect/#jaxley.connect.connectivity_matrix_connect","title":"connectivity_matrix_connect(pre_cell_view, post_cell_view, synapse_type, connectivity_matrix)
","text":"Appends multiple connections which build a custom connected network.
Connects pre- and postsynaptic cells according to a custom connectivity matrix. Entries > 0 in the matrix indicate a connection between the corresponding cells. Connections are from branch 0 location 0 to a randomly chosen branch and loc.
Parameters:
Name Type Description Defaultpre_cell_view
View
View of the presynaptic cell.
requiredpost_cell_view
View
View of the postsynaptic cell.
requiredsynapse_type
Synapse
The synapse to append.
requiredconnectivity_matrix
ndarray[bool]
A boolean matrix indicating the connections between cells.
required Source code injaxley/connect.py
def connectivity_matrix_connect(\n pre_cell_view: \"View\",\n post_cell_view: \"View\",\n synapse_type: \"Synapse\",\n connectivity_matrix: np.ndarray[bool],\n):\n \"\"\"Appends multiple connections which build a custom connected network.\n\n Connects pre- and postsynaptic cells according to a custom connectivity matrix.\n Entries > 0 in the matrix indicate a connection between the corresponding cells.\n Connections are from branch 0 location 0 to a randomly chosen branch and loc.\n\n Args:\n pre_cell_view: View of the presynaptic cell.\n post_cell_view: View of the postsynaptic cell.\n synapse_type: The synapse to append.\n connectivity_matrix: A boolean matrix indicating the connections between cells.\n \"\"\"\n # Get pre- and postsynaptic cell indices.\n pre_cell_inds = pre_cell_view._cells_in_view\n post_cell_inds = post_cell_view._cells_in_view\n # setting scope ensure that this works indep of current scope\n pre_nodes = pre_cell_view.scope(\"local\").branch(0).comp(0).nodes\n pre_nodes[\"index\"] = pre_nodes.index\n pre_cell_nodes = pre_nodes.set_index(\"global_cell_index\")\n\n assert connectivity_matrix.shape == (\n len(pre_cell_inds),\n len(post_cell_inds),\n ), \"Connectivity matrix must have shape (num_pre, num_post).\"\n assert connectivity_matrix.dtype == bool, \"Connectivity matrix must be boolean.\"\n\n # get connection pairs from connectivity matrix\n from_idx, to_idx = np.where(connectivity_matrix)\n pre_cell_inds = pre_cell_inds[from_idx]\n post_cell_inds = post_cell_inds[to_idx]\n\n # Sample random postsynaptic compartments (global comp indices).\n global_post_indices = np.hstack(\n [\n sample_comp(post_cell_view.scope(\"global\").cell(cell_idx))\n for cell_idx in post_cell_inds\n ]\n )\n post_rows = post_cell_view.nodes.loc[global_post_indices]\n\n # Pre-synapse is at the zero-eth branch and zero-eth compartment.\n global_pre_indices = pre_cell_nodes.loc[pre_cell_inds, \"index\"].to_numpy()\n pre_rows = pre_cell_view.select(nodes=global_pre_indices).nodes\n\n pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)\n
"},{"location":"reference/connect/#jaxley.connect.fully_connect","title":"fully_connect(pre_cell_view, post_cell_view, synapse_type)
","text":"Appends multiple connections which build a fully connected layer.
Connections are from branch 0 location 0 to a randomly chosen branch and loc.
Parameters:
Name Type Description Defaultpre_cell_view
View
View of the presynaptic cell.
requiredpost_cell_view
View
View of the postsynaptic cell.
requiredsynapse_type
Synapse
The synapse to append.
required Source code injaxley/connect.py
def fully_connect(\n pre_cell_view: \"View\",\n post_cell_view: \"View\",\n synapse_type: \"Synapse\",\n):\n \"\"\"Appends multiple connections which build a fully connected layer.\n\n Connections are from branch 0 location 0 to a randomly chosen branch and loc.\n\n Args:\n pre_cell_view: View of the presynaptic cell.\n post_cell_view: View of the postsynaptic cell.\n synapse_type: The synapse to append.\n \"\"\"\n # Get pre- and postsynaptic cell indices.\n num_pre = len(pre_cell_view._cells_in_view)\n num_post = len(post_cell_view._cells_in_view)\n\n # Infer indices of (random) postsynaptic compartments.\n global_post_indices = (\n post_cell_view.nodes.groupby(\"global_cell_index\")\n .sample(num_pre, replace=True)\n .index.to_numpy()\n )\n global_post_indices = global_post_indices.reshape((-1, num_pre), order=\"F\").ravel()\n post_rows = post_cell_view.nodes.loc[global_post_indices]\n\n # Pre-synapse is at the zero-eth branch and zero-eth compartment.\n pre_rows = pre_cell_view.scope(\"local\").branch(0).comp(0).nodes.copy()\n # Repeat rows `num_post` times. See SO 50788508.\n pre_rows = pre_rows.loc[pre_rows.index.repeat(num_post)].reset_index(drop=True)\n\n pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)\n
"},{"location":"reference/connect/#jaxley.connect.is_same_network","title":"is_same_network(pre, post)
","text":"Check if views are from the same network.
Source code injaxley/connect.py
def is_same_network(pre: \"View\", post: \"View\") -> bool:\n \"\"\"Check if views are from the same network.\"\"\"\n is_in_net = \"network\" in pre.base.__class__.__name__.lower()\n is_in_same_net = pre.base is post.base\n return is_in_net and is_in_same_net\n
"},{"location":"reference/connect/#jaxley.connect.sample_comp","title":"sample_comp(cell_view, num=1, replace=True)
","text":"Sample a compartment from a cell.
Returns View with shape (num, num_cols).
Source code injaxley/connect.py
def sample_comp(cell_view: \"View\", num: int = 1, replace=True) -> \"CompartmentView\":\n \"\"\"Sample a compartment from a cell.\n\n Returns View with shape (num, num_cols).\"\"\"\n return np.random.choice(cell_view._comps_in_view, num, replace=replace)\n
"},{"location":"reference/connect/#jaxley.connect.sparse_connect","title":"sparse_connect(pre_cell_view, post_cell_view, synapse_type, p)
","text":"Appends multiple connections which build a sparse, randomly connected layer.
Connections are from branch 0 location 0 to a randomly chosen branch and loc.
Parameters:
Name Type Description Defaultpre_cell_view
View
View of the presynaptic cell.
requiredpost_cell_view
View
View of the postsynaptic cell.
requiredsynapse_type
Synapse
The synapse to append.
requiredp
float
Probability of connection.
required Source code injaxley/connect.py
def sparse_connect(\n pre_cell_view: \"View\",\n post_cell_view: \"View\",\n synapse_type: \"Synapse\",\n p: float,\n):\n \"\"\"Appends multiple connections which build a sparse, randomly connected layer.\n\n Connections are from branch 0 location 0 to a randomly chosen branch and loc.\n\n Args:\n pre_cell_view: View of the presynaptic cell.\n post_cell_view: View of the postsynaptic cell.\n synapse_type: The synapse to append.\n p: Probability of connection.\n \"\"\"\n # Get pre- and postsynaptic cell indices.\n pre_cell_inds = pre_cell_view._cells_in_view\n post_cell_inds = post_cell_view._cells_in_view\n num_pre = len(pre_cell_inds)\n num_post = len(post_cell_inds)\n\n num_connections = np.random.binomial(num_pre * num_post, p)\n pre_syn_neurons = np.random.choice(pre_cell_inds, size=num_connections)\n post_syn_neurons = np.random.choice(post_cell_inds, size=num_connections)\n\n # Sort the synapses only for convenience of inspecting `.edges`.\n sorting = np.argsort(pre_syn_neurons)\n pre_syn_neurons = pre_syn_neurons[sorting]\n post_syn_neurons = post_syn_neurons[sorting]\n\n # Post-synapse is a randomly chosen branch and compartment.\n global_post_indices = [\n sample_comp(post_cell_view.scope(\"global\").cell(cell_idx))\n for cell_idx in post_syn_neurons\n ]\n global_post_indices = (\n np.hstack(global_post_indices) if len(global_post_indices) > 1 else []\n )\n post_rows = post_cell_view.base.nodes.loc[global_post_indices]\n\n # Pre-synapse is at the zero-eth branch and zero-eth compartment.\n global_pre_indices = pre_cell_view.base._cumsum_ncomp_per_cell[pre_syn_neurons]\n pre_rows = pre_cell_view.base.nodes.loc[global_pre_indices]\n\n if len(pre_rows) > 0:\n pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)\n
"},{"location":"reference/integration/","title":"Simulation","text":""},{"location":"reference/integration/#jaxley.integrate.add_clamps","title":"add_clamps(externals, external_inds, data_clamps=None)
","text":"Adds clamps to the external inputs.
Parameters:
Name Type Description Defaultexternals
Dict
Current external inputs.
requiredexternal_inds
Dict
Current external indices.
requireddata_clamps
Optional[Tuple[str, ndarray, DataFrame]]
Additional data clamps. Defaults to None.
None
Returns:
Type DescriptionTuple[Dict, Dict]
Tuple[Dict, Dict]: Updated external inputs and indices.
Source code injaxley/integrate.py
def add_clamps(\n externals: Dict,\n external_inds: Dict,\n data_clamps: Optional[Tuple[str, jnp.ndarray, pd.DataFrame]] = None,\n) -> Tuple[Dict, Dict]:\n \"\"\"Adds clamps to the external inputs.\n\n Args:\n externals (Dict): Current external inputs.\n external_inds (Dict): Current external indices.\n data_clamps (Optional[Tuple[str, jnp.ndarray, pd.DataFrame]], optional): Additional data clamps. Defaults to None.\n\n Returns:\n Tuple[Dict, Dict]: Updated external inputs and indices.\n \"\"\"\n # If a clamp is inserted, add it to the external inputs.\n if data_clamps is not None:\n state_name, clamps, inds = data_clamps\n if state_name in externals.keys():\n externals[state_name] = jnp.concatenate([externals[state_name], clamps])\n external_inds[state_name] = jnp.concatenate(\n [external_inds[state_name], inds.index.to_numpy()]\n )\n else:\n externals[state_name] = clamps\n external_inds[state_name] = inds.index.to_numpy()\n\n return externals, external_inds\n
"},{"location":"reference/integration/#jaxley.integrate.add_stimuli","title":"add_stimuli(externals, external_inds, data_stimuli=None)
","text":"Extends the external inputs with the stimuli.
Parameters:
Name Type Description Defaultexternals
Dict
Current external inputs.
requiredexternal_inds
Dict
Current external indices.
requireddata_stimuli
Optional[Tuple[ndarray, DataFrame]]
Additional data stimuli. Defaults to None.
None
Returns:
Type DescriptionTuple[Dict, Dict]
Tuple[Dict, Dict]: Updated external inputs and indices.
Source code injaxley/integrate.py
def add_stimuli(\n externals: Dict,\n external_inds: Dict,\n data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n) -> Tuple[Dict, Dict]:\n \"\"\"Extends the external inputs with the stimuli.\n\n Args:\n externals (Dict): Current external inputs.\n external_inds (Dict): Current external indices.\n data_stimuli (Optional[Tuple[jnp.ndarray, pd.DataFrame]], optional): Additional data stimuli. Defaults to None.\n\n Returns:\n Tuple[Dict, Dict]: Updated external inputs and indices.\n \"\"\"\n # If stimulus is inserted, add it to the external inputs.\n if \"i\" in externals.keys() or data_stimuli is not None:\n if \"i\" in externals.keys():\n if data_stimuli is not None:\n externals[\"i\"] = jnp.concatenate([externals[\"i\"], data_stimuli[1]])\n external_inds[\"i\"] = jnp.concatenate(\n [external_inds[\"i\"], data_stimuli[2].index.to_numpy()]\n )\n else:\n externals[\"i\"] = data_stimuli[1]\n external_inds[\"i\"] = data_stimuli[2].index.to_numpy()\n\n return externals, external_inds\n
"},{"location":"reference/integration/#jaxley.integrate.build_init_and_step_fn","title":"build_init_and_step_fn(module, voltage_solver='jaxley.stone', solver='bwd_euler')
","text":"This function returns the init_fn
and step_fn
which initialize the parameters and states of the neuron model and then step through the model
Parameters:
Name Type Description Defaultmodule
Module
A Module
object that e.g. a cell.
voltage_solver
str
Voltage solver used in step. Defaults to \u201cjaxley.stone\u201d.
'jaxley.stone'
solver
str
ODE solver. Defaults to \u201cbwd_euler\u201d.
'bwd_euler'
Returns:
Type DescriptionTuple[Callable, Callable]
init_fn, step_fn: Functions that initialize the state and parameters, and perform a single integration step, respectively.
Source code injaxley/integrate.py
def build_init_and_step_fn(\n module: Module,\n voltage_solver: str = \"jaxley.stone\",\n solver: str = \"bwd_euler\",\n) -> Tuple[Callable, Callable]:\n \"\"\"This function returns the `init_fn` and `step_fn` which initialize the\n parameters and states of the neuron model and then step through the model\n\n Args:\n module (Module): A `Module` object that e.g. a cell.\n voltage_solver (str, optional): Voltage solver used in step. Defaults to \"jaxley.stone\".\n solver (str, optional): ODE solver. Defaults to \"bwd_euler\".\n\n Returns:\n init_fn, step_fn: Functions that initialize the state and parameters, and perform\n a single integration step, respectively.\n \"\"\"\n # Initialize the external inputs and their indices.\n external_inds = module.external_inds.copy()\n\n def init_fn(\n params: List[Dict[str, jnp.ndarray]],\n all_states: Optional[Dict] = None,\n param_state: Optional[List[Dict]] = None,\n delta_t: float = 0.025,\n ) -> Tuple[Dict, Dict]:\n \"\"\"Initializes the parameters and states of the neuron model.\n\n Args:\n params (List[Dict[str, jnp.ndarray]]): List of trainable parameters.\n all_states (Optional[Dict], optional): State if alread initialized. Defaults to None.\n param_state (Optional[List[Dict]], optional): Parameters returned by `data_set`.. Defaults to None.\n delta_t (float, optional): Step size. Defaults to 0.025.\n\n Returns:\n Tuple[Dict, Dict]: All states and parameters.\n \"\"\"\n # Make the `trainable_params` of the same shape as the `param_state`, such that\n # they can be processed together by `get_all_parameters`.\n pstate = params_to_pstate(params, module.indices_set_by_trainables)\n if param_state is not None:\n pstate += param_state\n\n all_params = module.get_all_parameters(pstate, voltage_solver=voltage_solver)\n all_states = (\n module.get_all_states(pstate, all_params, delta_t)\n if all_states is None\n else all_states\n )\n return all_states, all_params\n\n def step_fn(\n all_states: Dict,\n all_params: Dict,\n externals: Dict,\n external_inds: Dict = external_inds,\n delta_t: float = 0.025,\n ) -> Dict:\n \"\"\"Performs a single integration step with step size delta_t.\n\n Args:\n all_states (Dict): Current state of the neuron model.\n all_params (Dict): Current parameters of the neuron model.\n externals (Dict): External inputs.\n external_inds (Dict, optional): External indices. Defaults to `module.external_inds`.\n delta_t (float, optional): Time step. Defaults to 0.025.\n\n Returns:\n Dict: Updated states.\n \"\"\"\n state = all_states\n state = module.step(\n state,\n delta_t,\n external_inds,\n externals,\n params=all_params,\n solver=solver,\n voltage_solver=voltage_solver,\n )\n return state\n\n return init_fn, step_fn\n
"},{"location":"reference/integration/#jaxley.integrate.integrate","title":"integrate(module, params=[], *, param_state=None, data_stimuli=None, data_clamps=None, t_max=None, delta_t=0.025, solver='bwd_euler', voltage_solver='jaxley.stone', checkpoint_lengths=None, all_states=None, return_states=False)
","text":"Solves ODE and simulates neuron model.
Parameters:
Name Type Description Defaultparams
List[Dict[str, ndarray]]
Trainable parameters returned by get_parameters()
.
[]
param_state
Optional[List[Dict]]
Parameters returned by data_set
.
None
data_stimuli
Optional[Tuple[ndarray, DataFrame]]
Outputs of .data_stimulate()
, only needed if stimuli change across function calls.
None
data_clamps
Optional[Tuple[str, ndarray, DataFrame]]
Outputs of .data_clamp()
, only needed if clamps change across function calls.
None
t_max
Optional[float]
Duration of the simulation in milliseconds. If t_max
is greater than the length of the stimulus input, the stimulus will be padded at the end with zeros. If t_max
is smaller, then the stimulus with be truncated.
None
delta_t
float
Time step of the solver in milliseconds.
0.025
solver
str
Which ODE solver to use. Either of [\u201cfwd_euler\u201d, \u201cbwd_euler\u201d, \u201ccrank_nicolson\u201d].
'bwd_euler'
tridiag_solver
Algorithm to solve tridiagonal systems. The different options only affect bwd_euler
and crank_nicolson
solvers. Either of [\u201cstone\u201d, \u201cthomas\u201d], where stone
is much faster on GPU for long branches with many compartments and thomas
is slightly faster on CPU (thomas
is used in NEURON).
checkpoint_lengths
Optional[List[int]]
Number of timesteps at every level of checkpointing. The prod(checkpoint_lengths)
must be larger or equal to the desired number of simulated timesteps. Warning: the simulation is run for prod(checkpoint_lengths)
timesteps, and the result is posthoc truncated to the desired simulation length. Therefore, a poor choice of checkpoint_lengths
can lead to longer simulation time. If None
, no checkpointing is applied.
None
all_states
Optional[Dict]
An optional initial state that was returned by a previous jx.integrate(..., return_states=True)
run. Overrides potentially trainable initial states.
None
return_states
bool
If True, it returns all states such that the current state of the Module
can be set with set_states
.
False
Source code in jaxley/integrate.py
def integrate(\n module: Module,\n params: List[Dict[str, jnp.ndarray]] = [],\n *,\n param_state: Optional[List[Dict]] = None,\n data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n data_clamps: Optional[Tuple[str, jnp.ndarray, pd.DataFrame]] = None,\n t_max: Optional[float] = None,\n delta_t: float = 0.025,\n solver: str = \"bwd_euler\",\n voltage_solver: str = \"jaxley.stone\",\n checkpoint_lengths: Optional[List[int]] = None,\n all_states: Optional[Dict] = None,\n return_states: bool = False,\n) -> jnp.ndarray:\n \"\"\"\n Solves ODE and simulates neuron model.\n\n Args:\n params: Trainable parameters returned by `get_parameters()`.\n param_state: Parameters returned by `data_set`.\n data_stimuli: Outputs of `.data_stimulate()`, only needed if stimuli change\n across function calls.\n data_clamps: Outputs of `.data_clamp()`, only needed if clamps change across\n function calls.\n t_max: Duration of the simulation in milliseconds. If `t_max` is greater than\n the length of the stimulus input, the stimulus will be padded at the end\n with zeros. If `t_max` is smaller, then the stimulus with be truncated.\n delta_t: Time step of the solver in milliseconds.\n solver: Which ODE solver to use. Either of [\"fwd_euler\", \"bwd_euler\",\n \"crank_nicolson\"].\n tridiag_solver: Algorithm to solve tridiagonal systems. The different options\n only affect `bwd_euler` and `crank_nicolson` solvers. Either of [\"stone\",\n \"thomas\"], where `stone` is much faster on GPU for long branches\n with many compartments and `thomas` is slightly faster on CPU (`thomas` is\n used in NEURON).\n checkpoint_lengths: Number of timesteps at every level of checkpointing. The\n `prod(checkpoint_lengths)` must be larger or equal to the desired number of\n simulated timesteps. Warning: the simulation is run for\n `prod(checkpoint_lengths)` timesteps, and the result is posthoc truncated\n to the desired simulation length. Therefore, a poor choice of\n `checkpoint_lengths` can lead to longer simulation time. If `None`, no\n checkpointing is applied.\n all_states: An optional initial state that was returned by a previous\n `jx.integrate(..., return_states=True)` run. Overrides potentially\n trainable initial states.\n return_states: If True, it returns all states such that the current state of\n the `Module` can be set with `set_states`.\n \"\"\"\n\n assert module.initialized, \"Module is not initialized, run `._initialize()`.\"\n module.to_jax() # Creates `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`.\n\n # Initialize the external inputs and their indices.\n externals = module.externals.copy()\n external_inds = module.external_inds.copy()\n\n # If stimulus is inserted, add it to the external inputs.\n externals, external_inds = add_stimuli(externals, external_inds, data_stimuli)\n\n # If a clamp is inserted, add it to the external inputs.\n externals, external_inds = add_clamps(externals, external_inds, data_clamps)\n\n if not externals.keys():\n # No stimulus was inserted and no clamp was set.\n assert (\n t_max is not None\n ), \"If no stimulus or clamp are inserted you have to specify the simulation duration at `jx.integrate(..., t_max=)`.\"\n\n for key in externals.keys():\n externals[key] = externals[key].T # Shape `(time, num_stimuli)`.\n\n if module.recordings.empty:\n raise ValueError(\"No recordings are set. Please set them.\")\n rec_inds = module.recordings.rec_index.to_numpy()\n rec_states = module.recordings.state.to_numpy()\n\n # Shorten or pad stimulus depending on `t_max`.\n if t_max is not None:\n t_max_steps = int(t_max // delta_t + 1)\n\n # Pad or truncate the stimulus.\n for key in externals.keys():\n if t_max_steps > externals[key].shape[0]:\n if key == \"i\":\n pad = jnp.zeros(\n (t_max_steps - externals[\"i\"].shape[0], externals[\"i\"].shape[1])\n )\n externals[\"i\"] = jnp.concatenate((externals[\"i\"], pad))\n else:\n raise NotImplementedError(\n \"clamp must be at least as long as simulation.\"\n )\n else:\n externals[key] = externals[key][:t_max_steps, :]\n\n init_fn, step_fn = build_init_and_step_fn(\n module, voltage_solver=voltage_solver, solver=solver\n )\n all_states, all_params = init_fn(params, all_states, param_state, delta_t)\n\n def _body_fun(state, externals):\n state = step_fn(state, all_params, externals, external_inds, delta_t)\n recs = jnp.asarray(\n [\n state[rec_state][rec_ind]\n for rec_state, rec_ind in zip(rec_states, rec_inds)\n ]\n )\n return state, recs\n\n # If necessary, pad the stimulus with zeros in order to simulate sufficiently long.\n # The total simulation length will be `prod(checkpoint_lengths)`. At the end, we\n # return only the first `nsteps_to_return` elements (plus the initial state).\n if externals:\n example_key = list(externals.keys())[0]\n nsteps_to_return = len(externals[example_key])\n else:\n nsteps_to_return = t_max_steps\n\n if checkpoint_lengths is None:\n checkpoint_lengths = [nsteps_to_return]\n length = nsteps_to_return\n else:\n length = prod(checkpoint_lengths)\n size_difference = length - nsteps_to_return\n assert (\n nsteps_to_return <= length\n ), \"The desired simulation duration is longer than `prod(nested_length)`.\"\n if externals:\n dummy_external = jnp.zeros(\n (size_difference, externals[example_key].shape[1])\n )\n for key in externals.keys():\n externals[key] = jnp.concatenate([externals[key], dummy_external])\n\n # Record the initial state.\n init_recs = jnp.asarray(\n [\n all_states[rec_state][rec_ind]\n for rec_state, rec_ind in zip(rec_states, rec_inds)\n ]\n )\n init_recording = jnp.expand_dims(init_recs, axis=0)\n\n # Run simulation.\n all_states, recordings = nested_checkpoint_scan(\n _body_fun,\n all_states,\n externals,\n length=length,\n nested_lengths=checkpoint_lengths,\n )\n recs = jnp.concatenate([init_recording, recordings[:nsteps_to_return]], axis=0).T\n return (recs, all_states) if return_states else recs\n
"},{"location":"reference/integration/#jaxley.solver_gate.exponential_euler","title":"exponential_euler(x, dt, x_inf, x_tau)
","text":"An exact solver for the linear dynamical system dx = -(x - x_inf) / x_tau
.
jaxley/solver_gate.py
def exponential_euler(\n x: jnp.ndarray,\n dt: float,\n x_inf: jnp.ndarray,\n x_tau: jnp.ndarray,\n):\n \"\"\"An exact solver for the linear dynamical system `dx = -(x - x_inf) / x_tau`.\"\"\"\n exp_term = save_exp(-dt / x_tau)\n return x * exp_term + x_inf * (1.0 - exp_term)\n
"},{"location":"reference/integration/#jaxley.solver_gate.save_exp","title":"save_exp(x, max_value=20.0)
","text":"Clip the input to a maximum value and return its exponential.
Source code injaxley/solver_gate.py
def save_exp(x, max_value: float = 20.0):\n \"\"\"Clip the input to a maximum value and return its exponential.\"\"\"\n x = jnp.clip(x, a_max=max_value)\n return jnp.exp(x)\n
"},{"location":"reference/integration/#jaxley.solver_gate.solve_inf_gate_exponential","title":"solve_inf_gate_exponential(x, dt, s_inf, tau_s)
","text":"solves dx/dt = (s_inf - x) / tau_s via exponential Euler
Parameters:
Name Type Description Defaultx
ndarray
gate variable
requireddt
float
time_delta
requireds_inf
ndarray
description
requiredtau_s
ndarray
description
requiredReturns:
Name Type Description_type_
updated gate
Source code injaxley/solver_gate.py
def solve_inf_gate_exponential(\n x: jnp.ndarray,\n dt: float,\n s_inf: jnp.ndarray,\n tau_s: jnp.ndarray,\n):\n \"\"\"solves dx/dt = (s_inf - x) / tau_s\n via exponential Euler\n\n Args:\n x (jnp.ndarray): gate variable\n dt (float): time_delta\n s_inf (jnp.ndarray): _description_\n tau_s (jnp.ndarray): _description_\n\n Returns:\n _type_: updated gate\n \"\"\"\n slope = -1.0 / tau_s\n exp_term = save_exp(slope * dt)\n return x * exp_term + s_inf * (1.0 - exp_term)\n
"},{"location":"reference/integration/#jaxley.solver_voltage.step_voltage_explicit","title":"step_voltage_explicit(voltages, voltage_terms, constant_terms, axial_conductances, internal_node_inds, sinks, sources, types, ncomp_per_branch, par_inds, child_inds, nbranches, solver, delta_t, idx, debug_states)
","text":"Solve one timestep of branched nerve equations with explicit (forward) Euler.
Source code injaxley/solver_voltage.py
def step_voltage_explicit(\n voltages: jnp.ndarray,\n voltage_terms: jnp.ndarray,\n constant_terms: jnp.ndarray,\n axial_conductances: jnp.ndarray,\n internal_node_inds: jnp.ndarray,\n sinks: jnp.ndarray,\n sources: jnp.ndarray,\n types: jnp.ndarray,\n ncomp_per_branch: jnp.ndarray,\n par_inds: jnp.ndarray,\n child_inds: jnp.ndarray,\n nbranches: int,\n solver: str,\n delta_t: float,\n idx: JaxleySolveIndexer,\n debug_states,\n) -> jnp.ndarray:\n \"\"\"Solve one timestep of branched nerve equations with explicit (forward) Euler.\"\"\"\n voltages = jnp.reshape(voltages, (nbranches, -1))\n voltage_terms = jnp.reshape(voltage_terms, (nbranches, -1))\n constant_terms = jnp.reshape(constant_terms, (nbranches, -1))\n\n update = _voltage_vectorfield(\n voltages,\n voltage_terms,\n constant_terms,\n types,\n sources,\n sinks,\n axial_conductances,\n par_inds,\n child_inds,\n nbranches,\n solver,\n delta_t,\n idx,\n debug_states,\n )\n new_voltates = voltages + delta_t * update\n return new_voltates.ravel(order=\"C\")\n
"},{"location":"reference/integration/#jaxley.solver_voltage.step_voltage_implicit_with_jaxley_spsolve","title":"step_voltage_implicit_with_jaxley_spsolve(voltages, voltage_terms, constant_terms, axial_conductances, internal_node_inds, sinks, sources, types, ncomp_per_branch, par_inds, child_inds, nbranches, solver, delta_t, idx, debug_states)
","text":"Solve one timestep of branched nerve equations with implicit (backward) Euler.
Source code injaxley/solver_voltage.py
def step_voltage_implicit_with_jaxley_spsolve(\n voltages: jnp.ndarray,\n voltage_terms: jnp.ndarray,\n constant_terms: jnp.ndarray,\n axial_conductances: jnp.ndarray,\n internal_node_inds: jnp.ndarray,\n sinks: jnp.ndarray,\n sources: jnp.ndarray,\n types: jnp.ndarray,\n ncomp_per_branch: jnp.ndarray,\n par_inds: jnp.ndarray,\n child_inds: jnp.ndarray,\n nbranches: int,\n solver: str,\n delta_t: float,\n idx: JaxleySolveIndexer,\n debug_states,\n):\n \"\"\"Solve one timestep of branched nerve equations with implicit (backward) Euler.\"\"\"\n # Build diagonals.\n c2c = np.isin(types, [0, 1, 2])\n total_ncomp = idx.cumsum_ncomp[-1]\n diags = jnp.ones(total_ncomp)\n\n # if-case needed because `.at` does not allow empty inputs, but the input is\n # empty for compartments.\n if len(sinks[c2c]) > 0:\n diags = diags.at[idx.mask(sinks[c2c])].add(delta_t * axial_conductances[c2c])\n\n diags = diags.at[idx.mask(internal_node_inds)].add(delta_t * voltage_terms)\n\n # Build solves.\n solves = jnp.zeros(total_ncomp)\n solves = solves.at[idx.mask(internal_node_inds)].add(\n voltages + delta_t * constant_terms\n )\n\n # Build upper and lower within the branch.\n c2c = types == 0 # c2c = compartment-to-compartment.\n\n # Build uppers.\n uppers = jnp.zeros(total_ncomp)\n upper_inds = sources[c2c] > sinks[c2c]\n sinks_upper = sinks[c2c][upper_inds]\n if len(sinks_upper) > 0:\n uppers = uppers.at[idx.mask(sinks_upper)].add(\n -delta_t * axial_conductances[c2c][upper_inds]\n )\n\n # Build lowers.\n lowers = jnp.zeros(total_ncomp)\n lower_inds = sources[c2c] < sinks[c2c]\n sinks_lower = sinks[c2c][lower_inds]\n if len(sinks_lower) > 0:\n lowers = lowers.at[idx.mask(sinks_lower)].add(\n -delta_t * axial_conductances[c2c][lower_inds]\n )\n\n # Build branchpoint conductances.\n branchpoint_conds_parents = axial_conductances[types == 1]\n branchpoint_conds_children = axial_conductances[types == 2]\n branchpoint_weights_parents = axial_conductances[types == 3]\n branchpoint_weights_children = axial_conductances[types == 4]\n all_branchpoint_vals = jnp.concatenate(\n [branchpoint_weights_parents, branchpoint_weights_children]\n )\n # Find unique group identifiers\n num_branchpoints = len(branchpoint_conds_parents)\n branchpoint_diags = -group_and_sum(\n all_branchpoint_vals, idx.branchpoint_group_inds, num_branchpoints\n )\n branchpoint_solves = jnp.zeros((num_branchpoints,))\n\n branchpoint_conds_children = -delta_t * branchpoint_conds_children\n branchpoint_conds_parents = -delta_t * branchpoint_conds_parents\n\n # Here, I move all child and parent indices towards a branchpoint into a larger\n # vector. This is wasteful, but it makes indexing much easier. JIT compiling\n # makes the speed difference negligible.\n # Children.\n bp_conds_children = jnp.zeros(nbranches)\n bp_weights_children = jnp.zeros(nbranches)\n # Parents.\n bp_conds_parents = jnp.zeros(nbranches)\n bp_weights_parents = jnp.zeros(nbranches)\n\n # `.at[inds]` requires that `inds` is not empty, so we need an if-case here.\n # `len(inds) == 0` is the case for branches and compartments.\n if num_branchpoints > 0:\n bp_conds_children = bp_conds_children.at[child_inds].set(\n branchpoint_conds_children\n )\n bp_weights_children = bp_weights_children.at[child_inds].set(\n branchpoint_weights_children\n )\n bp_conds_parents = bp_conds_parents.at[par_inds].set(branchpoint_conds_parents)\n bp_weights_parents = bp_weights_parents.at[par_inds].set(\n branchpoint_weights_parents\n )\n\n # Triangulate the linear system of equations.\n (\n diags,\n lowers,\n solves,\n uppers,\n branchpoint_diags,\n branchpoint_solves,\n bp_weights_children,\n bp_conds_parents,\n ) = _triang_branched(\n lowers,\n diags,\n uppers,\n solves,\n bp_conds_children,\n bp_conds_parents,\n bp_weights_children,\n bp_weights_parents,\n branchpoint_diags,\n branchpoint_solves,\n solver,\n ncomp_per_branch,\n idx,\n debug_states,\n )\n\n # Backsubstitute the linear system of equations.\n (\n solves,\n lowers,\n diags,\n bp_weights_parents,\n branchpoint_solves,\n bp_conds_children,\n ) = _backsub_branched(\n lowers,\n diags,\n uppers,\n solves,\n bp_conds_children,\n bp_conds_parents,\n bp_weights_children,\n bp_weights_parents,\n branchpoint_diags,\n branchpoint_solves,\n solver,\n ncomp_per_branch,\n idx,\n debug_states,\n )\n return solves.ravel(order=\"C\")[idx.mask(internal_node_inds)]\n
"},{"location":"reference/mechanisms/","title":"Channels","text":""},{"location":"reference/mechanisms/#channel","title":"Channel","text":"Channel base class. All channels inherit from this class.
As in NEURON, a Channel
is considered a distributed process, which means that its conductances are to be specified in S/cm2
and its currents are to be specified in uA/cm2
.
jaxley/channels/channel.py
class Channel:\n \"\"\"Channel base class. All channels inherit from this class.\n\n As in NEURON, a `Channel` is considered a distributed process, which means that its\n conductances are to be specified in `S/cm2` and its currents are to be specified in\n `uA/cm2`.\"\"\"\n\n _name = None\n channel_params = None\n channel_states = None\n current_name = None\n\n def __init__(self, name: Optional[str] = None):\n contact = (\n \"If you have any questions, please reach out via email to \"\n \"michael.deistler@uni-tuebingen.de or create an issue on Github: \"\n \"https://github.com/jaxleyverse/jaxley/issues. Thank you!\"\n )\n if (\n not hasattr(self, \"current_is_in_mA_per_cm2\")\n or not self.current_is_in_mA_per_cm2\n ):\n raise ValueError(\n \"The channel you are using is deprecated. \"\n \"In Jaxley version 0.5.0, we changed the unit of the current returned \"\n \"by `compute_current` of channels from `uA/cm^2` to `mA/cm^2`. Please \"\n \"update your channel model (by dividing the resulting current by 1000) \"\n \"and set `self.current_is_in_mA_per_cm2=True` as the first line \"\n f\"in the `__init__()` method of your channel. {contact}\"\n )\n\n self._name = name if name else self.__class__.__name__\n\n @property\n def name(self) -> Optional[str]:\n \"\"\"The name of the channel (by default, this is the class name).\"\"\"\n return self._name\n\n def change_name(self, new_name: str):\n \"\"\"Change the channel name.\n\n Args:\n new_name: The new name of the channel.\n\n Returns:\n Renamed channel, such that this function is chainable.\n \"\"\"\n old_prefix = self._name + \"_\"\n new_prefix = new_name + \"_\"\n\n self._name = new_name\n self.channel_params = {\n (\n new_prefix + key[len(old_prefix) :]\n if key.startswith(old_prefix)\n else key\n ): value\n for key, value in self.channel_params.items()\n }\n\n self.channel_states = {\n (\n new_prefix + key[len(old_prefix) :]\n if key.startswith(old_prefix)\n else key\n ): value\n for key, value in self.channel_states.items()\n }\n return self\n\n def update_states(\n self, states, dt, v, params\n ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:\n \"\"\"Return the updated states.\"\"\"\n raise NotImplementedError\n\n def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n ):\n \"\"\"Given channel states and voltage, return the current through the channel.\n\n Args:\n states: All states of the compartment.\n v: Voltage of the compartment in mV.\n params: Parameters of the channel (conductances in `S/cm2`).\n\n Returns:\n Current in `uA/cm2`.\n \"\"\"\n raise NotImplementedError\n\n def init_state(\n self,\n states: Dict[str, jnp.ndarray],\n v: jnp.ndarray,\n params: Dict[str, jnp.ndarray],\n delta_t: float,\n ):\n \"\"\"Initialize states of channel.\"\"\"\n return {}\n
"},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.name","title":"name: Optional[str]
property
","text":"The name of the channel (by default, this is the class name).
"},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.change_name","title":"change_name(new_name)
","text":"Change the channel name.
Parameters:
Name Type Description Defaultnew_name
str
The new name of the channel.
requiredReturns:
Type DescriptionRenamed channel, such that this function is chainable.
Source code injaxley/channels/channel.py
def change_name(self, new_name: str):\n \"\"\"Change the channel name.\n\n Args:\n new_name: The new name of the channel.\n\n Returns:\n Renamed channel, such that this function is chainable.\n \"\"\"\n old_prefix = self._name + \"_\"\n new_prefix = new_name + \"_\"\n\n self._name = new_name\n self.channel_params = {\n (\n new_prefix + key[len(old_prefix) :]\n if key.startswith(old_prefix)\n else key\n ): value\n for key, value in self.channel_params.items()\n }\n\n self.channel_states = {\n (\n new_prefix + key[len(old_prefix) :]\n if key.startswith(old_prefix)\n else key\n ): value\n for key, value in self.channel_states.items()\n }\n return self\n
"},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.compute_current","title":"compute_current(states, v, params)
","text":"Given channel states and voltage, return the current through the channel.
Parameters:
Name Type Description Defaultstates
Dict[str, ndarray]
All states of the compartment.
requiredv
Voltage of the compartment in mV.
requiredparams
Dict[str, ndarray]
Parameters of the channel (conductances in S/cm2
).
Returns:
Type DescriptionCurrent in uA/cm2
.
jaxley/channels/channel.py
def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n \"\"\"Given channel states and voltage, return the current through the channel.\n\n Args:\n states: All states of the compartment.\n v: Voltage of the compartment in mV.\n params: Parameters of the channel (conductances in `S/cm2`).\n\n Returns:\n Current in `uA/cm2`.\n \"\"\"\n raise NotImplementedError\n
"},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.init_state","title":"init_state(states, v, params, delta_t)
","text":"Initialize states of channel.
Source code injaxley/channels/channel.py
def init_state(\n self,\n states: Dict[str, jnp.ndarray],\n v: jnp.ndarray,\n params: Dict[str, jnp.ndarray],\n delta_t: float,\n):\n \"\"\"Initialize states of channel.\"\"\"\n return {}\n
"},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.update_states","title":"update_states(states, dt, v, params)
","text":"Return the updated states.
Source code injaxley/channels/channel.py
def update_states(\n self, states, dt, v, params\n) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:\n \"\"\"Return the updated states.\"\"\"\n raise NotImplementedError\n
"},{"location":"reference/mechanisms/#hh","title":"HH","text":" Bases: Channel
Hodgkin-Huxley channel.
Source code injaxley/channels/hh.py
class HH(Channel):\n \"\"\"Hodgkin-Huxley channel.\"\"\"\n\n def __init__(self, name: Optional[str] = None):\n self.current_is_in_mA_per_cm2 = True\n\n super().__init__(name)\n prefix = self._name\n self.channel_params = {\n f\"{prefix}_gNa\": 0.12,\n f\"{prefix}_gK\": 0.036,\n f\"{prefix}_gLeak\": 0.0003,\n f\"{prefix}_eNa\": 50.0,\n f\"{prefix}_eK\": -77.0,\n f\"{prefix}_eLeak\": -54.3,\n }\n self.channel_states = {\n f\"{prefix}_m\": 0.2,\n f\"{prefix}_h\": 0.2,\n f\"{prefix}_n\": 0.2,\n }\n self.current_name = f\"i_HH\"\n\n def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n ):\n \"\"\"Return updated HH channel state.\"\"\"\n prefix = self._name\n m, h, n = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"], states[f\"{prefix}_n\"]\n new_m = solve_gate_exponential(m, dt, *self.m_gate(v))\n new_h = solve_gate_exponential(h, dt, *self.h_gate(v))\n new_n = solve_gate_exponential(n, dt, *self.n_gate(v))\n return {f\"{prefix}_m\": new_m, f\"{prefix}_h\": new_h, f\"{prefix}_n\": new_n}\n\n def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n ):\n \"\"\"Return current through HH channels.\"\"\"\n prefix = self._name\n m, h, n = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"], states[f\"{prefix}_n\"]\n\n gNa = params[f\"{prefix}_gNa\"] * (m**3) * h # S/cm^2\n gK = params[f\"{prefix}_gK\"] * n**4 # S/cm^2\n gLeak = params[f\"{prefix}_gLeak\"] # S/cm^2\n\n return (\n gNa * (v - params[f\"{prefix}_eNa\"])\n + gK * (v - params[f\"{prefix}_eK\"])\n + gLeak * (v - params[f\"{prefix}_eLeak\"])\n )\n\n def init_state(self, states, v, params, delta_t):\n \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n prefix = self._name\n alpha_m, beta_m = self.m_gate(v)\n alpha_h, beta_h = self.h_gate(v)\n alpha_n, beta_n = self.n_gate(v)\n return {\n f\"{prefix}_m\": alpha_m / (alpha_m + beta_m),\n f\"{prefix}_h\": alpha_h / (alpha_h + beta_h),\n f\"{prefix}_n\": alpha_n / (alpha_n + beta_n),\n }\n\n @staticmethod\n def m_gate(v):\n alpha = 0.1 * _vtrap(-(v + 40), 10)\n beta = 4.0 * save_exp(-(v + 65) / 18)\n return alpha, beta\n\n @staticmethod\n def h_gate(v):\n alpha = 0.07 * save_exp(-(v + 65) / 20)\n beta = 1.0 / (save_exp(-(v + 35) / 10) + 1)\n return alpha, beta\n\n @staticmethod\n def n_gate(v):\n alpha = 0.01 * _vtrap(-(v + 55), 10)\n beta = 0.125 * save_exp(-(v + 65) / 80)\n return alpha, beta\n
"},{"location":"reference/mechanisms/#jaxley.channels.hh.HH.compute_current","title":"compute_current(states, v, params)
","text":"Return current through HH channels.
Source code injaxley/channels/hh.py
def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n \"\"\"Return current through HH channels.\"\"\"\n prefix = self._name\n m, h, n = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"], states[f\"{prefix}_n\"]\n\n gNa = params[f\"{prefix}_gNa\"] * (m**3) * h # S/cm^2\n gK = params[f\"{prefix}_gK\"] * n**4 # S/cm^2\n gLeak = params[f\"{prefix}_gLeak\"] # S/cm^2\n\n return (\n gNa * (v - params[f\"{prefix}_eNa\"])\n + gK * (v - params[f\"{prefix}_eK\"])\n + gLeak * (v - params[f\"{prefix}_eLeak\"])\n )\n
"},{"location":"reference/mechanisms/#jaxley.channels.hh.HH.init_state","title":"init_state(states, v, params, delta_t)
","text":"Initialize the state such at fixed point of gate dynamics.
Source code injaxley/channels/hh.py
def init_state(self, states, v, params, delta_t):\n \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n prefix = self._name\n alpha_m, beta_m = self.m_gate(v)\n alpha_h, beta_h = self.h_gate(v)\n alpha_n, beta_n = self.n_gate(v)\n return {\n f\"{prefix}_m\": alpha_m / (alpha_m + beta_m),\n f\"{prefix}_h\": alpha_h / (alpha_h + beta_h),\n f\"{prefix}_n\": alpha_n / (alpha_n + beta_n),\n }\n
"},{"location":"reference/mechanisms/#jaxley.channels.hh.HH.update_states","title":"update_states(states, dt, v, params)
","text":"Return updated HH channel state.
Source code injaxley/channels/hh.py
def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n):\n \"\"\"Return updated HH channel state.\"\"\"\n prefix = self._name\n m, h, n = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"], states[f\"{prefix}_n\"]\n new_m = solve_gate_exponential(m, dt, *self.m_gate(v))\n new_h = solve_gate_exponential(h, dt, *self.h_gate(v))\n new_n = solve_gate_exponential(n, dt, *self.n_gate(v))\n return {f\"{prefix}_m\": new_m, f\"{prefix}_h\": new_h, f\"{prefix}_n\": new_n}\n
"},{"location":"reference/mechanisms/#pospischil","title":"Pospischil","text":" Bases: Channel
Leak current
Source code injaxley/channels/pospischil.py
class Leak(Channel):\n \"\"\"Leak current\"\"\"\n\n def __init__(self, name: Optional[str] = None):\n self.current_is_in_mA_per_cm2 = True\n\n super().__init__(name)\n prefix = self._name\n self.channel_params = {\n f\"{prefix}_gLeak\": 1e-4,\n f\"{prefix}_eLeak\": -70.0,\n }\n self.channel_states = {}\n self.current_name = f\"i_{prefix}\"\n\n def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n ):\n \"\"\"No state to update.\"\"\"\n return {}\n\n def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n ):\n \"\"\"Return current.\"\"\"\n prefix = self._name\n gLeak = params[f\"{prefix}_gLeak\"] # S/cm^2\n return gLeak * (v - params[f\"{prefix}_eLeak\"])\n\n def init_state(self, states, v, params, delta_t):\n return {}\n
Bases: Channel
Sodium channel
Source code injaxley/channels/pospischil.py
class Na(Channel):\n \"\"\"Sodium channel\"\"\"\n\n def __init__(self, name: Optional[str] = None):\n self.current_is_in_mA_per_cm2 = True\n\n super().__init__(name)\n prefix = self._name\n self.channel_params = {\n f\"{prefix}_gNa\": 50e-3,\n \"eNa\": 50.0,\n \"vt\": -60.0, # Global parameter, not prefixed with `Na`.\n }\n self.channel_states = {f\"{prefix}_m\": 0.2, f\"{prefix}_h\": 0.2}\n self.current_name = f\"i_Na\"\n\n def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n ):\n \"\"\"Update state.\"\"\"\n prefix = self._name\n m, h = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"]\n new_m = solve_gate_exponential(m, dt, *self.m_gate(v, params[\"vt\"]))\n new_h = solve_gate_exponential(h, dt, *self.h_gate(v, params[\"vt\"]))\n return {f\"{prefix}_m\": new_m, f\"{prefix}_h\": new_h}\n\n def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n ):\n \"\"\"Return current.\"\"\"\n prefix = self._name\n m, h = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"]\n\n gNa = params[f\"{prefix}_gNa\"] * (m**3) * h # S/cm^2\n\n current = gNa * (v - params[\"eNa\"])\n return current\n\n def init_state(self, states, v, params, delta_t):\n \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n prefix = self._name\n alpha_m, beta_m = self.m_gate(v, params[\"vt\"])\n alpha_h, beta_h = self.h_gate(v, params[\"vt\"])\n return {\n f\"{prefix}_m\": alpha_m / (alpha_m + beta_m),\n f\"{prefix}_h\": alpha_h / (alpha_h + beta_h),\n }\n\n @staticmethod\n def m_gate(v, vt):\n v_alpha = v - vt - 13.0\n alpha = 0.32 * efun(-0.25 * v_alpha) / 0.25\n\n v_beta = v - vt - 40.0\n beta = 0.28 * efun(0.2 * v_beta) / 0.2\n return alpha, beta\n\n @staticmethod\n def h_gate(v, vt):\n v_alpha = v - vt - 17.0\n alpha = 0.128 * save_exp(-v_alpha / 18.0)\n\n v_beta = v - vt - 40.0\n beta = 4.0 / (save_exp(-v_beta / 5.0) + 1.0)\n return alpha, beta\n
Bases: Channel
Potassium channel
Source code injaxley/channels/pospischil.py
class K(Channel):\n \"\"\"Potassium channel\"\"\"\n\n def __init__(self, name: Optional[str] = None):\n self.current_is_in_mA_per_cm2 = True\n\n super().__init__(name)\n prefix = self._name\n self.channel_params = {\n f\"{prefix}_gK\": 5e-3,\n \"eK\": -90.0,\n \"vt\": -60.0, # Global parameter, not prefixed with `Na`.\n }\n self.channel_states = {f\"{prefix}_n\": 0.2}\n self.current_name = f\"i_K\"\n\n def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n ):\n \"\"\"Update state.\"\"\"\n prefix = self._name\n n = states[f\"{prefix}_n\"]\n new_n = solve_gate_exponential(n, dt, *self.n_gate(v, params[\"vt\"]))\n return {f\"{prefix}_n\": new_n}\n\n def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n ):\n \"\"\"Return current.\"\"\"\n prefix = self._name\n n = states[f\"{prefix}_n\"]\n\n gK = params[f\"{prefix}_gK\"] * (n**4) # S/cm^2\n\n return gK * (v - params[\"eK\"])\n\n def init_state(self, states, v, params, delta_t):\n \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n prefix = self._name\n alpha_n, beta_n = self.n_gate(v, params[\"vt\"])\n return {f\"{prefix}_n\": alpha_n / (alpha_n + beta_n)}\n\n @staticmethod\n def n_gate(v, vt):\n v_alpha = v - vt - 15.0\n alpha = 0.032 * efun(-0.2 * v_alpha) / 0.2\n\n v_beta = v - vt - 10.0\n beta = 0.5 * save_exp(-v_beta / 40.0)\n return alpha, beta\n
Bases: Channel
Slow M Potassium channel
Source code injaxley/channels/pospischil.py
class Km(Channel):\n \"\"\"Slow M Potassium channel\"\"\"\n\n def __init__(self, name: Optional[str] = None):\n self.current_is_in_mA_per_cm2 = True\n\n super().__init__(name)\n prefix = self._name\n self.channel_params = {\n f\"{prefix}_gKm\": 0.004e-3,\n f\"{prefix}_taumax\": 4000.0,\n f\"eK\": -90.0,\n }\n self.channel_states = {f\"{prefix}_p\": 0.2}\n self.current_name = f\"i_K\"\n\n def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n ):\n \"\"\"Update state.\"\"\"\n prefix = self._name\n p = states[f\"{prefix}_p\"]\n new_p = solve_inf_gate_exponential(\n p, dt, *self.p_gate(v, params[f\"{prefix}_taumax\"])\n )\n return {f\"{prefix}_p\": new_p}\n\n def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n ):\n \"\"\"Return current.\"\"\"\n prefix = self._name\n p = states[f\"{prefix}_p\"]\n\n gKm = params[f\"{prefix}_gKm\"] * p # S/cm^2\n return gKm * (v - params[\"eK\"])\n\n def init_state(self, states, v, params, delta_t):\n \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n prefix = self._name\n alpha_p, beta_p = self.p_gate(v, params[f\"{prefix}_taumax\"])\n return {f\"{prefix}_p\": alpha_p / (alpha_p + beta_p)}\n\n @staticmethod\n def p_gate(v, taumax):\n v_p = v + 35.0\n p_inf = 1.0 / (1.0 + save_exp(-0.1 * v_p))\n\n tau_p = taumax / (3.3 * save_exp(0.05 * v_p) + save_exp(-0.05 * v_p))\n\n return p_inf, tau_p\n
Bases: Channel
L-type Calcium channel
Source code injaxley/channels/pospischil.py
class CaL(Channel):\n \"\"\"L-type Calcium channel\"\"\"\n\n def __init__(self, name: Optional[str] = None):\n self.current_is_in_mA_per_cm2 = True\n\n super().__init__(name)\n prefix = self._name\n self.channel_params = {\n f\"{prefix}_gCaL\": 0.1e-3,\n \"eCa\": 120.0,\n }\n self.channel_states = {f\"{prefix}_q\": 0.2, f\"{prefix}_r\": 0.2}\n self.current_name = f\"i_Ca\"\n\n def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n ):\n \"\"\"Update state.\"\"\"\n prefix = self._name\n q, r = states[f\"{prefix}_q\"], states[f\"{prefix}_r\"]\n new_q = solve_gate_exponential(q, dt, *self.q_gate(v))\n new_r = solve_gate_exponential(r, dt, *self.r_gate(v))\n return {f\"{prefix}_q\": new_q, f\"{prefix}_r\": new_r}\n\n def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n ):\n \"\"\"Return current.\"\"\"\n prefix = self._name\n q, r = states[f\"{prefix}_q\"], states[f\"{prefix}_r\"]\n gCaL = params[f\"{prefix}_gCaL\"] * (q**2) * r # S/cm^2\n\n return gCaL * (v - params[\"eCa\"])\n\n def init_state(self, states, v, params, delta_t):\n \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n prefix = self._name\n alpha_q, beta_q = self.q_gate(v)\n alpha_r, beta_r = self.r_gate(v)\n return {\n f\"{prefix}_q\": alpha_q / (alpha_q + beta_q),\n f\"{prefix}_r\": alpha_r / (alpha_r + beta_r),\n }\n\n @staticmethod\n def q_gate(v):\n v_alpha = -v - 27.0\n alpha = 0.055 * efun(v_alpha / 3.8) * 3.8\n\n v_beta = -v - 75.0\n beta = 0.94 * save_exp(v_beta / 17.0)\n return alpha, beta\n\n @staticmethod\n def r_gate(v):\n v_alpha = -v - 13.0\n alpha = 0.000457 * save_exp(v_alpha / 50)\n\n v_beta = -v - 15.0\n beta = 0.0065 / (save_exp(v_beta / 28.0) + 1)\n return alpha, beta\n
Bases: Channel
T-type Calcium channel
Source code injaxley/channels/pospischil.py
class CaT(Channel):\n \"\"\"T-type Calcium channel\"\"\"\n\n def __init__(self, name: Optional[str] = None):\n self.current_is_in_mA_per_cm2 = True\n\n super().__init__(name)\n prefix = self._name\n self.channel_params = {\n f\"{prefix}_gCaT\": 0.4e-4,\n f\"{prefix}_vx\": 2.0,\n \"eCa\": 120.0, # Global parameter, not prefixed with `CaT`.\n }\n self.channel_states = {f\"{prefix}_u\": 0.2}\n self.current_name = f\"i_Ca\"\n\n def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n ):\n \"\"\"Update state.\"\"\"\n prefix = self._name\n u = states[f\"{prefix}_u\"]\n new_u = solve_inf_gate_exponential(\n u, dt, *self.u_gate(v, params[f\"{prefix}_vx\"])\n )\n return {f\"{prefix}_u\": new_u}\n\n def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n ):\n \"\"\"Return current.\"\"\"\n prefix = self._name\n u = states[f\"{prefix}_u\"]\n s_inf = 1.0 / (1.0 + save_exp(-(v + params[f\"{prefix}_vx\"] + 57.0) / 6.2))\n\n gCaT = params[f\"{prefix}_gCaT\"] * (s_inf**2) * u # S/cm^2\n\n return gCaT * (v - params[\"eCa\"])\n\n def init_state(self, states, v, params, delta_t):\n \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n prefix = self._name\n alpha_u, beta_u = self.u_gate(v, params[f\"{prefix}_vx\"])\n return {f\"{prefix}_u\": alpha_u / (alpha_u + beta_u)}\n\n @staticmethod\n def u_gate(v, vx):\n v_u1 = v + vx + 81.0\n u_inf = 1.0 / (1.0 + save_exp(v_u1 / 4))\n\n tau_u = (30.8 + (211.4 + save_exp((v + vx + 113.2) / 5.0))) / (\n 3.7 * (1 + save_exp((v + vx + 84.0) / 3.2))\n )\n\n return u_inf, tau_u\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Leak.compute_current","title":"compute_current(states, v, params)
","text":"Return current.
Source code injaxley/channels/pospischil.py
def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n \"\"\"Return current.\"\"\"\n prefix = self._name\n gLeak = params[f\"{prefix}_gLeak\"] # S/cm^2\n return gLeak * (v - params[f\"{prefix}_eLeak\"])\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Leak.update_states","title":"update_states(states, dt, v, params)
","text":"No state to update.
Source code injaxley/channels/pospischil.py
def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n):\n \"\"\"No state to update.\"\"\"\n return {}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Na.compute_current","title":"compute_current(states, v, params)
","text":"Return current.
Source code injaxley/channels/pospischil.py
def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n \"\"\"Return current.\"\"\"\n prefix = self._name\n m, h = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"]\n\n gNa = params[f\"{prefix}_gNa\"] * (m**3) * h # S/cm^2\n\n current = gNa * (v - params[\"eNa\"])\n return current\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Na.init_state","title":"init_state(states, v, params, delta_t)
","text":"Initialize the state such at fixed point of gate dynamics.
Source code injaxley/channels/pospischil.py
def init_state(self, states, v, params, delta_t):\n \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n prefix = self._name\n alpha_m, beta_m = self.m_gate(v, params[\"vt\"])\n alpha_h, beta_h = self.h_gate(v, params[\"vt\"])\n return {\n f\"{prefix}_m\": alpha_m / (alpha_m + beta_m),\n f\"{prefix}_h\": alpha_h / (alpha_h + beta_h),\n }\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Na.update_states","title":"update_states(states, dt, v, params)
","text":"Update state.
Source code injaxley/channels/pospischil.py
def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n):\n \"\"\"Update state.\"\"\"\n prefix = self._name\n m, h = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"]\n new_m = solve_gate_exponential(m, dt, *self.m_gate(v, params[\"vt\"]))\n new_h = solve_gate_exponential(h, dt, *self.h_gate(v, params[\"vt\"]))\n return {f\"{prefix}_m\": new_m, f\"{prefix}_h\": new_h}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.K.compute_current","title":"compute_current(states, v, params)
","text":"Return current.
Source code injaxley/channels/pospischil.py
def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n \"\"\"Return current.\"\"\"\n prefix = self._name\n n = states[f\"{prefix}_n\"]\n\n gK = params[f\"{prefix}_gK\"] * (n**4) # S/cm^2\n\n return gK * (v - params[\"eK\"])\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.K.init_state","title":"init_state(states, v, params, delta_t)
","text":"Initialize the state such at fixed point of gate dynamics.
Source code injaxley/channels/pospischil.py
def init_state(self, states, v, params, delta_t):\n \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n prefix = self._name\n alpha_n, beta_n = self.n_gate(v, params[\"vt\"])\n return {f\"{prefix}_n\": alpha_n / (alpha_n + beta_n)}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.K.update_states","title":"update_states(states, dt, v, params)
","text":"Update state.
Source code injaxley/channels/pospischil.py
def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n):\n \"\"\"Update state.\"\"\"\n prefix = self._name\n n = states[f\"{prefix}_n\"]\n new_n = solve_gate_exponential(n, dt, *self.n_gate(v, params[\"vt\"]))\n return {f\"{prefix}_n\": new_n}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Km.compute_current","title":"compute_current(states, v, params)
","text":"Return current.
Source code injaxley/channels/pospischil.py
def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n \"\"\"Return current.\"\"\"\n prefix = self._name\n p = states[f\"{prefix}_p\"]\n\n gKm = params[f\"{prefix}_gKm\"] * p # S/cm^2\n return gKm * (v - params[\"eK\"])\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Km.init_state","title":"init_state(states, v, params, delta_t)
","text":"Initialize the state such at fixed point of gate dynamics.
Source code injaxley/channels/pospischil.py
def init_state(self, states, v, params, delta_t):\n \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n prefix = self._name\n alpha_p, beta_p = self.p_gate(v, params[f\"{prefix}_taumax\"])\n return {f\"{prefix}_p\": alpha_p / (alpha_p + beta_p)}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Km.update_states","title":"update_states(states, dt, v, params)
","text":"Update state.
Source code injaxley/channels/pospischil.py
def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n):\n \"\"\"Update state.\"\"\"\n prefix = self._name\n p = states[f\"{prefix}_p\"]\n new_p = solve_inf_gate_exponential(\n p, dt, *self.p_gate(v, params[f\"{prefix}_taumax\"])\n )\n return {f\"{prefix}_p\": new_p}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaL.compute_current","title":"compute_current(states, v, params)
","text":"Return current.
Source code injaxley/channels/pospischil.py
def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n \"\"\"Return current.\"\"\"\n prefix = self._name\n q, r = states[f\"{prefix}_q\"], states[f\"{prefix}_r\"]\n gCaL = params[f\"{prefix}_gCaL\"] * (q**2) * r # S/cm^2\n\n return gCaL * (v - params[\"eCa\"])\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaL.init_state","title":"init_state(states, v, params, delta_t)
","text":"Initialize the state such at fixed point of gate dynamics.
Source code injaxley/channels/pospischil.py
def init_state(self, states, v, params, delta_t):\n \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n prefix = self._name\n alpha_q, beta_q = self.q_gate(v)\n alpha_r, beta_r = self.r_gate(v)\n return {\n f\"{prefix}_q\": alpha_q / (alpha_q + beta_q),\n f\"{prefix}_r\": alpha_r / (alpha_r + beta_r),\n }\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaL.update_states","title":"update_states(states, dt, v, params)
","text":"Update state.
Source code injaxley/channels/pospischil.py
def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n):\n \"\"\"Update state.\"\"\"\n prefix = self._name\n q, r = states[f\"{prefix}_q\"], states[f\"{prefix}_r\"]\n new_q = solve_gate_exponential(q, dt, *self.q_gate(v))\n new_r = solve_gate_exponential(r, dt, *self.r_gate(v))\n return {f\"{prefix}_q\": new_q, f\"{prefix}_r\": new_r}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaT.compute_current","title":"compute_current(states, v, params)
","text":"Return current.
Source code injaxley/channels/pospischil.py
def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n \"\"\"Return current.\"\"\"\n prefix = self._name\n u = states[f\"{prefix}_u\"]\n s_inf = 1.0 / (1.0 + save_exp(-(v + params[f\"{prefix}_vx\"] + 57.0) / 6.2))\n\n gCaT = params[f\"{prefix}_gCaT\"] * (s_inf**2) * u # S/cm^2\n\n return gCaT * (v - params[\"eCa\"])\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaT.init_state","title":"init_state(states, v, params, delta_t)
","text":"Initialize the state such at fixed point of gate dynamics.
Source code injaxley/channels/pospischil.py
def init_state(self, states, v, params, delta_t):\n \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n prefix = self._name\n alpha_u, beta_u = self.u_gate(v, params[f\"{prefix}_vx\"])\n return {f\"{prefix}_u\": alpha_u / (alpha_u + beta_u)}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaT.update_states","title":"update_states(states, dt, v, params)
","text":"Update state.
Source code injaxley/channels/pospischil.py
def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n):\n \"\"\"Update state.\"\"\"\n prefix = self._name\n u = states[f\"{prefix}_u\"]\n new_u = solve_inf_gate_exponential(\n u, dt, *self.u_gate(v, params[f\"{prefix}_vx\"])\n )\n return {f\"{prefix}_u\": new_u}\n
"},{"location":"reference/mechanisms/#synapses","title":"Synapses","text":""},{"location":"reference/mechanisms/#synapse","title":"Synapse","text":"Base class for a synapse.
As in NEURON, a Synapse
is considered a point process, which means that its conductances are to be specified in uS
and its currents are to be specified in nA
.
jaxley/synapses/synapse.py
class Synapse:\n \"\"\"Base class for a synapse.\n\n As in NEURON, a `Synapse` is considered a point process, which means that its\n conductances are to be specified in `uS` and its currents are to be specified in\n `nA`.\n \"\"\"\n\n _name = None\n synapse_params = None\n synapse_states = None\n\n def __init__(self, name: Optional[str] = None):\n self._name = name if name else self.__class__.__name__\n\n @property\n def name(self) -> Optional[str]:\n return self._name\n\n def change_name(self, new_name: str):\n \"\"\"Change the synapse name.\n\n Args:\n new_name: The new name of the channel.\n\n Returns:\n Renamed channel, such that this function is chainable.\n \"\"\"\n old_prefix = self._name + \"_\"\n new_prefix = new_name + \"_\"\n\n self._name = new_name\n self.synapse_params = {\n (\n new_prefix + key[len(old_prefix) :]\n if key.startswith(old_prefix)\n else key\n ): value\n for key, value in self.synapse_params.items()\n }\n\n self.synapse_states = {\n (\n new_prefix + key[len(old_prefix) :]\n if key.startswith(old_prefix)\n else key\n ): value\n for key, value in self.synapse_states.items()\n }\n return self\n\n def update_states(\n states: Dict[str, jnp.ndarray],\n delta_t: float,\n pre_voltage: jnp.ndarray,\n post_voltage: jnp.ndarray,\n params: Dict[str, jnp.ndarray],\n ) -> Dict[str, jnp.ndarray]:\n \"\"\"ODE update step.\n\n Args:\n states: States of the synapse.\n delta_t: Time step in `ms`.\n pre_voltage: Voltage of the presynaptic compartment, shape `()`.\n post_voltage: Voltage of the postsynaptic compartment, shape `()`.\n params: Parameters of the synapse. Conductances in `uS`.\n\n Returns:\n Updated states.\"\"\"\n raise NotImplementedError\n\n def compute_current(\n states: Dict[str, jnp.ndarray],\n pre_voltage: jnp.ndarray,\n post_voltage: jnp.ndarray,\n params: Dict[str, jnp.ndarray],\n ) -> jnp.ndarray:\n \"\"\"Return current through one synapse in `nA`.\n\n Internally, we use `jax.vmap` to vectorize this function across many synapses.\n\n Args:\n states: States of the synapse.\n pre_voltage: Voltage of the presynaptic compartment, shape `()`.\n post_voltage: Voltage of the postsynaptic compartment, shape `()`.\n params: Parameters of the synapse. Conductances in `uS`.\n\n Returns:\n Current through the synapse in `nA`, shape `()`.\n \"\"\"\n raise NotImplementedError\n
"},{"location":"reference/mechanisms/#jaxley.synapses.synapse.Synapse.change_name","title":"change_name(new_name)
","text":"Change the synapse name.
Parameters:
Name Type Description Defaultnew_name
str
The new name of the channel.
requiredReturns:
Type DescriptionRenamed channel, such that this function is chainable.
Source code injaxley/synapses/synapse.py
def change_name(self, new_name: str):\n \"\"\"Change the synapse name.\n\n Args:\n new_name: The new name of the channel.\n\n Returns:\n Renamed channel, such that this function is chainable.\n \"\"\"\n old_prefix = self._name + \"_\"\n new_prefix = new_name + \"_\"\n\n self._name = new_name\n self.synapse_params = {\n (\n new_prefix + key[len(old_prefix) :]\n if key.startswith(old_prefix)\n else key\n ): value\n for key, value in self.synapse_params.items()\n }\n\n self.synapse_states = {\n (\n new_prefix + key[len(old_prefix) :]\n if key.startswith(old_prefix)\n else key\n ): value\n for key, value in self.synapse_states.items()\n }\n return self\n
"},{"location":"reference/mechanisms/#jaxley.synapses.synapse.Synapse.compute_current","title":"compute_current(states, pre_voltage, post_voltage, params)
","text":"Return current through one synapse in nA
.
Internally, we use jax.vmap
to vectorize this function across many synapses.
Parameters:
Name Type Description Defaultstates
Dict[str, ndarray]
States of the synapse.
requiredpre_voltage
ndarray
Voltage of the presynaptic compartment, shape ()
.
post_voltage
ndarray
Voltage of the postsynaptic compartment, shape ()
.
params
Dict[str, ndarray]
Parameters of the synapse. Conductances in uS
.
Returns:
Type Descriptionndarray
Current through the synapse in nA
, shape ()
.
jaxley/synapses/synapse.py
def compute_current(\n states: Dict[str, jnp.ndarray],\n pre_voltage: jnp.ndarray,\n post_voltage: jnp.ndarray,\n params: Dict[str, jnp.ndarray],\n) -> jnp.ndarray:\n \"\"\"Return current through one synapse in `nA`.\n\n Internally, we use `jax.vmap` to vectorize this function across many synapses.\n\n Args:\n states: States of the synapse.\n pre_voltage: Voltage of the presynaptic compartment, shape `()`.\n post_voltage: Voltage of the postsynaptic compartment, shape `()`.\n params: Parameters of the synapse. Conductances in `uS`.\n\n Returns:\n Current through the synapse in `nA`, shape `()`.\n \"\"\"\n raise NotImplementedError\n
"},{"location":"reference/mechanisms/#jaxley.synapses.synapse.Synapse.update_states","title":"update_states(states, delta_t, pre_voltage, post_voltage, params)
","text":"ODE update step.
Parameters:
Name Type Description Defaultstates
Dict[str, ndarray]
States of the synapse.
requireddelta_t
float
Time step in ms
.
pre_voltage
ndarray
Voltage of the presynaptic compartment, shape ()
.
post_voltage
ndarray
Voltage of the postsynaptic compartment, shape ()
.
params
Dict[str, ndarray]
Parameters of the synapse. Conductances in uS
.
Returns:
Type DescriptionDict[str, ndarray]
Updated states.
Source code injaxley/synapses/synapse.py
def update_states(\n states: Dict[str, jnp.ndarray],\n delta_t: float,\n pre_voltage: jnp.ndarray,\n post_voltage: jnp.ndarray,\n params: Dict[str, jnp.ndarray],\n) -> Dict[str, jnp.ndarray]:\n \"\"\"ODE update step.\n\n Args:\n states: States of the synapse.\n delta_t: Time step in `ms`.\n pre_voltage: Voltage of the presynaptic compartment, shape `()`.\n post_voltage: Voltage of the postsynaptic compartment, shape `()`.\n params: Parameters of the synapse. Conductances in `uS`.\n\n Returns:\n Updated states.\"\"\"\n raise NotImplementedError\n
"},{"location":"reference/mechanisms/#ionotropic-synapse","title":"Ionotropic Synapse","text":" Bases: Synapse
Compute synaptic current and update synapse state for a generic ionotropic synapse.
The synapse state \u201cs\u201d is the probability that a postsynaptic receptor channel is open, and this depends on the amount of neurotransmitter released, which is in turn dependent on the presynaptic voltage.
The synaptic parameters areL. F. Abbott and E. Marder, \u201cModeling Small Networks,\u201d in Methods in Neuronal Modeling, C. Koch and I. Sergev, Eds. Cambridge: MIT Press, 1998.
Source code injaxley/synapses/ionotropic.py
class IonotropicSynapse(Synapse):\n \"\"\"\n Compute synaptic current and update synapse state for a generic ionotropic synapse.\n\n The synapse state \"s\" is the probability that a postsynaptic receptor channel is\n open, and this depends on the amount of neurotransmitter released, which is in turn\n dependent on the presynaptic voltage.\n\n The synaptic parameters are:\n - gS: the maximal conductance across the postsynaptic membrane (uS)\n - e_syn: the reversal potential across the postsynaptic membrane (mV)\n - k_minus: the rate constant of neurotransmitter unbinding from the postsynaptic\n receptor (s^-1)\n\n Details of this implementation can be found in the following book chapter:\n L. F. Abbott and E. Marder, \"Modeling Small Networks,\" in Methods in Neuronal\n Modeling, C. Koch and I. Sergev, Eds. Cambridge: MIT Press, 1998.\n\n \"\"\"\n\n def __init__(self, name: Optional[str] = None):\n super().__init__(name)\n prefix = self._name\n self.synapse_params = {\n f\"{prefix}_gS\": 1e-4,\n f\"{prefix}_e_syn\": 0.0,\n f\"{prefix}_k_minus\": 0.025,\n }\n self.synapse_states = {f\"{prefix}_s\": 0.2}\n\n def update_states(\n self,\n states: Dict,\n delta_t: float,\n pre_voltage: float,\n post_voltage: float,\n params: Dict,\n ) -> Dict:\n \"\"\"Return updated synapse state and current.\"\"\"\n prefix = self._name\n v_th = -35.0 # mV\n delta = 10.0 # mV\n\n s_inf = 1.0 / (1.0 + save_exp((v_th - pre_voltage) / delta))\n tau_s = (1.0 - s_inf) / params[f\"{prefix}_k_minus\"]\n\n slope = -1.0 / tau_s\n exp_term = save_exp(slope * delta_t)\n new_s = states[f\"{prefix}_s\"] * exp_term + s_inf * (1.0 - exp_term)\n return {f\"{prefix}_s\": new_s}\n\n def compute_current(\n self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict\n ) -> float:\n prefix = self._name\n g_syn = params[f\"{prefix}_gS\"] * states[f\"{prefix}_s\"]\n return g_syn * (post_voltage - params[f\"{prefix}_e_syn\"])\n
"},{"location":"reference/mechanisms/#jaxley.synapses.ionotropic.IonotropicSynapse.update_states","title":"update_states(states, delta_t, pre_voltage, post_voltage, params)
","text":"Return updated synapse state and current.
Source code injaxley/synapses/ionotropic.py
def update_states(\n self,\n states: Dict,\n delta_t: float,\n pre_voltage: float,\n post_voltage: float,\n params: Dict,\n) -> Dict:\n \"\"\"Return updated synapse state and current.\"\"\"\n prefix = self._name\n v_th = -35.0 # mV\n delta = 10.0 # mV\n\n s_inf = 1.0 / (1.0 + save_exp((v_th - pre_voltage) / delta))\n tau_s = (1.0 - s_inf) / params[f\"{prefix}_k_minus\"]\n\n slope = -1.0 / tau_s\n exp_term = save_exp(slope * delta_t)\n new_s = states[f\"{prefix}_s\"] * exp_term + s_inf * (1.0 - exp_term)\n return {f\"{prefix}_s\": new_s}\n
"},{"location":"reference/mechanisms/#tanh-rate-synapse","title":"TanH Rate Synapse","text":" Bases: Synapse
Compute synaptic current for tanh synapse (no state).
Source code injaxley/synapses/tanh_rate.py
class TanhRateSynapse(Synapse):\n \"\"\"\n Compute synaptic current for tanh synapse (no state).\n \"\"\"\n\n def __init__(self, name: Optional[str] = None):\n super().__init__(name)\n prefix = self._name\n self.synapse_params = {\n f\"{prefix}_gS\": 1e-4,\n f\"{prefix}_x_offset\": -70.0,\n f\"{prefix}_slope\": 1.0,\n }\n self.synapse_states = {}\n\n def update_states(\n self,\n states: Dict,\n delta_t: float,\n pre_voltage: float,\n post_voltage: float,\n params: Dict,\n ) -> Dict:\n \"\"\"Return updated synapse state and current.\"\"\"\n return {}\n\n def compute_current(\n self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict\n ) -> float:\n \"\"\"Return updated synapse state and current.\"\"\"\n prefix = self._name\n current = (\n -1\n * params[f\"{prefix}_gS\"]\n * jnp.tanh(\n (pre_voltage - params[f\"{prefix}_x_offset\"]) * params[f\"{prefix}_slope\"]\n )\n )\n return current\n
"},{"location":"reference/mechanisms/#jaxley.synapses.tanh_rate.TanhRateSynapse.compute_current","title":"compute_current(states, pre_voltage, post_voltage, params)
","text":"Return updated synapse state and current.
Source code injaxley/synapses/tanh_rate.py
def compute_current(\n self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict\n) -> float:\n \"\"\"Return updated synapse state and current.\"\"\"\n prefix = self._name\n current = (\n -1\n * params[f\"{prefix}_gS\"]\n * jnp.tanh(\n (pre_voltage - params[f\"{prefix}_x_offset\"]) * params[f\"{prefix}_slope\"]\n )\n )\n return current\n
"},{"location":"reference/mechanisms/#jaxley.synapses.tanh_rate.TanhRateSynapse.update_states","title":"update_states(states, delta_t, pre_voltage, post_voltage, params)
","text":"Return updated synapse state and current.
Source code injaxley/synapses/tanh_rate.py
def update_states(\n self,\n states: Dict,\n delta_t: float,\n pre_voltage: float,\n post_voltage: float,\n params: Dict,\n) -> Dict:\n \"\"\"Return updated synapse state and current.\"\"\"\n return {}\n
"},{"location":"reference/modules/","title":"Modules","text":""},{"location":"reference/modules/#module","title":"Module","text":" Bases: ABC
Module base class.
Modules are everything that can be passed to jx.integrate
, i.e. compartments, branches, cells, and networks.
This base class defines the scaffold for all jaxley modules (compartments, branches, cells, networks).
Modules can be traversed and modified using the at
, cell
, branch
, comp
, edge
, and loc
methods. The scope
method can be used to toggle between global and local indices. Traversal of Modules will return a View
of itself, that has a modified set of attributes, which only consider the part of the Module that is in view.
For developers: The above has consequences for how to operate on Module
and which changes take affect where. The following guidelines should be followed (copied from View
):
self.base
. In order to enssure that these changes only affects whatever is currently in view self._nodes_in_view
, or self._edges_in_view
among others have to be used. Operating on nodes currently in view can for example be done with self.base.node.loc[self._nodes_in_view]
.xyzr
, needs to modified when View is instantiated. I.e. xyzr
of cell.branch(0)
, should be [self.base.xyzr[0]]
This could be achieved via: [self.base.xyzr[b] for b in self._branches_in_view]
.For developers: If you want to add a new method to Module
, here is an example of how to make methods of Module compatible with View:
.. code-block:: python
# Use data in view to return something.\ndef count_small_branches(self):\n # no need to use self.base.attr + viewed indices,\n # since no change is made to the attr in question (nodes)\n comp_lens = self.nodes[\"length\"]\n branch_lens = comp_lens.groupby(\"global_branch_index\").sum()\n return np.sum(branch_lens < 10)\n\n# Change data in view.\ndef change_attr_in_view(self):\n # changes to attrs have to be made via self.base.attr + viewed indices\n a = func1(self.base.attr1[self._cells_in_view])\n b = func2(self.base.attr2[self._edges_in_view])\n self.base.attr3[self._branches_in_view] = a + b\n
Source code in jaxley/modules/base.py
class Module(ABC):\n \"\"\"Module base class.\n\n Modules are everything that can be passed to `jx.integrate`, i.e. compartments,\n branches, cells, and networks.\n\n This base class defines the scaffold for all jaxley modules (compartments,\n branches, cells, networks).\n\n Modules can be traversed and modified using the `at`, `cell`, `branch`, `comp`,\n `edge`, and `loc` methods. The `scope` method can be used to toggle between\n global and local indices. Traversal of Modules will return a `View` of itself,\n that has a modified set of attributes, which only consider the part of the Module\n that is in view.\n\n For developers: The above has consequences for how to operate on `Module` and which\n changes take affect where. The following guidelines should be followed (copied from\n `View`):\n\n 1. We consider a Module to have everything in view.\n 2. Views can display and keep track of how a module is traversed. But(!),\n do not support making changes or setting variables. This still has to be\n done in the base Module, i.e. `self.base`. In order to enssure that these\n changes only affects whatever is currently in view `self._nodes_in_view`,\n or `self._edges_in_view` among others have to be used. Operating on nodes\n currently in view can for example be done with\n `self.base.node.loc[self._nodes_in_view]`.\n 3. Every attribute of Module that changes based on what's in view, i.e. `xyzr`,\n needs to modified when View is instantiated. I.e. `xyzr` of `cell.branch(0)`,\n should be `[self.base.xyzr[0]]` This could be achieved via:\n `[self.base.xyzr[b] for b in self._branches_in_view]`.\n\n For developers: If you want to add a new method to `Module`, here is an example of\n how to make methods of Module compatible with View:\n\n .. code-block:: python\n\n # Use data in view to return something.\n def count_small_branches(self):\n # no need to use self.base.attr + viewed indices,\n # since no change is made to the attr in question (nodes)\n comp_lens = self.nodes[\"length\"]\n branch_lens = comp_lens.groupby(\"global_branch_index\").sum()\n return np.sum(branch_lens < 10)\n\n # Change data in view.\n def change_attr_in_view(self):\n # changes to attrs have to be made via self.base.attr + viewed indices\n a = func1(self.base.attr1[self._cells_in_view])\n b = func2(self.base.attr2[self._edges_in_view])\n self.base.attr3[self._branches_in_view] = a + b\n \"\"\"\n\n def __init__(self):\n self.ncomp: int = None\n self.total_nbranches: int = 0\n self.nbranches_per_cell: List[int] = None\n\n self.groups = {}\n\n self.nodes: Optional[pd.DataFrame] = None\n self._scope = \"local\" # defaults to local scope\n self._nodes_in_view: np.ndarray = None\n self._edges_in_view: np.ndarray = None\n\n self.edges = pd.DataFrame(\n columns=[\n \"global_edge_index\",\n \"pre_global_comp_index\",\n \"post_global_comp_index\",\n \"pre_locs\",\n \"post_locs\",\n \"type\",\n \"type_ind\",\n ]\n )\n\n self._cumsum_nbranches: Optional[np.ndarray] = None\n\n self.comb_parents: jnp.ndarray = jnp.asarray([-1])\n\n self.initialized_morph: bool = False\n self.initialized_syns: bool = False\n\n # List of all types of `jx.Synapse`s.\n self.synapses: List = []\n self.synapse_param_names = []\n self.synapse_state_names = []\n self.synapse_names = []\n\n # List of types of all `jx.Channel`s.\n self.channels: List[Channel] = []\n self.membrane_current_names: List[str] = []\n\n # For trainable parameters.\n self.indices_set_by_trainables: List[jnp.ndarray] = []\n self.trainable_params: List[Dict[str, jnp.ndarray]] = []\n self.allow_make_trainable: bool = True\n self.num_trainable_params: int = 0\n\n # For recordings.\n self.recordings: pd.DataFrame = pd.DataFrame().from_dict({})\n\n # For stimuli or clamps.\n # E.g. `self.externals = {\"v\": zeros(1000,2), \"i\": ones(1000, 2)}`\n # for 1000 timesteps and two compartments.\n self.externals: Dict[str, jnp.ndarray] = {}\n # E.g. `self.external)inds = {\"v\": jnp.asarray([0,1]), \"i\": jnp.asarray([2,3])}`\n self.external_inds: Dict[str, jnp.ndarray] = {}\n\n # x, y, z coordinates and radius.\n self.xyzr: List[np.ndarray] = []\n self._radius_generating_fns = None # Defined by `.read_swc()`.\n\n # For debugging the solver. Will be empty by default and only filled if\n # `self._init_morph_for_debugging` is run.\n self.debug_states = {}\n\n # needs to be set at the end\n self.base: Module = self\n\n def __repr__(self):\n return f\"{type(self).__name__} with {len(self.channels)} different channels. Use `.nodes` for details.\"\n\n def __str__(self):\n return f\"jx.{type(self).__name__}\"\n\n def __dir__(self):\n base_dir = object.__dir__(self)\n return sorted(base_dir + self.synapse_names + list(self.group_nodes.keys()))\n\n def __getattr__(self, key):\n # Ensure that hidden methods such as `__deepcopy__` still work.\n if key.startswith(\"__\"):\n return super().__getattribute__(key)\n\n # intercepts calls to groups\n if key in self.base.groups:\n view = (\n self.select(self.groups[key])\n if key in self.groups\n else self.select(None)\n )\n view._set_controlled_by_param(key)\n return view\n\n # intercepts calls to channels\n if key in [c._name for c in self.base.channels]:\n channel_names = [c._name for c in self.channels]\n inds = self.nodes.index[self.nodes[key]].to_numpy()\n view = self.select(inds) if key in channel_names else self.select(None)\n view._set_controlled_by_param(key)\n return view\n\n # intercepts calls to synapse types\n if key in self.base.synapse_names:\n syn_inds = self.edges[self.edges[\"type\"] == key][\n \"global_edge_index\"\n ].to_numpy()\n orig_scope = self._scope\n view = (\n self.scope(\"global\").edge(syn_inds).scope(orig_scope)\n if key in self.synapse_names\n else self.select(None)\n )\n view._set_controlled_by_param(key) # overwrites param set by edge\n # Ensure synapse param sharing works with `edge`\n # `edge` will be removed as part of #463\n view.edges[\"local_edge_index\"] = np.arange(len(view.edges))\n return view\n\n def _childviews(self) -> List[str]:\n \"\"\"Returns levels that module can be viewed at.\n\n I.e. for net -> [cell, branch, comp]. For branch -> [comp]\"\"\"\n levels = [\"network\", \"cell\", \"branch\", \"comp\"]\n if self._current_view in levels:\n children = levels[levels.index(self._current_view) + 1 :]\n return children\n return []\n\n def _has_childview(self, key: str) -> bool:\n child_views = self._childviews()\n return key in child_views\n\n def __getitem__(self, index):\n \"\"\"Lazy indexing of the module.\"\"\"\n supported_parents = [\"network\", \"cell\", \"branch\"] # cannot index into comp\n\n not_group_view = self._current_view not in self.groups\n assert (\n self._current_view in supported_parents or not_group_view\n ), \"Lazy indexing is only supported for `Network`, `Cell`, `Branch` and Views thereof.\"\n index = index if isinstance(index, tuple) else (index,)\n\n child_views = self._childviews()\n assert len(index) <= len(child_views), \"Too many indices.\"\n view = self\n for i, child in zip(index, child_views):\n view = view._at_nodes(child, i)\n return view\n\n def _update_local_indices(self) -> pd.DataFrame:\n \"\"\"Compute local indices from the global indices that are in view.\n This is recomputed everytime a View is created.\"\"\"\n rerank = lambda df: df.rank(method=\"dense\").astype(int) - 1\n\n def reorder_cols(\n df: pd.DataFrame, cols: List[str], first: bool = True\n ) -> pd.DataFrame:\n \"\"\"Move cols to front/back.\n\n Args:\n df: DataFrame to reorder.\n cols: List of columns to place before/after remaining columns.\n first: If True, cols are placed in front, otherwise at the end.\n\n Returns:\n DataFrame with reordered columns.\"\"\"\n new_cols = [col for col in df.columns if first == (col in cols)]\n new_cols += [col for col in df.columns if first != (col in cols)]\n return df[new_cols]\n\n def reindex_a_by_b(\n df: pd.DataFrame, a: str, b: Optional[Union[str, List[str]]] = None\n ) -> pd.DataFrame:\n \"\"\"Reindex based on a different col or several columns\n for b=[0,0,1,1,2,2,2] -> a=[0,1,0,1,0,1,2]\"\"\"\n grouped_df = df.groupby(b) if b is not None else df\n df.loc[:, a] = rerank(grouped_df[a])\n return df\n\n index_names = [\"cell_index\", \"branch_index\", \"comp_index\"] # order is important\n global_idx_cols = [f\"global_{name}\" for name in index_names]\n local_idx_cols = [f\"local_{name}\" for name in index_names]\n idcs = self.nodes[global_idx_cols]\n\n # update local indices of nodes\n idcs = reindex_a_by_b(idcs, global_idx_cols[0])\n idcs = reindex_a_by_b(idcs, global_idx_cols[1], global_idx_cols[0])\n idcs = reindex_a_by_b(idcs, global_idx_cols[2], global_idx_cols[:2])\n idcs.columns = [col.replace(\"global\", \"local\") for col in global_idx_cols]\n self.nodes[local_idx_cols] = idcs[local_idx_cols].astype(int)\n\n # move indices to the front of the dataframe; move controlled_by_param to the end\n # move indices of current scope to the front and the others to the back\n not_scope = \"global\" if self._scope == \"local\" else \"local\"\n self.nodes = reorder_cols(\n self.nodes, [f\"{self._scope}_{name}\" for name in index_names], first=True\n )\n self.nodes = reorder_cols(\n self.nodes, [f\"{not_scope}_{name}\" for name in index_names], first=False\n )\n\n self.edges = reorder_cols(self.edges, [\"global_edge_index\"])\n self.nodes = reorder_cols(self.nodes, [\"controlled_by_param\"], first=False)\n self.edges = reorder_cols(self.edges, [\"controlled_by_param\"], first=False)\n\n def _init_view(self):\n \"\"\"Init attributes critical for View.\n\n Needs to be called at init of a Module.\"\"\"\n parent = self.__class__.__name__.lower()\n self._current_view = \"comp\" if parent == \"compartment\" else parent\n self._nodes_in_view = self.nodes.index.to_numpy()\n self._edges_in_view = self.edges.index.to_numpy()\n self.nodes[\"controlled_by_param\"] = 0\n\n def _compute_coords_of_comp_centers(self) -> np.ndarray:\n \"\"\"Compute xyz coordinates of compartment centers.\n\n Centers are the midpoint between the comparment endpoints on the morphology\n as defined by xyzr.\n\n Note: For sake of performance, interpolation is not done for each branch\n individually, but only once along a concatenated (and padded) array of all branches.\n This means for ncomps = [2,4] and normalized cum_branch_lens of [[0,1],[0,1]] we would\n interpolate xyz at the locations comp_ends = [[0,0.5,1], [0,0.25,0.5,0.75,1]],\n where 0 is the start of the branch and 1 is the end point at the full branch_len.\n To avoid do this in one go we set comp_ends = [0,0.5,1,2,2.25,2.5,2.75,3], and\n norm_cum_branch_len = [0,1,2,3] incrememting and also padding them by 1 to\n avoid overlapping branch_lens i.e. norm_cum_branch_len = [0,1,1,2] for only\n incrementing.\n \"\"\"\n nodes_by_branches = self.nodes.groupby(\"global_branch_index\")\n ncomps = nodes_by_branches[\"global_comp_index\"].nunique().to_numpy()\n\n comp_ends = [\n np.linspace(0, 1, ncomp + 1) + 2 * i for i, ncomp in enumerate(ncomps)\n ]\n comp_ends = np.hstack(comp_ends)\n\n comp_ends = comp_ends.reshape(-1)\n cum_branch_lens = []\n for i, xyzr in enumerate(self.xyzr):\n branch_len = np.sqrt(np.sum(np.diff(xyzr[:, :3], axis=0) ** 2, axis=1))\n cum_branch_len = np.cumsum(np.concatenate([np.array([0]), branch_len]))\n max_len = cum_branch_len.max()\n # add padding like above\n cum_branch_len = cum_branch_len / (max_len if max_len > 0 else 1) + 2 * i\n cum_branch_len[np.isnan(cum_branch_len)] = 0\n cum_branch_lens.append(cum_branch_len)\n cum_branch_lens = np.hstack(cum_branch_lens)\n xyz = np.vstack(self.xyzr)[:, :3]\n xyz = v_interp(comp_ends, cum_branch_lens, xyz).T\n centers = (xyz[:-1] + xyz[1:]) / 2 # unaware of inter vs intra comp centers\n cum_ncomps = np.cumsum(ncomps)\n # this means centers between comps have to be removed here\n between_comp_inds = (cum_ncomps + np.arange(len(cum_ncomps)))[:-1]\n centers = np.delete(centers, between_comp_inds, axis=0)\n return centers\n\n def compute_compartment_centers(self):\n \"\"\"Add compartment centers to nodes dataframe\"\"\"\n centers = self._compute_coords_of_comp_centers()\n self.base.nodes.loc[self._nodes_in_view, [\"x\", \"y\", \"z\"]] = centers\n\n def _reformat_index(self, idx: Any, dtype: type = int) -> np.ndarray:\n \"\"\"Transforms different types of indices into an array.\n\n Takes slice, list, array, ints, range and None and transforms\n it into array of indices. If index == \"all\" it returns \"all\"\n to be handled downstream.\n\n Args:\n idx: index that specifies at which locations to view the module.\n dtype: defaults to int, but can also reformat float for use in `loc`\n\n Returns:\n array of indices of shape (N,)\"\"\"\n if is_str_all(idx): # also asserts that the only allowed str == \"all\"\n return idx\n\n np_dtype = np.int64 if dtype is int else np.float64\n idx = np.array([], dtype=dtype) if idx is None else idx\n idx = np.array([idx]) if isinstance(idx, (dtype, np_dtype)) else idx\n idx = np.array(idx) if isinstance(idx, (list, range, pd.Index)) else idx\n\n idx = np.arange(len(self.base.nodes))[idx] if isinstance(idx, slice) else idx\n if idx.dtype == bool:\n shape = (*self.shape, len(self.edges))\n which_idx = len(idx) == np.array(shape)\n assert np.any(which_idx), \"Index not matching num of cells/branches/comps.\"\n dim = shape[np.where(which_idx)[0][0]]\n idx = np.arange(dim)[idx]\n assert isinstance(idx, np.ndarray), \"Invalid type\"\n assert idx.dtype in [np_dtype, bool], \"Invalid dtype\"\n return idx.reshape(-1)\n\n def _set_controlled_by_param(self, key: str):\n \"\"\"Determines which parameters are shared in `make_trainable`.\n\n Adds column to nodes/edges dataframes to read of shared params from.\n\n Args:\n key: key specifying group / view that is in control of the params.\"\"\"\n if key in [\"comp\", \"branch\", \"cell\"]:\n self.nodes[\"controlled_by_param\"] = self.nodes[f\"global_{key}_index\"]\n self.edges[\"controlled_by_param\"] = 0\n elif key == \"edge\":\n self.edges[\"controlled_by_param\"] = np.arange(len(self.edges))\n elif key == \"filter\":\n self.nodes[\"controlled_by_param\"] = np.arange(len(self.nodes))\n self.edges[\"controlled_by_param\"] = np.arange(len(self.edges))\n else:\n self.nodes[\"controlled_by_param\"] = 0\n self.edges[\"controlled_by_param\"] = 0\n self._current_view = key\n\n def select(\n self, nodes: np.ndarray = None, edges: np.ndarray = None, sorted: bool = False\n ) -> View:\n \"\"\"Return View of the module filtered by specific node or edges indices.\n\n Args:\n nodes: indices of nodes to view. If None, all nodes are viewed.\n edges: indices of edges to view. If None, all edges are viewed.\n sorted: if True, nodes and edges are sorted.\n\n Returns:\n View for subset of selected nodes and/or edges.\"\"\"\n\n nodes = self._reformat_index(nodes) if nodes is not None else None\n nodes = self._nodes_in_view if is_str_all(nodes) else nodes\n nodes = np.sort(nodes) if sorted else nodes\n\n edges = self._reformat_index(edges) if edges is not None else None\n edges = self._edges_in_view if is_str_all(edges) else edges\n edges = np.sort(edges) if sorted else edges\n\n view = View(self, nodes, edges)\n view._set_controlled_by_param(\"filter\")\n return view\n\n def set_scope(self, scope: str):\n \"\"\"Toggle between \"global\" or \"local\" scope.\n\n Determines if global or local indices are used for viewing the module.\n\n Args:\n scope: either \"global\" or \"local\".\"\"\"\n assert scope in [\"global\", \"local\"], \"Invalid scope.\"\n self._scope = scope\n\n def scope(self, scope: str) -> View:\n \"\"\"Return a View of the module with the specified scope.\n\n For example `cell.scope(\"global\").branch(2).scope(\"local\").comp(1)`\n will return the 1st compartment of branch 2.\n\n Args:\n scope: either \"global\" or \"local\".\n\n Returns:\n View with the specified scope.\"\"\"\n view = self.view\n view.set_scope(scope)\n return view\n\n def _at_nodes(self, key: str, idx: Any) -> View:\n \"\"\"Return a View of the module filtering `nodes` by specified key and index.\n\n Keys can be `cell`, `branch`, `comp` and determine which index is used to filter.\n \"\"\"\n base_name = self.base.__class__.__name__\n assert self.base._has_childview(key), f\"{base_name} does not support {key}.\"\n idx = self._reformat_index(idx)\n idx = self.nodes[self._scope + f\"_{key}_index\"] if is_str_all(idx) else idx\n where = self.nodes[self._scope + f\"_{key}_index\"].isin(idx)\n inds = self.nodes.index[where].to_numpy()\n\n view = View(self, nodes=inds)\n view._set_controlled_by_param(key)\n return view\n\n def _at_edges(self, key: str, idx: Any) -> View:\n \"\"\"Return a View of the module filtering `edges` by specified key and index.\n\n Keys can be `pre`, `post`, `edge` and determine which index is used to filter.\n \"\"\"\n idx = self._reformat_index(idx)\n idx = self.edges[self._scope + f\"_{key}_index\"] if is_str_all(idx) else idx\n where = self.edges[self._scope + f\"_{key}_index\"].isin(idx)\n inds = self.edges.index[where].to_numpy()\n\n view = View(self, edges=inds)\n view._set_controlled_by_param(key)\n return view\n\n def cell(self, idx: Any) -> View:\n \"\"\"Return a View of the module at the selected cell(s).\n\n Args:\n idx: index of the cell to view.\n\n Returns:\n View of the module at the specified cell index.\"\"\"\n return self._at_nodes(\"cell\", idx)\n\n def branch(self, idx: Any) -> View:\n \"\"\"Return a View of the module at the selected branches(s).\n\n Args:\n idx: index of the branch to view.\n\n Returns:\n View of the module at the specified branch index.\"\"\"\n return self._at_nodes(\"branch\", idx)\n\n def comp(self, idx: Any) -> View:\n \"\"\"Return a View of the module at the selected compartments(s).\n\n Args:\n idx: index of the comp to view.\n\n Returns:\n View of the module at the specified compartment index.\"\"\"\n return self._at_nodes(\"comp\", idx)\n\n def edge(self, idx: Any) -> View:\n \"\"\"Return a View of the module at the selected synapse edges(s).\n\n Args:\n idx: index of the edge to view.\n\n Returns:\n View of the module at the specified edge index.\"\"\"\n return self._at_edges(\"edge\", idx)\n\n def loc(self, at: Any) -> View:\n \"\"\"Return a View of the module at the selected branch location(s).\n\n Args:\n at: location along the branch.\n\n Returns:\n View of the module at the specified branch location.\"\"\"\n global_comp_idxs = []\n for i in self._branches_in_view:\n ncomp = self.base.ncomp_per_branch[i]\n comp_locs = np.linspace(0, 1, ncomp)\n at = comp_locs if is_str_all(at) else self._reformat_index(at, dtype=float)\n comp_edges = np.linspace(0, 1 + 1e-10, ncomp + 1)\n idx = np.digitize(at, comp_edges) - 1 + self.base.cumsum_ncomp[i]\n global_comp_idxs.append(idx)\n global_comp_idxs = np.concatenate(global_comp_idxs)\n orig_scope = self._scope\n # global scope needed to select correct comps, for i.e. branches w. ncomp=[1,2]\n # loc(0.9) will correspond to different local branches (0 vs 1).\n view = self.scope(\"global\").comp(global_comp_idxs).scope(orig_scope)\n view._current_view = \"loc\"\n return view\n\n @property\n def _comps_in_view(self):\n \"\"\"Lists the global compartment indices which are currently part of the view.\"\"\"\n # method also exists in View. this copy forgoes need to instantiate a View\n return self.nodes[\"global_comp_index\"].unique()\n\n @property\n def _branches_in_view(self):\n \"\"\"Lists the global branch indices which are currently part of the view.\"\"\"\n # method also exists in View. this copy forgoes need to instantiate a View\n return self.nodes[\"global_branch_index\"].unique()\n\n @property\n def _cells_in_view(self):\n \"\"\"Lists the global cell indices which are currently part of the view.\"\"\"\n # method also exists in View. this copy forgoes need to instantiate a View\n return self.nodes[\"global_cell_index\"].unique()\n\n def _iter_submodules(self, name: str):\n \"\"\"Iterate over submoduleslevel.\n\n Used for `cells`, `branches`, `comps`.\"\"\"\n col = self._scope + f\"_{name}_index\"\n idxs = self.nodes[col].unique()\n for idx in idxs:\n yield self._at_nodes(name, idx)\n\n @property\n def cells(self):\n \"\"\"Iterate over all cells in the module.\n\n Returns a generator that yields a View of each cell.\"\"\"\n yield from self._iter_submodules(\"cell\")\n\n @property\n def branches(self):\n \"\"\"Iterate over all branches in the module.\n\n Returns a generator that yields a View of each branch.\"\"\"\n yield from self._iter_submodules(\"branch\")\n\n @property\n def comps(self):\n \"\"\"Iterate over all compartments in the module.\n Can be called on any module, i.e. `net.comps`, `cell.comps` or\n `branch.comps`. `__iter__` does not allow for this.\n\n Returns a generator that yields a View of each compartment.\"\"\"\n yield from self._iter_submodules(\"comp\")\n\n def __iter__(self):\n \"\"\"Iterate over parts of the module.\n\n Internally calls `cells`, `branches`, `comps` at the appropriate level.\n\n Example:\n\n .. code-block:: python\n\n for cell in network:\n for branch in cell:\n for comp in branch:\n print(comp.nodes.shape)\n \"\"\"\n next_level = self._childviews()[0]\n yield from self._iter_submodules(next_level)\n\n @property\n def shape(self) -> Tuple[int]:\n \"\"\"Returns the number of submodules contained in a module.\n\n .. code-block:: python\n\n network.shape = (num_cells, num_branches, num_compartments)\n cell.shape = (num_branches, num_compartments)\n branch.shape = (num_compartments,)\n \"\"\"\n cols = [\"global_cell_index\", \"global_branch_index\", \"global_comp_index\"]\n raw_shape = self.nodes[cols].nunique().to_list()\n\n # ensure (net.shape -> dim=3, cell.shape -> dim=2, branch.shape -> dim=1, comp.shape -> dim=0)\n levels = [\"network\", \"cell\", \"branch\", \"comp\"]\n module = self.base.__class__.__name__.lower()\n module = \"comp\" if module == \"compartment\" else module\n shape = tuple(raw_shape[levels.index(module) :])\n return shape\n\n def copy(\n self, reset_index: bool = False, as_module: bool = False\n ) -> Union[Module, View]:\n \"\"\"Extract part of a module and return a copy of its View or a new module.\n\n This can be used to call `jx.integrate` on part of a Module.\n\n Args:\n reset_index: if True, the indices of the new module are reset to start from 0.\n as_module: if True, a new module is returned instead of a View.\n\n Returns:\n A part of the module or a copied view of it.\"\"\"\n view = deepcopy(self)\n warnings.warn(\"This method is experimental, use at your own risk.\")\n # TODO FROM #447: add reset_index, i.e. for parents, nodes, edges etc. such that they\n # start from 0/-1 and are contiguous\n if as_module:\n raise NotImplementedError(\"Not yet implemented.\")\n # initialize a new module with the same attributes\n return view\n\n @property\n def view(self):\n \"\"\"Return view of the module.\"\"\"\n return View(self, self._nodes_in_view, self._edges_in_view)\n\n @property\n def _module_type(self):\n \"\"\"Return type of the module (compartment, branch, cell, network) as string.\n\n This is used to perform asserts for some modules (e.g. network cannot use\n `set_ncomp`) without having to import the module in `base.py`.\"\"\"\n return self.__class__.__name__.lower()\n\n def _append_params_and_states(self, param_dict: Dict, state_dict: Dict):\n \"\"\"Insert the default params of the module (e.g. radius, length).\n\n This is run at `__init__()`. It does not deal with channels.\n \"\"\"\n for param_name, param_value in param_dict.items():\n self.base.nodes[param_name] = param_value\n for state_name, state_value in state_dict.items():\n self.base.nodes[state_name] = state_value\n\n def _gather_channels_from_constituents(self, constituents: List):\n \"\"\"Modify `self.channels` and `self.nodes` with channel info from constituents.\n\n This is run at `__init__()`. It takes all branches of constituents (e.g.\n of all branches when the are assembled into a cell) and adds columns to\n `.nodes` for the relevant channels.\n \"\"\"\n for module in constituents:\n for channel in module.channels:\n if channel._name not in [c._name for c in self.channels]:\n self.base.channels.append(channel)\n if channel.current_name not in self.membrane_current_names:\n self.base.membrane_current_names.append(channel.current_name)\n # Setting columns of channel names to `False` instead of `NaN`.\n for channel in self.base.channels:\n name = channel._name\n self.base.nodes.loc[self.nodes[name].isna(), name] = False\n\n @only_allow_module\n def to_jax(self):\n # TODO FROM #447: Make this work for View?\n \"\"\"Move `.nodes` to `.jaxnodes`.\n\n Before the actual simulation is run (via `jx.integrate`), all parameters of\n the `jx.Module` are stored in `.nodes` (a `pd.DataFrame`). However, for\n simulation, these parameters have to be moved to be `jnp.ndarrays` such that\n they can be processed on GPU/TPU and such that the simulation can be\n differentiated. `.to_jax()` copies the `.nodes` to `.jaxnodes`.\n \"\"\"\n self.base.jaxnodes = {}\n for key, value in self.base.nodes.to_dict(orient=\"list\").items():\n inds = jnp.arange(len(value))\n self.base.jaxnodes[key] = jnp.asarray(value)[inds]\n\n # `jaxedges` contains only parameters (no indices).\n # `jaxedges` contains only non-Nan elements. This is unlike the channels where\n # we allow parameter sharing.\n self.base.jaxedges = {}\n edges = self.base.edges.to_dict(orient=\"list\")\n for i, synapse in enumerate(self.base.synapses):\n condition = np.asarray(edges[\"type_ind\"]) == i\n for key in synapse.synapse_params:\n self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])\n for key in synapse.synapse_states:\n self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])\n\n def show(\n self,\n param_names: Optional[Union[str, List[str]]] = None,\n *,\n indices: bool = True,\n params: bool = True,\n states: bool = True,\n channel_names: Optional[List[str]] = None,\n ) -> pd.DataFrame:\n \"\"\"Print detailed information about the Module or a view of it.\n\n Args:\n param_names: The names of the parameters to show. If `None`, all parameters\n are shown.\n indices: Whether to show the indices of the compartments.\n params: Whether to show the parameters of the compartments.\n states: Whether to show the states of the compartments.\n channel_names: The names of the channels to show. If `None`, all channels are\n shown.\n\n Returns:\n A `pd.DataFrame` with the requested information.\n \"\"\"\n nodes = self.nodes.copy() # prevents this from being edited\n\n cols = []\n inds = [\"comp_index\", \"branch_index\", \"cell_index\"]\n scopes = [\"local\", \"global\"]\n inds = [f\"{s}_{i}\" for i in inds for s in scopes] if indices else []\n cols += inds\n cols += [ch._name for ch in self.channels] if channel_names else []\n cols += (\n sum([list(ch.channel_params) for ch in self.channels], []) if params else []\n )\n cols += (\n sum([list(ch.channel_states) for ch in self.channels], []) if states else []\n )\n\n if not param_names is None:\n cols = (\n inds + [c for c in cols if c in param_names]\n if params\n else list(param_names)\n )\n\n return nodes[cols]\n\n @only_allow_module\n def _init_morph(self):\n \"\"\"Initialize the morphology such that it can be processed by the solvers.\"\"\"\n self._init_morph_jaxley_spsolve()\n self._init_morph_jax_spsolve()\n self.initialized_morph = True\n\n @abstractmethod\n def _init_morph_jax_spsolve(self):\n \"\"\"Initialize the morphology for the JAX sparse solver.\"\"\"\n raise NotImplementedError\n\n @abstractmethod\n def _init_morph_jaxley_spsolve(self):\n \"\"\"Initialize the morphology for the custom Jaxley solver.\"\"\"\n raise NotImplementedError\n\n def _compute_axial_conductances(self, params: Dict[str, jnp.ndarray]):\n \"\"\"Given radius, length, r_a, compute the axial coupling conductances.\"\"\"\n return compute_axial_conductances(self._comp_edges, params)\n\n def set(self, key: str, val: Union[float, jnp.ndarray]):\n \"\"\"Set parameter of module (or its view) to a new value.\n\n Note that this function can not be called within `jax.jit` or `jax.grad`.\n Instead, it should be used set the parameters of the module **before** the\n simulation. Use `.data_set()` to set parameters during `jax.jit` or\n `jax.grad`.\n\n Args:\n key: The name of the parameter to set.\n val: The value to set the parameter to. If it is `jnp.ndarray` then it\n must be of shape `(len(num_compartments))`.\n \"\"\"\n if key in self.nodes.columns:\n not_nan = ~self.nodes[key].isna().to_numpy()\n self.base.nodes.loc[self._nodes_in_view[not_nan], key] = val\n elif key in self.edges.columns:\n not_nan = ~self.edges[key].isna().to_numpy()\n self.base.edges.loc[self._edges_in_view[not_nan], key] = val\n else:\n raise KeyError(f\"Key '{key}' not found in nodes or edges\")\n\n def data_set(\n self,\n key: str,\n val: Union[float, jnp.ndarray],\n param_state: Optional[List[Dict]],\n ):\n \"\"\"Set parameter of module (or its view) to a new value within `jit`.\n\n Args:\n key: The name of the parameter to set.\n val: The value to set the parameter to. If it is `jnp.ndarray` then it\n must be of shape `(len(num_compartments))`.\n param_state: State of the setted parameters, internally used such that this\n function does not modify global state.\n \"\"\"\n # Note: `data_set` does not support arrays for `val`.\n is_node_param = key in self.nodes.columns\n data = self.nodes if is_node_param else self.edges\n viewed_inds = self._nodes_in_view if is_node_param else self._edges_in_view\n if key in data.columns:\n not_nan = ~data[key].isna()\n added_param_state = [\n {\n \"indices\": np.atleast_2d(viewed_inds[not_nan]),\n \"key\": key,\n \"val\": jnp.atleast_1d(jnp.asarray(val)),\n }\n ]\n if param_state is not None:\n param_state += added_param_state\n else:\n param_state = added_param_state\n else:\n raise KeyError(\"Key not recognized.\")\n return param_state\n\n def set_ncomp(\n self,\n ncomp: int,\n min_radius: Optional[float] = None,\n ):\n \"\"\"Set the number of compartments with which the branch is discretized.\n\n Args:\n ncomp: The number of compartments that the branch should be discretized\n into.\n min_radius: Only used if the morphology was read from an SWC file. If passed\n the radius is capped to be at least this value.\n\n Raises:\n - When there are stimuli in any compartment in the module.\n - When there are recordings in any compartment in the module.\n - When the channels of the compartments are not the same within the branch\n that is modified.\n - When the lengths of the compartments are not the same within the branch\n that is modified.\n - Unless the morphology was read from an SWC file, when the radiuses of the\n compartments are not the same within the branch that is modified.\n \"\"\"\n assert len(self.base.externals) == 0, \"No stimuli allowed!\"\n assert len(self.base.recordings) == 0, \"No recordings allowed!\"\n assert len(self.base.trainable_params) == 0, \"No trainables allowed!\"\n\n assert self.base._module_type != \"network\", \"This is not allowed for networks.\"\n assert not (\n self.base._module_type == \"cell\"\n and len(self._branches_in_view) == len(self.base._branches_in_view)\n ), \"This is not allowed for cells.\"\n\n # Update all attributes that are affected by compartment structure.\n view = self.nodes.copy()\n all_nodes = self.base.nodes\n start_idx = self.nodes[\"global_comp_index\"].to_numpy()[0]\n ncomp_per_branch = self.base.ncomp_per_branch\n channel_names = [c._name for c in self.base.channels]\n channel_param_names = list(\n chain(*[c.channel_params for c in self.base.channels])\n )\n channel_state_names = list(\n chain(*[c.channel_states for c in self.base.channels])\n )\n radius_generating_fns = self.base._radius_generating_fns\n\n within_branch_radiuses = view[\"radius\"].to_numpy()\n compartment_lengths = view[\"length\"].to_numpy()\n num_previous_ncomp = len(within_branch_radiuses)\n branch_indices = pd.unique(view[\"global_branch_index\"])\n\n error_msg = lambda name: (\n f\"You previously modified the {name} of individual compartments, but \"\n f\"now you are modifying the number of compartments in this branch. \"\n f\"This is not allowed. First build the morphology with `set_ncomp()` and \"\n f\"then modify the radiuses and lengths of compartments.\"\n )\n\n if (\n ~np.all(within_branch_radiuses == within_branch_radiuses[0])\n and radius_generating_fns is None\n ):\n raise ValueError(error_msg(\"radius\"))\n\n for property_name in [\"length\", \"capacitance\", \"axial_resistivity\"]:\n compartment_properties = view[property_name].to_numpy()\n if ~np.all(compartment_properties == compartment_properties[0]):\n raise ValueError(error_msg(property_name))\n\n if not (self.nodes[channel_names].var() == 0.0).all():\n raise ValueError(\n \"Some channel exists only in some compartments of the branch which you\"\n \"are trying to modify. This is not allowed. First specify the number\"\n \"of compartments with `.set_ncomp()` and then insert the channels\"\n \"accordingly.\"\n )\n\n if not (\n self.nodes[channel_param_names + channel_state_names].var() == 0.0\n ).all():\n raise ValueError(\n \"Some channel has different parameters or states between the \"\n \"different compartments of the branch which you are trying to modify. \"\n \"This is not allowed. First specify the number of compartments with \"\n \"`.set_ncomp()` and then insert the channels accordingly.\"\n )\n\n # Add new rows as the average of all rows. Special case for the length is below.\n average_row = self.nodes.mean(skipna=False)\n average_row = average_row.to_frame().T\n view = pd.concat([*[average_row] * ncomp], axis=\"rows\")\n\n # Set the correct datatype after having performed an average which cast\n # everything to float.\n integer_cols = [\"global_cell_index\", \"global_branch_index\", \"global_comp_index\"]\n view[integer_cols] = view[integer_cols].astype(int)\n\n # Whether or not a channel exists in a compartment is a boolean.\n boolean_cols = channel_names\n view[boolean_cols] = view[boolean_cols].astype(bool)\n\n # Special treatment for the lengths and radiuses. These are not being set as\n # the average because we:\n # 1) Want to maintain the total length of a branch.\n # 2) Want to use the SWC inferred radius.\n #\n # Compute new compartment lengths.\n comp_lengths = np.sum(compartment_lengths) / ncomp\n view[\"length\"] = comp_lengths\n\n # Compute new compartment radiuses.\n if radius_generating_fns is not None:\n view[\"radius\"] = build_radiuses_from_xyzr(\n radius_fns=radius_generating_fns,\n branch_indices=branch_indices,\n min_radius=min_radius,\n ncomp=ncomp,\n )\n else:\n view[\"radius\"] = within_branch_radiuses[0] * np.ones(ncomp)\n\n # Update `.nodes`.\n # 1) Delete N rows starting from start_idx\n number_deleted = num_previous_ncomp\n all_nodes = all_nodes.drop(index=range(start_idx, start_idx + number_deleted))\n\n # 2) Insert M new rows at the same location\n df1 = all_nodes.iloc[:start_idx] # Rows before the insertion point\n df2 = all_nodes.iloc[start_idx:] # Rows after the insertion point\n\n # 3) Combine the parts: before, new rows, and after\n all_nodes = pd.concat([df1, view, df2]).reset_index(drop=True)\n\n # Override `comp_index` to just be a consecutive list.\n all_nodes[\"global_comp_index\"] = np.arange(len(all_nodes))\n\n # Update compartment structure arguments.\n ncomp_per_branch[branch_indices] = ncomp\n ncomp = int(np.max(ncomp_per_branch))\n cumsum_ncomp = cumsum_leading_zero(ncomp_per_branch)\n internal_node_inds = np.arange(cumsum_ncomp[-1])\n\n self.base.nodes = all_nodes\n self.base.ncomp_per_branch = ncomp_per_branch\n self.base.ncomp = ncomp\n self.base.cumsum_ncomp = cumsum_ncomp\n self.base._internal_node_inds = internal_node_inds\n\n # Update the morphology indexing (e.g., `.comp_edges`).\n self.base._initialize()\n self.base._init_view()\n self.base._update_local_indices()\n\n def make_trainable(\n self,\n key: str,\n init_val: Optional[Union[float, list]] = None,\n verbose: bool = True,\n ):\n \"\"\"Make a parameter trainable.\n\n If a parameter is made trainable, it will be returned by `get_parameters()`\n and should then be passed to `jx.integrate(..., params=params)`.\n\n Args:\n key: Name of the parameter to make trainable.\n init_val: Initial value of the parameter. If `float`, the same value is\n used for every created parameter. If `list`, the length of the list has\n to match the number of created parameters. If `None`, the current\n parameter value is used and if parameter sharing is performed that the\n current parameter value is averaged over all shared parameters.\n verbose: Whether to print the number of parameters that are added and the\n total number of parameters.\n \"\"\"\n assert (\n self.allow_make_trainable\n ), \"network.cell('all').make_trainable() is not supported. Use a for-loop over cells.\"\n ncomps_per_branch = (\n self.base.nodes[\"global_branch_index\"].value_counts().to_numpy()\n )\n assert np.all(\n ncomps_per_branch == ncomps_per_branch[0]\n ), \"Parameter sharing is not allowed for modules containing branches with different numbers of compartments.\"\n\n data = self.nodes if key in self.nodes.columns else None\n data = self.edges if key in self.edges.columns else data\n\n assert data is not None, f\"Key '{key}' not found in nodes or edges\"\n not_nan = ~data[key].isna()\n data = data.loc[not_nan]\n assert (\n len(data) > 0\n ), \"No settable parameters found in the selected compartments.\"\n\n grouped_view = data.groupby(\"controlled_by_param\")\n # Because of this `x.index.values` we cannot support `make_trainable()` on\n # the module level for synapse parameters (but only for `SynapseView`).\n inds_of_comps = list(\n grouped_view.apply(lambda x: x.index.values, include_groups=False)\n )\n indices_per_param = jnp.stack(inds_of_comps)\n # Sorted inds are only used to infer the correct starting values.\n param_vals = jnp.asarray(\n [data.loc[inds, key].to_numpy() for inds in inds_of_comps]\n )\n\n # Set the value which the trainable parameter should take.\n num_created_parameters = len(indices_per_param)\n if init_val is not None:\n if isinstance(init_val, float):\n new_params = jnp.asarray([init_val] * num_created_parameters)\n elif isinstance(init_val, list):\n assert (\n len(init_val) == num_created_parameters\n ), f\"len(init_val)={len(init_val)}, but trying to create {num_created_parameters} parameters.\"\n new_params = jnp.asarray(init_val)\n else:\n raise ValueError(\n f\"init_val must a float, list, or None, but it is a {type(init_val).__name__}.\"\n )\n else:\n new_params = jnp.mean(param_vals, axis=1)\n self.base.trainable_params.append({key: new_params})\n self.base.indices_set_by_trainables.append(indices_per_param)\n self.base.num_trainable_params += num_created_parameters\n if verbose:\n print(\n f\"Number of newly added trainable parameters: {num_created_parameters}. Total number of trainable parameters: {self.base.num_trainable_params}\"\n )\n\n def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]):\n \"\"\"Write the trainables into `.nodes` and `.edges`.\n\n This allows to, e.g., visualize trained networks with `.vis()`.\n\n Args:\n trainable_params: The trainable parameters returned by `get_parameters()`.\n \"\"\"\n # We do not support views. Why? `jaxedges` does not have any NaN\n # elements, whereas edges does. Because of this, we already need special\n # treatment to make this function work, and it would be an even bigger hassle\n # if we wanted to support this.\n assert self.__class__.__name__ in [\n \"Compartment\",\n \"Branch\",\n \"Cell\",\n \"Network\",\n ], \"Only supports modules.\"\n\n # We could also implement this without casting the module to jax.\n # However, I think it allows us to reuse as much code as possible and it avoids\n # any kind of issues with indexing or parameter sharing (as this is fully\n # taken care of by `get_all_parameters()`).\n self.base.to_jax()\n pstate = params_to_pstate(trainable_params, self.base.indices_set_by_trainables)\n all_params = self.base.get_all_parameters(pstate, voltage_solver=\"jaxley.stone\")\n\n # The value for `delta_t` does not matter here because it is only used to\n # compute the initial current. However, the initial current cannot be made\n # trainable and so its value never gets used below.\n all_states = self.base.get_all_states(pstate, all_params, delta_t=0.025)\n\n # Loop only over the keys in `pstate` to avoid unnecessary computation.\n for parameter in pstate:\n key = parameter[\"key\"]\n if key in self.base.nodes.columns:\n vals_to_set = all_params if key in all_params.keys() else all_states\n self.base.nodes[key] = vals_to_set[key]\n\n # `jaxedges` contains only non-Nan elements. This is unlike the channels where\n # we allow parameter sharing.\n edges = self.base.edges.to_dict(orient=\"list\")\n for i, synapse in enumerate(self.base.synapses):\n condition = np.asarray(edges[\"type_ind\"]) == i\n for key in list(synapse.synapse_params.keys()):\n self.base.edges.loc[condition, key] = all_params[key]\n for key in list(synapse.synapse_states.keys()):\n self.base.edges.loc[condition, key] = all_states[key]\n\n def distance(self, endpoint: \"View\") -> float:\n \"\"\"Return the direct distance between two compartments.\n This does not compute the pathwise distance (which is currently not\n implemented).\n Args:\n endpoint: The compartment to which to compute the distance to.\n \"\"\"\n assert len(self.xyzr) == 1 and len(endpoint.xyzr) == 1\n start_xyz = np.mean(self.xyzr[0][:, :3], axis=0)\n end_xyz = np.mean(endpoint.xyzr[0][:, :3], axis=0)\n return np.sqrt(np.sum((start_xyz - end_xyz) ** 2))\n\n def delete_trainables(self):\n \"\"\"Removes all trainable parameters from the module.\"\"\"\n\n if isinstance(self, View):\n trainables_and_inds = self._filter_trainables(is_viewed=False)\n self.base.indices_set_by_trainables = trainables_and_inds[0]\n self.base.trainable_params = trainables_and_inds[1]\n self.base.num_trainable_params -= self.num_trainable_params\n else:\n self.base.indices_set_by_trainables = []\n self.base.trainable_params = []\n self.base.num_trainable_params = 0\n self._update_view()\n\n def add_to_group(self, group_name: str):\n \"\"\"Add a view of the module to a group.\n\n Groups can then be indexed. For example:\n\n .. code-block:: python\n\n net.cell(0).add_to_group(\"excitatory\")\n net.excitatory.set(\"radius\", 0.1)\n\n Args:\n group_name: The name of the group.\n \"\"\"\n if group_name not in self.base.groups:\n self.base.groups[group_name] = self._nodes_in_view\n else:\n self.base.groups[group_name] = np.unique(\n np.concatenate([self.base.groups[group_name], self._nodes_in_view])\n )\n\n def _get_state_names(self) -> Tuple[List, List]:\n \"\"\"Collect all recordable / clampable states in the membrane and synapses.\n\n Returns states seperated by comps and edges.\"\"\"\n channel_states = [name for c in self.channels for name in c.channel_states]\n synapse_states = [name for s in self.synapses for name in s.synapse_states]\n membrane_states = [\"v\", \"i\"] + self.membrane_current_names\n return channel_states + membrane_states, synapse_states\n\n def get_parameters(self) -> List[Dict[str, jnp.ndarray]]:\n \"\"\"Get all trainable parameters.\n\n The returned parameters should be passed to `jx.integrate(..., params=params).\n\n Returns:\n A list of all trainable parameters in the form of\n [{\"gNa\": jnp.array([0.1, 0.2, 0.3])}, ...].\n \"\"\"\n return self.trainable_params\n\n @only_allow_module\n def get_all_parameters(\n self, pstate: List[Dict], voltage_solver: str\n ) -> Dict[str, jnp.ndarray]:\n # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n \"\"\"Return all parameters (and coupling conductances) needed to simulate.\n\n Runs `_compute_axial_conductances()` and return every parameter that is needed\n to solve the ODE. This includes conductances, radiuses, lengths,\n axial_resistivities, but also coupling conductances.\n\n This is done by first obtaining the current value of every parameter (not only\n the trainable ones) and then replacing the trainable ones with the value\n in `trainable_params()`. This function is run within `jx.integrate()`.\n\n pstate can be obtained by calling `params_to_pstate()`.\n\n .. code-block:: python\n\n params = module.get_parameters() # i.e. [0, 1, 2]\n pstate = params_to_pstate(params, module.indices_set_by_trainables)\n module.to_jax() # needed for call to module.jaxnodes\n\n Args:\n pstate: The state of the trainable parameters. pstate takes the form\n [{\n \"key\": \"gNa\", \"indices\": jnp.array([0, 1, 2]),\n \"val\": jnp.array([0.1, 0.2, 0.3])\n }, ...].\n voltage_solver: The voltage solver that is used. Since `jax.sparse` and\n `jaxley.xyz` require different formats of the axial conductances, this\n function will default to different building methods.\n\n Returns:\n A dictionary of all module parameters.\n \"\"\"\n params = {}\n for key in [\"radius\", \"length\", \"axial_resistivity\", \"capacitance\"]:\n params[key] = self.base.jaxnodes[key]\n\n for channel in self.base.channels:\n for channel_params in channel.channel_params:\n params[channel_params] = self.base.jaxnodes[channel_params]\n\n for synapse_params in self.base.synapse_param_names:\n params[synapse_params] = self.base.jaxedges[synapse_params]\n\n # Override with those parameters set by `.make_trainable()`.\n for parameter in pstate:\n key = parameter[\"key\"]\n inds = parameter[\"indices\"]\n set_param = parameter[\"val\"]\n\n # This is needed since SynapseViews worked differently before.\n # This mimics the old behaviour and tranformes the new indices\n # to the old indices.\n # TODO FROM #447: Longterm this should be gotten rid of.\n # Instead edges should work similar to nodes (would also allow for\n # param sharing).\n synapse_inds = self.base.edges.groupby(\"type\").rank()[\"global_edge_index\"]\n synapse_inds = (synapse_inds.astype(int) - 1).to_numpy()\n if key in self.base.synapse_param_names:\n inds = synapse_inds[inds]\n\n if key in params: # Only parameters, not initial states.\n # `inds` is of shape `(num_params, num_comps_per_param)`.\n # `set_param` is of shape `(num_params,)`\n # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the\n # `.set()` to work. This is done with `[:, None]`.\n params[key] = params[key].at[inds].set(set_param[:, None])\n\n # Compute conductance params and add them to the params dictionary.\n params[\"axial_conductances\"] = self.base._compute_axial_conductances(\n params=params\n )\n return params\n\n @only_allow_module\n def _get_states_from_nodes_and_edges(self) -> Dict[str, jnp.ndarray]:\n # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n \"\"\"Return states as they are set in the `.nodes` and `.edges` tables.\"\"\"\n self.base.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`.\n states = {\"v\": self.base.jaxnodes[\"v\"]}\n # Join node and edge states into a single state dictionary.\n for channel in self.base.channels:\n for channel_states in channel.channel_states:\n states[channel_states] = self.base.jaxnodes[channel_states]\n for synapse_states in self.base.synapse_state_names:\n states[synapse_states] = self.base.jaxedges[synapse_states]\n return states\n\n @only_allow_module\n def get_all_states(\n self, pstate: List[Dict], all_params, delta_t: float\n ) -> Dict[str, jnp.ndarray]:\n # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n \"\"\"Get the full initial state of the module from jaxnodes and trainables.\n\n Args:\n pstate: The state of the trainable parameters.\n all_params: All parameters of the module.\n delta_t: The time step.\n\n Returns:\n A dictionary of all states of the module.\n \"\"\"\n states = self.base._get_states_from_nodes_and_edges()\n\n # Override with the initial states set by `.make_trainable()`.\n for parameter in pstate:\n key = parameter[\"key\"]\n inds = parameter[\"indices\"]\n set_param = parameter[\"val\"]\n if key in states: # Only initial states, not parameters.\n # `inds` is of shape `(num_params, num_comps_per_param)`.\n # `set_param` is of shape `(num_params,)`\n # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the\n # `.set()` to work. This is done with `[:, None]`.\n states[key] = states[key].at[inds].set(set_param[:, None])\n\n # Add to the states the initial current through every channel.\n states, _ = self.base._channel_currents(\n states, delta_t, self.channels, self.nodes, all_params\n )\n\n # Add to the states the initial current through every synapse.\n states, _ = self.base._synapse_currents(\n states, self.synapses, all_params, delta_t, self.edges\n )\n return states\n\n @property\n def initialized(self) -> bool:\n \"\"\"Whether the `Module` is ready to be solved or not.\"\"\"\n return self.initialized_morph\n\n def _initialize(self):\n \"\"\"Initialize the module.\"\"\"\n self._init_morph()\n return self\n\n @only_allow_module\n def init_states(self, delta_t: float = 0.025):\n # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n \"\"\"Initialize all mechanisms in their steady state.\n\n This considers the voltages and parameters of each compartment.\n\n Args:\n delta_t: Passed on to `channel.init_state()`.\n \"\"\"\n # Update states of the channels.\n channel_nodes = self.base.nodes\n states = self.base._get_states_from_nodes_and_edges()\n\n # We do not use any `pstate` for initializing. In principle, we could change\n # that by allowing an input `params` and `pstate` to this function.\n # `voltage_solver` could also be `jax.sparse` here, because both of them\n # build the channel parameters in the same way.\n params = self.base.get_all_parameters([], voltage_solver=\"jaxley.thomas\")\n\n for channel in self.base.channels:\n name = channel._name\n channel_indices = channel_nodes.loc[channel_nodes[name]][\n \"global_comp_index\"\n ].to_numpy()\n voltages = channel_nodes.loc[channel_indices, \"v\"].to_numpy()\n\n channel_param_names = list(channel.channel_params.keys())\n channel_state_names = list(channel.channel_states.keys())\n channel_states = query_channel_states_and_params(\n states, channel_state_names, channel_indices\n )\n channel_params = query_channel_states_and_params(\n params, channel_param_names, channel_indices\n )\n\n init_state = channel.init_state(\n channel_states, voltages, channel_params, delta_t\n )\n\n # `init_state` might not return all channel states. Only the ones that are\n # returned are updated here.\n for key, val in init_state.items():\n # Note that we are overriding `self.nodes` here, but `self.nodes` is\n # not used above to actually compute the current states (so there are\n # no issues with overriding states).\n self.nodes.loc[channel_indices, key] = val\n\n def _init_morph_for_debugging(self):\n \"\"\"Instandiates row and column inds which can be used to solve the voltage eqs.\n\n This is important only for expert users who try to modify the solver for the\n voltage equations. By default, this function is never run.\n\n This is useful for debugging the solver because one can use\n `scipy.linalg.sparse.spsolve` after every step of the solve.\n\n Here is the code snippet that can be used for debugging then (to be inserted in\n `solver_voltage`):\n ```python\n from scipy.sparse import csc_matrix\n from scipy.sparse.linalg import spsolve\n from jaxley.utils.debug_solver import build_voltage_matrix_elements\n\n elements, solve, num_entries, start_ind_for_branchpoints = (\n build_voltage_matrix_elements(\n uppers,\n lowers,\n diags,\n solves,\n branchpoint_conds_children[debug_states[\"child_inds\"]],\n branchpoint_conds_parents[debug_states[\"par_inds\"]],\n branchpoint_weights_children[debug_states[\"child_inds\"]],\n branchpoint_weights_parents[debug_states[\"par_inds\"]],\n branchpoint_diags,\n branchpoint_solves,\n debug_states[\"ncomp\"],\n nbranches,\n )\n )\n sparse_matrix = csc_matrix(\n (elements, (debug_states[\"row_inds\"], debug_states[\"col_inds\"])),\n shape=(num_entries, num_entries),\n )\n solution = spsolve(sparse_matrix, solve)\n solution = solution[:start_ind_for_branchpoints] # Delete branchpoint voltages.\n solves = jnp.reshape(solution, (debug_states[\"ncomp\"], nbranches))\n return solves\n ```\n \"\"\"\n # For scipy and jax.scipy.\n row_and_col_inds = compute_morphology_indices(\n len(self.base._par_inds),\n self.base._child_belongs_to_branchpoint,\n self.base._par_inds,\n self.base._child_inds,\n self.base.ncomp,\n self.base.total_nbranches,\n )\n\n num_elements = len(row_and_col_inds[\"row_inds\"])\n data_inds, indices, indptr = convert_to_csc(\n num_elements=num_elements,\n row_ind=row_and_col_inds[\"row_inds\"],\n col_ind=row_and_col_inds[\"col_inds\"],\n )\n self.base.debug_states[\"row_inds\"] = row_and_col_inds[\"row_inds\"]\n self.base.debug_states[\"col_inds\"] = row_and_col_inds[\"col_inds\"]\n self.base.debug_states[\"data_inds\"] = data_inds\n self.base.debug_states[\"indices\"] = indices\n self.base.debug_states[\"indptr\"] = indptr\n\n self.base.debug_states[\"ncomp\"] = self.base.ncomp\n self.base.debug_states[\"child_inds\"] = self.base._child_inds\n self.base.debug_states[\"par_inds\"] = self.base._par_inds\n\n def record(self, state: str = \"v\", verbose=True):\n comp_states, edge_states = self._get_state_names()\n if state not in comp_states + edge_states:\n raise KeyError(f\"{state} is not a recognized state in this module.\")\n in_view = self._nodes_in_view if state in comp_states else self._edges_in_view\n\n new_recs = pd.DataFrame(in_view, columns=[\"rec_index\"])\n new_recs[\"state\"] = state\n self.base.recordings = pd.concat([self.base.recordings, new_recs])\n has_duplicates = self.base.recordings.duplicated()\n self.base.recordings = self.base.recordings.loc[~has_duplicates]\n if verbose:\n print(\n f\"Added {len(in_view)-sum(has_duplicates)} recordings. See `.recordings` for details.\"\n )\n\n def _update_view(self):\n \"\"\"Update the attrs of the view after changes in the base module.\"\"\"\n if isinstance(self, View):\n scope = self._scope\n current_view = self._current_view\n # copy dict of new View. For some reason doing self = View(self)\n # did not work.\n self.__dict__ = View(\n self.base, self._nodes_in_view, self._edges_in_view\n ).__dict__\n\n # retain the scope and current_view of the previous view\n self._scope = scope\n self._current_view = current_view\n\n def delete_recordings(self):\n \"\"\"Removes all recordings from the module.\"\"\"\n if isinstance(self, View):\n base_recs = self.base.recordings\n self.base.recordings = base_recs[\n ~base_recs.isin(self.recordings).all(axis=1)\n ]\n self._update_view()\n else:\n self.base.recordings = pd.DataFrame().from_dict({})\n\n def stimulate(self, current: Optional[jnp.ndarray] = None, verbose: bool = True):\n \"\"\"Insert a stimulus into the compartment.\n\n current must be a 1d array or have batch dimension of size `(num_compartments, )`\n or `(1, )`. If 1d, the same stimulus is added to all compartments.\n\n This function cannot be run during `jax.jit` and `jax.grad`. Because of this,\n it should only be used for static stimuli (i.e., stimuli that do not depend\n on the data and that should not be learned). For stimuli that depend on data\n (or that should be learned), please use `data_stimulate()`.\n\n Args:\n current: Current in `nA`.\n \"\"\"\n self._external_input(\"i\", current, verbose=verbose)\n\n def clamp(self, state_name: str, state_array: jnp.ndarray, verbose: bool = True):\n \"\"\"Clamp a state to a given value across specified compartments.\n\n Args:\n state_name: The name of the state to clamp.\n state_array (jnp.nd: Array of values to clamp the state to.\n verbose : If True, prints details about the clamping.\n\n This function sets external states for the compartments.\n \"\"\"\n self._external_input(state_name, state_array, verbose=verbose)\n\n def _external_input(\n self,\n key: str,\n values: Optional[jnp.ndarray],\n verbose: bool = True,\n ):\n comp_states, edge_states = self._get_state_names()\n if key not in comp_states + edge_states:\n raise KeyError(f\"{key} is not a recognized state in this module.\")\n values = values if values.ndim == 2 else jnp.expand_dims(values, axis=0)\n batch_size = values.shape[0]\n num_inserted = (\n len(self._nodes_in_view) if key in comp_states else len(self._edges_in_view)\n )\n is_multiple = num_inserted == batch_size\n values = values if is_multiple else jnp.repeat(values, num_inserted, axis=0)\n assert batch_size in [\n 1,\n num_inserted,\n ], \"Number of comps and stimuli do not match.\"\n\n if key in self.base.externals.keys():\n self.base.externals[key] = jnp.concatenate(\n [self.base.externals[key], values]\n )\n self.base.external_inds[key] = jnp.concatenate(\n [self.base.external_inds[key], self._nodes_in_view]\n )\n else:\n if key in comp_states:\n self.base.externals[key] = values\n self.base.external_inds[key] = self._nodes_in_view\n else:\n self.base.externals[key] = values\n self.base.external_inds[key] = self._edges_in_view\n if verbose:\n print(\n f\"Added {num_inserted} external_states. See `.externals` for details.\"\n )\n\n def data_stimulate(\n self,\n current: jnp.ndarray,\n data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n verbose: bool = False,\n ) -> Tuple[jnp.ndarray, pd.DataFrame]:\n \"\"\"Insert a stimulus into the module within jit (or grad).\n\n Args:\n current: Current in `nA`.\n verbose: Whether or not to print the number of inserted stimuli. `False`\n by default because this method is meant to be jitted.\n \"\"\"\n return self._data_external_input(\n \"i\", current, data_stimuli, self.nodes, verbose=verbose\n )\n\n def data_clamp(\n self,\n state_name: str,\n state_array: jnp.ndarray,\n data_clamps: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n verbose: bool = False,\n ):\n \"\"\"Insert a clamp into the module within jit (or grad).\n\n Args:\n state_name: Name of the state variable to set.\n state_array: Time series of the state variable in the default Jaxley unit.\n State array should be of shape (num_clamps, simulation_time) or\n (simulation_time, ) for a single clamp.\n verbose: Whether or not to print the number of inserted clamps. `False`\n by default because this method is meant to be jitted.\n \"\"\"\n comp_states, edge_states = self._get_state_names()\n if state_name not in comp_states + edge_states:\n raise KeyError(f\"{state_name} is not a recognized state in this module.\")\n data = self.nodes if state_name in comp_states else self.edges\n return self._data_external_input(\n state_name, state_array, data_clamps, data, verbose=verbose\n )\n\n def _data_external_input(\n self,\n state_name: str,\n state_array: jnp.ndarray,\n data_external_input: Optional[Tuple[jnp.ndarray, pd.DataFrame]],\n view: pd.DataFrame,\n verbose: bool = False,\n ):\n comp_states, edge_states = self._get_state_names()\n state_array = (\n state_array\n if state_array.ndim == 2\n else jnp.expand_dims(state_array, axis=0)\n )\n batch_size = state_array.shape[0]\n num_inserted = (\n len(self._nodes_in_view)\n if state_name in comp_states\n else len(self._edges_in_view)\n )\n is_multiple = num_inserted == batch_size\n state_array = (\n state_array\n if is_multiple\n else jnp.repeat(state_array, num_inserted, axis=0)\n )\n assert batch_size in [\n 1,\n num_inserted,\n ], \"Number of comps and clamps do not match.\"\n\n if data_external_input is not None:\n external_input = data_external_input[1]\n external_input = jnp.concatenate([external_input, state_array])\n inds = data_external_input[2]\n else:\n external_input = state_array\n inds = pd.DataFrame().from_dict({})\n\n inds = pd.concat([inds, view])\n\n if verbose:\n if state_name == \"i\":\n print(f\"Added {len(view)} stimuli.\")\n else:\n print(f\"Added {len(view)} clamps.\")\n\n return (state_name, external_input, inds)\n\n def delete_stimuli(self):\n \"\"\"Removes all stimuli from the module.\"\"\"\n self.delete_clamps(\"i\")\n\n def delete_clamps(self, state_name: Optional[str] = None):\n \"\"\"Removes all clamps of the given state from the module.\"\"\"\n all_externals = list(self.externals.keys())\n if \"i\" in all_externals:\n all_externals.remove(\"i\")\n state_names = all_externals if state_name is None else [state_name]\n for state_name in state_names:\n if state_name in self.externals:\n keep_inds = ~np.isin(\n self.base.external_inds[state_name], self._nodes_in_view\n )\n base_exts = self.base.externals\n base_exts_inds = self.base.external_inds\n if np.all(~keep_inds):\n base_exts.pop(state_name, None)\n base_exts_inds.pop(state_name, None)\n else:\n base_exts[state_name] = base_exts[state_name][keep_inds]\n base_exts_inds[state_name] = base_exts_inds[state_name][keep_inds]\n self._update_view()\n else:\n pass # does not have to be deleted if not in externals\n\n def insert(self, channel: Channel):\n \"\"\"Insert a channel into the module.\n\n Args:\n channel: The channel to insert.\"\"\"\n name = channel._name\n\n # Channel does not yet exist in the `jx.Module` at all.\n if name not in [c._name for c in self.base.channels]:\n self.base.channels.append(channel)\n self.base.nodes[name] = (\n False # Previous columns do not have the new channel.\n )\n\n if channel.current_name not in self.base.membrane_current_names:\n self.base.membrane_current_names.append(channel.current_name)\n\n # Add a binary column that indicates if a channel is present.\n self.base.nodes.loc[self._nodes_in_view, name] = True\n\n # Loop over all new parameters, e.g. gNa, eNa.\n for key in channel.channel_params:\n self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_params[key]\n\n # Loop over all new parameters, e.g. gNa, eNa.\n for key in channel.channel_states:\n self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_states[key]\n\n def delete_channel(self, channel: Channel):\n \"\"\"Remove a channel from the module.\n\n Args:\n channel: The channel to remove.\"\"\"\n name = channel._name\n channel_names = [c._name for c in self.channels]\n all_channel_names = [c._name for c in self.base.channels]\n if name in channel_names:\n channel_cols = list(channel.channel_params.keys())\n channel_cols += list(channel.channel_states.keys())\n self.base.nodes.loc[self._nodes_in_view, channel_cols] = float(\"nan\")\n self.base.nodes.loc[self._nodes_in_view, name] = False\n\n # only delete cols if no other comps in the module have the same channel\n if np.all(~self.base.nodes[name]):\n self.base.channels.pop(all_channel_names.index(name))\n self.base.membrane_current_names.remove(channel.current_name)\n self.base.nodes.drop(columns=channel_cols + [name], inplace=True)\n else:\n raise ValueError(f\"Channel {name} not found in the module.\")\n\n @only_allow_module\n def step(\n self,\n u: Dict[str, jnp.ndarray],\n delta_t: float,\n external_inds: Dict[str, jnp.ndarray],\n externals: Dict[str, jnp.ndarray],\n params: Dict[str, jnp.ndarray],\n solver: str = \"bwd_euler\",\n voltage_solver: str = \"jaxley.stone\",\n ) -> Dict[str, jnp.ndarray]:\n \"\"\"One step of solving the Ordinary Differential Equation.\n\n This function is called inside of `integrate` and increments the state of the\n module by one time step. Calls `_step_channels` and `_step_synapse` to update\n the states of the channels and synapses using fwd_euler.\n\n Args:\n u: The state of the module. voltages = u[\"v\"]\n delta_t: The time step.\n external_inds: The indices of the external inputs.\n externals: The external inputs.\n params: The parameters of the module.\n solver: The solver to use for the voltages. Either of [\"bwd_euler\",\n \"fwd_euler\", \"crank_nicolson\"].\n voltage_solver: The tridiagonal solver used to diagonalize the\n coefficient matrix of the ODE system. Either of [\"jaxley.thomas\",\n \"jaxley.stone\"].\n\n Returns:\n The updated state of the module.\n \"\"\"\n\n # Extract the voltages\n voltages = u[\"v\"]\n\n # Extract the external inputs\n if \"i\" in externals.keys():\n i_current = externals[\"i\"]\n i_inds = external_inds[\"i\"]\n i_ext = self._get_external_input(\n voltages, i_inds, i_current, params[\"radius\"], params[\"length\"]\n )\n else:\n i_ext = 0.0\n\n # Step of the channels.\n u, (v_terms, const_terms) = self._step_channels(\n u, delta_t, self.channels, self.nodes, params\n )\n\n # Step of the synapse.\n u, (syn_v_terms, syn_const_terms) = self._step_synapse(\n u,\n self.synapses,\n params,\n delta_t,\n self.edges,\n )\n\n # Clamp for channels and synapses.\n for key in externals.keys():\n if key not in [\"i\", \"v\"]:\n u[key] = u[key].at[external_inds[key]].set(externals[key])\n\n # Voltage steps.\n cm = params[\"capacitance\"] # Abbreviation.\n\n # Arguments used by all solvers.\n solver_kwargs = {\n \"voltages\": voltages,\n \"voltage_terms\": (v_terms + syn_v_terms) / cm,\n \"constant_terms\": (const_terms + i_ext + syn_const_terms) / cm,\n \"axial_conductances\": params[\"axial_conductances\"],\n \"internal_node_inds\": self._internal_node_inds,\n }\n\n # Add solver specific arguments.\n if voltage_solver == \"jax.sparse\":\n solver_kwargs.update(\n {\n \"sinks\": np.asarray(self._comp_edges[\"sink\"].to_list()),\n \"data_inds\": self._data_inds,\n \"indices\": self._indices_jax_spsolve,\n \"indptr\": self._indptr_jax_spsolve,\n \"n_nodes\": self._n_nodes,\n }\n )\n # Only for `bwd_euler` and `cranck-nicolson`.\n step_voltage_implicit = step_voltage_implicit_with_jax_spsolve\n else:\n # Our custom sparse solver requires a different format of all conductance\n # values to perform triangulation and backsubstution optimally.\n #\n # Currently, the forward Euler solver also uses this format. However,\n # this is only for historical reasons and we are planning to change this in\n # the future.\n solver_kwargs.update(\n {\n \"sinks\": np.asarray(self._comp_edges[\"sink\"].to_list()),\n \"sources\": np.asarray(self._comp_edges[\"source\"].to_list()),\n \"types\": np.asarray(self._comp_edges[\"type\"].to_list()),\n \"ncomp_per_branch\": self.ncomp_per_branch,\n \"par_inds\": self._par_inds,\n \"child_inds\": self._child_inds,\n \"nbranches\": self.total_nbranches,\n \"solver\": voltage_solver,\n \"idx\": self._solve_indexer,\n \"debug_states\": self.debug_states,\n }\n )\n # Only for `bwd_euler` and `cranck-nicolson`.\n step_voltage_implicit = step_voltage_implicit_with_jaxley_spsolve\n\n if solver == \"bwd_euler\":\n u[\"v\"] = step_voltage_implicit(**solver_kwargs, delta_t=delta_t)\n elif solver == \"crank_nicolson\":\n # Crank-Nicolson advances by half a step of backward and half a step of\n # forward Euler.\n half_step_delta_t = delta_t / 2\n half_step_voltages = step_voltage_implicit(\n **solver_kwargs, delta_t=half_step_delta_t\n )\n # The forward Euler step in Crank-Nicolson can be performed easily as\n # `V_{n+1} = 2 * V_{n+1/2} - V_n`. See also NEURON book Chapter 4.\n u[\"v\"] = 2 * half_step_voltages - voltages\n elif solver == \"fwd_euler\":\n u[\"v\"] = step_voltage_explicit(**solver_kwargs, delta_t=delta_t)\n else:\n raise ValueError(\n f\"You specified `solver={solver}`. The only allowed solvers are \"\n \"['bwd_euler', 'fwd_euler', 'crank_nicolson'].\"\n )\n\n # Clamp for voltages.\n if \"v\" in externals.keys():\n u[\"v\"] = u[\"v\"].at[external_inds[\"v\"]].set(externals[\"v\"])\n\n return u\n\n def _step_channels(\n self,\n states: Dict[str, jnp.ndarray],\n delta_t: float,\n channels: List[Channel],\n channel_nodes: pd.DataFrame,\n params: Dict[str, jnp.ndarray],\n ) -> Tuple[Dict[str, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:\n \"\"\"One step of integration of the channels and of computing their current.\"\"\"\n states = self._step_channels_state(\n states, delta_t, channels, channel_nodes, params\n )\n states, current_terms = self._channel_currents(\n states, delta_t, channels, channel_nodes, params\n )\n return states, current_terms\n\n def _step_channels_state(\n self,\n states,\n delta_t,\n channels: List[Channel],\n channel_nodes: pd.DataFrame,\n params: Dict[str, jnp.ndarray],\n ) -> Dict[str, jnp.ndarray]:\n \"\"\"One integration step of the channels.\"\"\"\n voltages = states[\"v\"]\n\n # Update states of the channels.\n indices = channel_nodes[\"global_comp_index\"].to_numpy()\n for channel in channels:\n channel_param_names = list(channel.channel_params)\n channel_param_names += [\n \"radius\",\n \"length\",\n \"axial_resistivity\",\n \"capacitance\",\n ]\n channel_state_names = list(channel.channel_states)\n channel_state_names += self.membrane_current_names\n channel_indices = indices[channel_nodes[channel._name].astype(bool)]\n\n channel_params = query_channel_states_and_params(\n params, channel_param_names, channel_indices\n )\n channel_states = query_channel_states_and_params(\n states, channel_state_names, channel_indices\n )\n\n states_updated = channel.update_states(\n channel_states, delta_t, voltages[channel_indices], channel_params\n )\n # Rebuild state. This has to be done within the loop over channels to allow\n # multiple channels which modify the same state.\n for key, val in states_updated.items():\n states[key] = states[key].at[channel_indices].set(val)\n\n return states\n\n def _channel_currents(\n self,\n states: Dict[str, jnp.ndarray],\n delta_t: float,\n channels: List[Channel],\n channel_nodes: pd.DataFrame,\n params: Dict[str, jnp.ndarray],\n ) -> Tuple[Dict[str, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:\n \"\"\"Return the current through each channel.\n\n This is also updates `state` because the `state` also contains the current.\n \"\"\"\n voltages = states[\"v\"]\n\n # Compute current through channels.\n voltage_terms = jnp.zeros_like(voltages)\n constant_terms = jnp.zeros_like(voltages)\n # Run with two different voltages that are `diff` apart to infer the slope and\n # offset.\n diff = 1e-3\n\n current_states = {}\n for name in self.membrane_current_names:\n current_states[name] = jnp.zeros_like(voltages)\n\n for channel in channels:\n name = channel._name\n channel_param_names = list(channel.channel_params.keys())\n channel_state_names = list(channel.channel_states.keys())\n indices = channel_nodes.loc[channel_nodes[name]][\n \"global_comp_index\"\n ].to_numpy()\n\n channel_params = {}\n for p in channel_param_names:\n channel_params[p] = params[p][indices]\n channel_params[\"radius\"] = params[\"radius\"][indices]\n channel_params[\"length\"] = params[\"length\"][indices]\n channel_params[\"axial_resistivity\"] = params[\"axial_resistivity\"][indices]\n\n channel_states = {}\n for s in channel_state_names:\n channel_states[s] = states[s][indices]\n\n v_and_perturbed = jnp.stack([voltages[indices], voltages[indices] + diff])\n membrane_currents = vmap(channel.compute_current, in_axes=(None, 0, None))(\n channel_states, v_and_perturbed, channel_params\n )\n voltage_term = (membrane_currents[1] - membrane_currents[0]) / diff\n constant_term = membrane_currents[0] - voltage_term * voltages[indices]\n\n # * 1000 to convert from mA/cm^2 to uA/cm^2.\n voltage_terms = voltage_terms.at[indices].add(voltage_term * 1000.0)\n constant_terms = constant_terms.at[indices].add(-constant_term * 1000.0)\n\n # Save the current (for the unperturbed voltage) as a state that will\n # also be passed to the state update.\n current_states[channel.current_name] = (\n current_states[channel.current_name]\n .at[indices]\n .add(membrane_currents[0])\n )\n\n # Copy the currents into the `state` dictionary such that they can be\n # recorded and used by `Channel.update_states()`.\n for name in self.membrane_current_names:\n states[name] = current_states[name]\n\n return states, (voltage_terms, constant_terms)\n\n def _step_synapse(\n self,\n u: Dict[str, jnp.ndarray],\n syn_channels: List[Channel],\n params: Dict[str, jnp.ndarray],\n delta_t: float,\n edges: pd.DataFrame,\n ) -> Tuple[Dict[str, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:\n \"\"\"One step of integration of the channels.\n\n `Network` overrides this method (because it actually has synapses), whereas\n `Compartment`, `Branch`, and `Cell` do not override this.\n \"\"\"\n voltages = u[\"v\"]\n return u, (jnp.zeros_like(voltages), jnp.zeros_like(voltages))\n\n def _synapse_currents(\n self, states, syn_channels, params, delta_t, edges: pd.DataFrame\n ) -> Tuple[Dict[str, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:\n return states, (None, None)\n\n @staticmethod\n def _get_external_input(\n voltages: jnp.ndarray,\n i_inds: jnp.ndarray,\n i_stim: jnp.ndarray,\n radius: float,\n length_single_compartment: float,\n ) -> jnp.ndarray:\n \"\"\"\n Return external input to each compartment in uA / cm^2.\n\n Args:\n voltages: mV.\n i_stim: nA.\n radius: um.\n length_single_compartment: um.\n \"\"\"\n zero_vec = jnp.zeros_like(voltages)\n current = convert_point_process_to_distributed(\n i_stim, radius[i_inds], length_single_compartment[i_inds]\n )\n\n dnums = ScatterDimensionNumbers(\n update_window_dims=(),\n inserted_window_dims=(0,),\n scatter_dims_to_operand_dims=(0,),\n )\n stim_at_timestep = scatter_add(zero_vec, i_inds[:, None], current, dnums)\n return stim_at_timestep\n\n def vis(\n self,\n ax: Optional[Axes] = None,\n col: str = \"k\",\n dims: Tuple[int] = (0, 1),\n type: str = \"line\",\n morph_plot_kwargs: Dict = {},\n ) -> Axes:\n \"\"\"Visualize the module.\n\n Modules can be visualized on one of the cardinal planes (xy, xz, yz) or\n even in 3D.\n\n Several options are available:\n - `line`: All points from the traced morphology (`xyzr`), are connected\n with a line plot.\n - `scatter`: All traced points, are plotted as scatter points.\n - `comp`: Plots the compartmentalized morphology, including radius\n and shape. (shows the true compartment lengths per default, but this can\n be changed via the `morph_plot_kwargs`, for details see\n `jaxley.utils.plot_utils.plot_comps`).\n - `morph`: Reconstructs the 3D shape of the traced morphology. For details see\n `jaxley.utils.plot_utils.plot_morph`. Warning: For 3D plots and morphologies\n with many traced points this can be very slow.\n\n Args:\n ax: An axis into which to plot.\n col: The color for all branches.\n dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n two of them.\n type: The type of plot. One of [\"line\", \"scatter\", \"comp\", \"morph\"].\n morph_plot_kwargs: Keyword arguments passed to the plotting function.\n \"\"\"\n if \"comp\" in type.lower():\n return plot_comps(self, dims=dims, ax=ax, col=col, **morph_plot_kwargs)\n if \"morph\" in type.lower():\n return plot_morph(self, dims=dims, ax=ax, col=col, **morph_plot_kwargs)\n\n assert not np.any(\n [np.isnan(xyzr[:, dims]).all() for xyzr in self.xyzr]\n ), \"No coordinates available. Use `vis(detail='point')` or run `.compute_xyz()` before running `.vis()`.\"\n\n ax = plot_graph(\n self.xyzr,\n dims=dims,\n col=col,\n ax=ax,\n type=type,\n morph_plot_kwargs=morph_plot_kwargs,\n )\n\n return ax\n\n def compute_xyz(self):\n \"\"\"Return xyz coordinates of every branch, based on the branch length.\n\n This function should not be called if the morphology was read from an `.swc`\n file. However, for morphologies that were constructed from scratch, this\n function **must** be called before `.vis()`. The computed `xyz` coordinates\n are only used for plotting.\n \"\"\"\n max_y_multiplier = 5.0\n min_y_multiplier = 0.5\n\n parents = self.comb_parents\n num_children = _compute_num_children(parents)\n index_of_child = _compute_index_of_child(parents)\n levels = compute_levels(parents)\n\n # Extract branch.\n inds_branch = self.nodes.groupby(\"global_branch_index\")[\n \"global_comp_index\"\n ].apply(list)\n branch_lens = [np.sum(self.nodes[\"length\"][np.asarray(i)]) for i in inds_branch]\n endpoints = []\n\n # Different levels will get a different \"angle\" at which the children emerge from\n # the parents. This angle is defined by the `y_offset_multiplier`. This value\n # defines the range between y-location of the first and of the last child of a\n # parent.\n y_offset_multiplier = np.linspace(\n max_y_multiplier, min_y_multiplier, np.max(levels) + 1\n )\n\n for b in range(self.total_nbranches):\n # For networks with mixed SWC and from-scatch neurons, only update those\n # branches that do not have coordingates yet.\n if np.any(np.isnan(self.xyzr[b])):\n if parents[b] > -1:\n start_point = endpoints[parents[b]]\n num_children_of_parent = num_children[parents[b]]\n if num_children_of_parent > 1:\n y_offset = (\n ((index_of_child[b] / (num_children_of_parent - 1))) - 0.5\n ) * y_offset_multiplier[levels[b]]\n else:\n y_offset = 0.0\n else:\n start_point = [0, 0, 0]\n y_offset = 0.0\n\n len_of_path = np.sqrt(y_offset**2 + 1.0)\n\n end_point = [\n start_point[0] + branch_lens[b] / len_of_path * 1.0,\n start_point[1] + branch_lens[b] / len_of_path * y_offset,\n start_point[2],\n ]\n endpoints.append(end_point)\n\n self.xyzr[b][:, :3] = np.asarray([start_point, end_point])\n else:\n # Dummy to keey the index `endpoints[parent[b]]` above working.\n endpoints.append(np.zeros((2,)))\n\n def move(\n self, x: float = 0.0, y: float = 0.0, z: float = 0.0, update_nodes: bool = False\n ):\n \"\"\"Move cells or networks by adding to their (x, y, z) coordinates.\n\n This function is used only for visualization. It does not affect the simulation.\n\n Args:\n x: The amount to move in the x direction in um.\n y: The amount to move in the y direction in um.\n z: The amount to move in the z direction in um.\n update_nodes: Whether `.nodes` should be updated or not. Setting this to\n `False` largely speeds up moving, especially for big networks, but\n `.nodes` or `.show` will not show the new xyz coordinates.\n \"\"\"\n for i in self._branches_in_view:\n self.base.xyzr[i][:, :3] += np.array([x, y, z])\n if update_nodes:\n self.compute_compartment_centers()\n\n def move_to(\n self,\n x: Union[float, np.ndarray] = 0.0,\n y: Union[float, np.ndarray] = 0.0,\n z: Union[float, np.ndarray] = 0.0,\n update_nodes: bool = False,\n ):\n \"\"\"Move cells or networks to a location (x, y, z).\n\n If x, y, and z are floats, then the first compartment of the first branch\n of the first cell is moved to that float coordinate, and everything else is\n shifted by the difference between that compartment's previous coordinate and\n the new float location.\n\n If x, y, and z are arrays, then they must each have a length equal to the number\n of cells being moved. Then the first compartment of the first branch of each\n cell is moved to the specified location.\n\n Args:\n update_nodes: Whether `.nodes` should be updated or not. Setting this to\n `False` largely speeds up moving, especially for big networks, but\n `.nodes` or `.show` will not show the new xyz coordinates.\n \"\"\"\n # Test if any coordinate values are NaN which would greatly affect moving\n if np.any(np.concatenate(self.xyzr, axis=0)[:, :3] == np.nan):\n raise ValueError(\n \"NaN coordinate values detected. Shift amounts cannot be computed. Please run compute_xyzr() or assign initial coordinate values.\"\n )\n\n # can only iterate over cells for networks\n # lambda makes sure that generator can be created multiple times\n base_is_net = self.base._current_view == \"network\"\n cells = lambda: (self.cells if base_is_net else [self])\n\n root_xyz_cells = np.array([c.xyzr[0][0, :3] for c in cells()])\n root_xyz = root_xyz_cells[0] if isinstance(x, float) else root_xyz_cells\n move_by = np.array([x, y, z]).T - root_xyz\n\n if len(move_by.shape) == 1:\n move_by = np.tile(move_by, (len(self._cells_in_view), 1))\n\n for cell, offset in zip(cells(), move_by):\n for idx in cell._branches_in_view:\n self.base.xyzr[idx][:, :3] += offset\n if update_nodes:\n self.compute_compartment_centers()\n\n def rotate(\n self, degrees: float, rotation_axis: str = \"xy\", update_nodes: bool = False\n ):\n \"\"\"Rotate jaxley modules clockwise. Used only for visualization.\n\n This function is used only for visualization. It does not affect the simulation.\n\n Args:\n degrees: How many degrees to rotate the module by.\n rotation_axis: Either of {`xy` | `xz` | `yz`}.\n \"\"\"\n degrees = degrees / 180 * np.pi\n if rotation_axis == \"xy\":\n dims = [0, 1]\n elif rotation_axis == \"xz\":\n dims = [0, 2]\n elif rotation_axis == \"yz\":\n dims = [1, 2]\n else:\n raise ValueError\n\n rotation_matrix = np.asarray(\n [[np.cos(degrees), np.sin(degrees)], [-np.sin(degrees), np.cos(degrees)]]\n )\n for i in self._branches_in_view:\n rot = np.dot(rotation_matrix, self.base.xyzr[i][:, dims].T).T\n self.base.xyzr[i][:, dims] = rot\n if update_nodes:\n self.compute_compartment_centers()\n\n def copy_node_property_to_edges(\n self,\n properties_to_import: Union[str, List[str]],\n pre_or_post: Union[str, List[str]] = [\"pre\", \"post\"],\n ) -> Module:\n \"\"\"Copy a property that is in `node` over to `edges`.\n\n By default, `.edges` does not contain the properties (radius, length, cm,\n channel properties,...) of the pre- and post-synaptic compartments. This\n method allows to copy a property of the pre- and/or post-synaptic compartment\n to the edges. It is then accessible as `module.edges.pre_property_name` or\n `module.edges.post_property_name`.\n\n Note that, if you modify the node property _after_ having run\n `copy_node_property_to_edges`, it will not automatically update the value in\n `.edges`.\n\n Note that, if this method is called on a View (e.g.\n `net.cell(0).copy_node_property_to_edges`), then it will return a View, but\n it will _not_ modify the module itself.\n\n Args:\n properties_to_import: The name of the node properties that should be\n imported. To list all available properties, look at\n `module.nodes.columns`.\n pre_or_post: Whether to import only the pre-synaptic property ('pre'), only\n the post-synaptic property ('post'), or both (['pre', 'post']).\n\n Returns:\n A new module which has the property copied to the `nodes`.\n \"\"\"\n # If a string is passed, wrap it as a list.\n if isinstance(pre_or_post, str):\n pre_or_post = [pre_or_post]\n if isinstance(properties_to_import, str):\n properties_to_import = [properties_to_import]\n\n for pre_or_post_val in pre_or_post:\n assert pre_or_post_val in [\"pre\", \"post\"]\n for property_to_import in properties_to_import:\n # Delete the column if it already exists. Otherwise it would exist\n # twice.\n if f\"{pre_or_post_val}_{property_to_import}\" in self.edges.columns:\n self.edges.drop(\n columns=f\"{pre_or_post_val}_{property_to_import}\", inplace=True\n )\n\n self.edges = self.edges.join(\n self.nodes[[property_to_import, \"global_comp_index\"]].set_index(\n \"global_comp_index\"\n ),\n on=f\"{pre_or_post_val}_global_comp_index\",\n )\n self.edges = self.edges.rename(\n columns={\n property_to_import: f\"{pre_or_post_val}_{property_to_import}\"\n }\n )\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.branches","title":"branches
property
","text":"Iterate over all branches in the module.
Returns a generator that yields a View of each branch.
"},{"location":"reference/modules/#jaxley.modules.base.Module.cells","title":"cells
property
","text":"Iterate over all cells in the module.
Returns a generator that yields a View of each cell.
"},{"location":"reference/modules/#jaxley.modules.base.Module.comps","title":"comps
property
","text":"Iterate over all compartments in the module. Can be called on any module, i.e. net.comps
, cell.comps
or branch.comps
. __iter__
does not allow for this.
Returns a generator that yields a View of each compartment.
"},{"location":"reference/modules/#jaxley.modules.base.Module.initialized","title":"initialized: bool
property
","text":"Whether the Module
is ready to be solved or not.
shape: Tuple[int]
property
","text":"Returns the number of submodules contained in a module.
.. code-block:: python
network.shape = (num_cells, num_branches, num_compartments)\ncell.shape = (num_branches, num_compartments)\nbranch.shape = (num_compartments,)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.view","title":"view
property
","text":"Return view of the module.
"},{"location":"reference/modules/#jaxley.modules.base.Module.__getitem__","title":"__getitem__(index)
","text":"Lazy indexing of the module.
Source code injaxley/modules/base.py
def __getitem__(self, index):\n \"\"\"Lazy indexing of the module.\"\"\"\n supported_parents = [\"network\", \"cell\", \"branch\"] # cannot index into comp\n\n not_group_view = self._current_view not in self.groups\n assert (\n self._current_view in supported_parents or not_group_view\n ), \"Lazy indexing is only supported for `Network`, `Cell`, `Branch` and Views thereof.\"\n index = index if isinstance(index, tuple) else (index,)\n\n child_views = self._childviews()\n assert len(index) <= len(child_views), \"Too many indices.\"\n view = self\n for i, child in zip(index, child_views):\n view = view._at_nodes(child, i)\n return view\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.__iter__","title":"__iter__()
","text":"Iterate over parts of the module.
Internally calls cells
, branches
, comps
at the appropriate level.
Example:
.. code-block:: python
for cell in network:\n for branch in cell:\n for comp in branch:\n print(comp.nodes.shape)\n
Source code in jaxley/modules/base.py
def __iter__(self):\n \"\"\"Iterate over parts of the module.\n\n Internally calls `cells`, `branches`, `comps` at the appropriate level.\n\n Example:\n\n .. code-block:: python\n\n for cell in network:\n for branch in cell:\n for comp in branch:\n print(comp.nodes.shape)\n \"\"\"\n next_level = self._childviews()[0]\n yield from self._iter_submodules(next_level)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.add_to_group","title":"add_to_group(group_name)
","text":"Add a view of the module to a group.
Groups can then be indexed. For example:
.. code-block:: python
net.cell(0).add_to_group(\"excitatory\")\nnet.excitatory.set(\"radius\", 0.1)\n
Parameters:
Name Type Description Defaultgroup_name
str
The name of the group.
required Source code injaxley/modules/base.py
def add_to_group(self, group_name: str):\n \"\"\"Add a view of the module to a group.\n\n Groups can then be indexed. For example:\n\n .. code-block:: python\n\n net.cell(0).add_to_group(\"excitatory\")\n net.excitatory.set(\"radius\", 0.1)\n\n Args:\n group_name: The name of the group.\n \"\"\"\n if group_name not in self.base.groups:\n self.base.groups[group_name] = self._nodes_in_view\n else:\n self.base.groups[group_name] = np.unique(\n np.concatenate([self.base.groups[group_name], self._nodes_in_view])\n )\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.branch","title":"branch(idx)
","text":"Return a View of the module at the selected branches(s).
Parameters:
Name Type Description Defaultidx
Any
index of the branch to view.
requiredReturns:
Type DescriptionView
View of the module at the specified branch index.
Source code injaxley/modules/base.py
def branch(self, idx: Any) -> View:\n \"\"\"Return a View of the module at the selected branches(s).\n\n Args:\n idx: index of the branch to view.\n\n Returns:\n View of the module at the specified branch index.\"\"\"\n return self._at_nodes(\"branch\", idx)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.cell","title":"cell(idx)
","text":"Return a View of the module at the selected cell(s).
Parameters:
Name Type Description Defaultidx
Any
index of the cell to view.
requiredReturns:
Type DescriptionView
View of the module at the specified cell index.
Source code injaxley/modules/base.py
def cell(self, idx: Any) -> View:\n \"\"\"Return a View of the module at the selected cell(s).\n\n Args:\n idx: index of the cell to view.\n\n Returns:\n View of the module at the specified cell index.\"\"\"\n return self._at_nodes(\"cell\", idx)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.clamp","title":"clamp(state_name, state_array, verbose=True)
","text":"Clamp a state to a given value across specified compartments.
Parameters:
Name Type Description Defaultstate_name
str
The name of the state to clamp.
requiredstate_array
nd
Array of values to clamp the state to.
requiredverbose
If True, prints details about the clamping.
True
This function sets external states for the compartments.
Source code injaxley/modules/base.py
def clamp(self, state_name: str, state_array: jnp.ndarray, verbose: bool = True):\n \"\"\"Clamp a state to a given value across specified compartments.\n\n Args:\n state_name: The name of the state to clamp.\n state_array (jnp.nd: Array of values to clamp the state to.\n verbose : If True, prints details about the clamping.\n\n This function sets external states for the compartments.\n \"\"\"\n self._external_input(state_name, state_array, verbose=verbose)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.comp","title":"comp(idx)
","text":"Return a View of the module at the selected compartments(s).
Parameters:
Name Type Description Defaultidx
Any
index of the comp to view.
requiredReturns:
Type DescriptionView
View of the module at the specified compartment index.
Source code injaxley/modules/base.py
def comp(self, idx: Any) -> View:\n \"\"\"Return a View of the module at the selected compartments(s).\n\n Args:\n idx: index of the comp to view.\n\n Returns:\n View of the module at the specified compartment index.\"\"\"\n return self._at_nodes(\"comp\", idx)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.compute_compartment_centers","title":"compute_compartment_centers()
","text":"Add compartment centers to nodes dataframe
Source code injaxley/modules/base.py
def compute_compartment_centers(self):\n \"\"\"Add compartment centers to nodes dataframe\"\"\"\n centers = self._compute_coords_of_comp_centers()\n self.base.nodes.loc[self._nodes_in_view, [\"x\", \"y\", \"z\"]] = centers\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.compute_xyz","title":"compute_xyz()
","text":"Return xyz coordinates of every branch, based on the branch length.
This function should not be called if the morphology was read from an .swc
file. However, for morphologies that were constructed from scratch, this function must be called before .vis()
. The computed xyz
coordinates are only used for plotting.
jaxley/modules/base.py
def compute_xyz(self):\n \"\"\"Return xyz coordinates of every branch, based on the branch length.\n\n This function should not be called if the morphology was read from an `.swc`\n file. However, for morphologies that were constructed from scratch, this\n function **must** be called before `.vis()`. The computed `xyz` coordinates\n are only used for plotting.\n \"\"\"\n max_y_multiplier = 5.0\n min_y_multiplier = 0.5\n\n parents = self.comb_parents\n num_children = _compute_num_children(parents)\n index_of_child = _compute_index_of_child(parents)\n levels = compute_levels(parents)\n\n # Extract branch.\n inds_branch = self.nodes.groupby(\"global_branch_index\")[\n \"global_comp_index\"\n ].apply(list)\n branch_lens = [np.sum(self.nodes[\"length\"][np.asarray(i)]) for i in inds_branch]\n endpoints = []\n\n # Different levels will get a different \"angle\" at which the children emerge from\n # the parents. This angle is defined by the `y_offset_multiplier`. This value\n # defines the range between y-location of the first and of the last child of a\n # parent.\n y_offset_multiplier = np.linspace(\n max_y_multiplier, min_y_multiplier, np.max(levels) + 1\n )\n\n for b in range(self.total_nbranches):\n # For networks with mixed SWC and from-scatch neurons, only update those\n # branches that do not have coordingates yet.\n if np.any(np.isnan(self.xyzr[b])):\n if parents[b] > -1:\n start_point = endpoints[parents[b]]\n num_children_of_parent = num_children[parents[b]]\n if num_children_of_parent > 1:\n y_offset = (\n ((index_of_child[b] / (num_children_of_parent - 1))) - 0.5\n ) * y_offset_multiplier[levels[b]]\n else:\n y_offset = 0.0\n else:\n start_point = [0, 0, 0]\n y_offset = 0.0\n\n len_of_path = np.sqrt(y_offset**2 + 1.0)\n\n end_point = [\n start_point[0] + branch_lens[b] / len_of_path * 1.0,\n start_point[1] + branch_lens[b] / len_of_path * y_offset,\n start_point[2],\n ]\n endpoints.append(end_point)\n\n self.xyzr[b][:, :3] = np.asarray([start_point, end_point])\n else:\n # Dummy to keey the index `endpoints[parent[b]]` above working.\n endpoints.append(np.zeros((2,)))\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.copy","title":"copy(reset_index=False, as_module=False)
","text":"Extract part of a module and return a copy of its View or a new module.
This can be used to call jx.integrate
on part of a Module.
Parameters:
Name Type Description Defaultreset_index
bool
if True, the indices of the new module are reset to start from 0.
False
as_module
bool
if True, a new module is returned instead of a View.
False
Returns:
Type DescriptionUnion[Module, View]
A part of the module or a copied view of it.
Source code injaxley/modules/base.py
def copy(\n self, reset_index: bool = False, as_module: bool = False\n) -> Union[Module, View]:\n \"\"\"Extract part of a module and return a copy of its View or a new module.\n\n This can be used to call `jx.integrate` on part of a Module.\n\n Args:\n reset_index: if True, the indices of the new module are reset to start from 0.\n as_module: if True, a new module is returned instead of a View.\n\n Returns:\n A part of the module or a copied view of it.\"\"\"\n view = deepcopy(self)\n warnings.warn(\"This method is experimental, use at your own risk.\")\n # TODO FROM #447: add reset_index, i.e. for parents, nodes, edges etc. such that they\n # start from 0/-1 and are contiguous\n if as_module:\n raise NotImplementedError(\"Not yet implemented.\")\n # initialize a new module with the same attributes\n return view\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.copy_node_property_to_edges","title":"copy_node_property_to_edges(properties_to_import, pre_or_post=['pre', 'post'])
","text":"Copy a property that is in node
over to edges
.
By default, .edges
does not contain the properties (radius, length, cm, channel properties,\u2026) of the pre- and post-synaptic compartments. This method allows to copy a property of the pre- and/or post-synaptic compartment to the edges. It is then accessible as module.edges.pre_property_name
or module.edges.post_property_name
.
Note that, if you modify the node property after having run copy_node_property_to_edges
, it will not automatically update the value in .edges
.
Note that, if this method is called on a View (e.g. net.cell(0).copy_node_property_to_edges
), then it will return a View, but it will not modify the module itself.
Parameters:
Name Type Description Defaultproperties_to_import
Union[str, List[str]]
The name of the node properties that should be imported. To list all available properties, look at module.nodes.columns
.
pre_or_post
Union[str, List[str]]
Whether to import only the pre-synaptic property (\u2018pre\u2019), only the post-synaptic property (\u2018post\u2019), or both ([\u2018pre\u2019, \u2018post\u2019]).
['pre', 'post']
Returns:
Type DescriptionModule
A new module which has the property copied to the nodes
.
jaxley/modules/base.py
def copy_node_property_to_edges(\n self,\n properties_to_import: Union[str, List[str]],\n pre_or_post: Union[str, List[str]] = [\"pre\", \"post\"],\n) -> Module:\n \"\"\"Copy a property that is in `node` over to `edges`.\n\n By default, `.edges` does not contain the properties (radius, length, cm,\n channel properties,...) of the pre- and post-synaptic compartments. This\n method allows to copy a property of the pre- and/or post-synaptic compartment\n to the edges. It is then accessible as `module.edges.pre_property_name` or\n `module.edges.post_property_name`.\n\n Note that, if you modify the node property _after_ having run\n `copy_node_property_to_edges`, it will not automatically update the value in\n `.edges`.\n\n Note that, if this method is called on a View (e.g.\n `net.cell(0).copy_node_property_to_edges`), then it will return a View, but\n it will _not_ modify the module itself.\n\n Args:\n properties_to_import: The name of the node properties that should be\n imported. To list all available properties, look at\n `module.nodes.columns`.\n pre_or_post: Whether to import only the pre-synaptic property ('pre'), only\n the post-synaptic property ('post'), or both (['pre', 'post']).\n\n Returns:\n A new module which has the property copied to the `nodes`.\n \"\"\"\n # If a string is passed, wrap it as a list.\n if isinstance(pre_or_post, str):\n pre_or_post = [pre_or_post]\n if isinstance(properties_to_import, str):\n properties_to_import = [properties_to_import]\n\n for pre_or_post_val in pre_or_post:\n assert pre_or_post_val in [\"pre\", \"post\"]\n for property_to_import in properties_to_import:\n # Delete the column if it already exists. Otherwise it would exist\n # twice.\n if f\"{pre_or_post_val}_{property_to_import}\" in self.edges.columns:\n self.edges.drop(\n columns=f\"{pre_or_post_val}_{property_to_import}\", inplace=True\n )\n\n self.edges = self.edges.join(\n self.nodes[[property_to_import, \"global_comp_index\"]].set_index(\n \"global_comp_index\"\n ),\n on=f\"{pre_or_post_val}_global_comp_index\",\n )\n self.edges = self.edges.rename(\n columns={\n property_to_import: f\"{pre_or_post_val}_{property_to_import}\"\n }\n )\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.data_clamp","title":"data_clamp(state_name, state_array, data_clamps=None, verbose=False)
","text":"Insert a clamp into the module within jit (or grad).
Parameters:
Name Type Description Defaultstate_name
str
Name of the state variable to set.
requiredstate_array
ndarray
Time series of the state variable in the default Jaxley unit. State array should be of shape (num_clamps, simulation_time) or (simulation_time, ) for a single clamp.
requiredverbose
bool
Whether or not to print the number of inserted clamps. False
by default because this method is meant to be jitted.
False
Source code in jaxley/modules/base.py
def data_clamp(\n self,\n state_name: str,\n state_array: jnp.ndarray,\n data_clamps: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n verbose: bool = False,\n):\n \"\"\"Insert a clamp into the module within jit (or grad).\n\n Args:\n state_name: Name of the state variable to set.\n state_array: Time series of the state variable in the default Jaxley unit.\n State array should be of shape (num_clamps, simulation_time) or\n (simulation_time, ) for a single clamp.\n verbose: Whether or not to print the number of inserted clamps. `False`\n by default because this method is meant to be jitted.\n \"\"\"\n comp_states, edge_states = self._get_state_names()\n if state_name not in comp_states + edge_states:\n raise KeyError(f\"{state_name} is not a recognized state in this module.\")\n data = self.nodes if state_name in comp_states else self.edges\n return self._data_external_input(\n state_name, state_array, data_clamps, data, verbose=verbose\n )\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.data_set","title":"data_set(key, val, param_state)
","text":"Set parameter of module (or its view) to a new value within jit
.
Parameters:
Name Type Description Defaultkey
str
The name of the parameter to set.
requiredval
Union[float, ndarray]
The value to set the parameter to. If it is jnp.ndarray
then it must be of shape (len(num_compartments))
.
param_state
Optional[List[Dict]]
State of the setted parameters, internally used such that this function does not modify global state.
required Source code injaxley/modules/base.py
def data_set(\n self,\n key: str,\n val: Union[float, jnp.ndarray],\n param_state: Optional[List[Dict]],\n):\n \"\"\"Set parameter of module (or its view) to a new value within `jit`.\n\n Args:\n key: The name of the parameter to set.\n val: The value to set the parameter to. If it is `jnp.ndarray` then it\n must be of shape `(len(num_compartments))`.\n param_state: State of the setted parameters, internally used such that this\n function does not modify global state.\n \"\"\"\n # Note: `data_set` does not support arrays for `val`.\n is_node_param = key in self.nodes.columns\n data = self.nodes if is_node_param else self.edges\n viewed_inds = self._nodes_in_view if is_node_param else self._edges_in_view\n if key in data.columns:\n not_nan = ~data[key].isna()\n added_param_state = [\n {\n \"indices\": np.atleast_2d(viewed_inds[not_nan]),\n \"key\": key,\n \"val\": jnp.atleast_1d(jnp.asarray(val)),\n }\n ]\n if param_state is not None:\n param_state += added_param_state\n else:\n param_state = added_param_state\n else:\n raise KeyError(\"Key not recognized.\")\n return param_state\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.data_stimulate","title":"data_stimulate(current, data_stimuli=None, verbose=False)
","text":"Insert a stimulus into the module within jit (or grad).
Parameters:
Name Type Description Defaultcurrent
ndarray
Current in nA
.
verbose
bool
Whether or not to print the number of inserted stimuli. False
by default because this method is meant to be jitted.
False
Source code in jaxley/modules/base.py
def data_stimulate(\n self,\n current: jnp.ndarray,\n data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n verbose: bool = False,\n) -> Tuple[jnp.ndarray, pd.DataFrame]:\n \"\"\"Insert a stimulus into the module within jit (or grad).\n\n Args:\n current: Current in `nA`.\n verbose: Whether or not to print the number of inserted stimuli. `False`\n by default because this method is meant to be jitted.\n \"\"\"\n return self._data_external_input(\n \"i\", current, data_stimuli, self.nodes, verbose=verbose\n )\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.delete_channel","title":"delete_channel(channel)
","text":"Remove a channel from the module.
Parameters:
Name Type Description Defaultchannel
Channel
The channel to remove.
required Source code injaxley/modules/base.py
def delete_channel(self, channel: Channel):\n \"\"\"Remove a channel from the module.\n\n Args:\n channel: The channel to remove.\"\"\"\n name = channel._name\n channel_names = [c._name for c in self.channels]\n all_channel_names = [c._name for c in self.base.channels]\n if name in channel_names:\n channel_cols = list(channel.channel_params.keys())\n channel_cols += list(channel.channel_states.keys())\n self.base.nodes.loc[self._nodes_in_view, channel_cols] = float(\"nan\")\n self.base.nodes.loc[self._nodes_in_view, name] = False\n\n # only delete cols if no other comps in the module have the same channel\n if np.all(~self.base.nodes[name]):\n self.base.channels.pop(all_channel_names.index(name))\n self.base.membrane_current_names.remove(channel.current_name)\n self.base.nodes.drop(columns=channel_cols + [name], inplace=True)\n else:\n raise ValueError(f\"Channel {name} not found in the module.\")\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.delete_clamps","title":"delete_clamps(state_name=None)
","text":"Removes all clamps of the given state from the module.
Source code injaxley/modules/base.py
def delete_clamps(self, state_name: Optional[str] = None):\n \"\"\"Removes all clamps of the given state from the module.\"\"\"\n all_externals = list(self.externals.keys())\n if \"i\" in all_externals:\n all_externals.remove(\"i\")\n state_names = all_externals if state_name is None else [state_name]\n for state_name in state_names:\n if state_name in self.externals:\n keep_inds = ~np.isin(\n self.base.external_inds[state_name], self._nodes_in_view\n )\n base_exts = self.base.externals\n base_exts_inds = self.base.external_inds\n if np.all(~keep_inds):\n base_exts.pop(state_name, None)\n base_exts_inds.pop(state_name, None)\n else:\n base_exts[state_name] = base_exts[state_name][keep_inds]\n base_exts_inds[state_name] = base_exts_inds[state_name][keep_inds]\n self._update_view()\n else:\n pass # does not have to be deleted if not in externals\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.delete_recordings","title":"delete_recordings()
","text":"Removes all recordings from the module.
Source code injaxley/modules/base.py
def delete_recordings(self):\n \"\"\"Removes all recordings from the module.\"\"\"\n if isinstance(self, View):\n base_recs = self.base.recordings\n self.base.recordings = base_recs[\n ~base_recs.isin(self.recordings).all(axis=1)\n ]\n self._update_view()\n else:\n self.base.recordings = pd.DataFrame().from_dict({})\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.delete_stimuli","title":"delete_stimuli()
","text":"Removes all stimuli from the module.
Source code injaxley/modules/base.py
def delete_stimuli(self):\n \"\"\"Removes all stimuli from the module.\"\"\"\n self.delete_clamps(\"i\")\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.delete_trainables","title":"delete_trainables()
","text":"Removes all trainable parameters from the module.
Source code injaxley/modules/base.py
def delete_trainables(self):\n \"\"\"Removes all trainable parameters from the module.\"\"\"\n\n if isinstance(self, View):\n trainables_and_inds = self._filter_trainables(is_viewed=False)\n self.base.indices_set_by_trainables = trainables_and_inds[0]\n self.base.trainable_params = trainables_and_inds[1]\n self.base.num_trainable_params -= self.num_trainable_params\n else:\n self.base.indices_set_by_trainables = []\n self.base.trainable_params = []\n self.base.num_trainable_params = 0\n self._update_view()\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.distance","title":"distance(endpoint)
","text":"Return the direct distance between two compartments. This does not compute the pathwise distance (which is currently not implemented). Args: endpoint: The compartment to which to compute the distance to.
Source code injaxley/modules/base.py
def distance(self, endpoint: \"View\") -> float:\n \"\"\"Return the direct distance between two compartments.\n This does not compute the pathwise distance (which is currently not\n implemented).\n Args:\n endpoint: The compartment to which to compute the distance to.\n \"\"\"\n assert len(self.xyzr) == 1 and len(endpoint.xyzr) == 1\n start_xyz = np.mean(self.xyzr[0][:, :3], axis=0)\n end_xyz = np.mean(endpoint.xyzr[0][:, :3], axis=0)\n return np.sqrt(np.sum((start_xyz - end_xyz) ** 2))\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.edge","title":"edge(idx)
","text":"Return a View of the module at the selected synapse edges(s).
Parameters:
Name Type Description Defaultidx
Any
index of the edge to view.
requiredReturns:
Type DescriptionView
View of the module at the specified edge index.
Source code injaxley/modules/base.py
def edge(self, idx: Any) -> View:\n \"\"\"Return a View of the module at the selected synapse edges(s).\n\n Args:\n idx: index of the edge to view.\n\n Returns:\n View of the module at the specified edge index.\"\"\"\n return self._at_edges(\"edge\", idx)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.get_all_parameters","title":"get_all_parameters(pstate, voltage_solver)
","text":"Return all parameters (and coupling conductances) needed to simulate.
Runs _compute_axial_conductances()
and return every parameter that is needed to solve the ODE. This includes conductances, radiuses, lengths, axial_resistivities, but also coupling conductances.
This is done by first obtaining the current value of every parameter (not only the trainable ones) and then replacing the trainable ones with the value in trainable_params()
. This function is run within jx.integrate()
.
pstate can be obtained by calling params_to_pstate()
.
.. code-block:: python
params = module.get_parameters() # i.e. [0, 1, 2]\npstate = params_to_pstate(params, module.indices_set_by_trainables)\nmodule.to_jax() # needed for call to module.jaxnodes\n
Parameters:
Name Type Description Defaultpstate
List[Dict]
The state of the trainable parameters. pstate takes the form [{ \u201ckey\u201d: \u201cgNa\u201d, \u201cindices\u201d: jnp.array([0, 1, 2]), \u201cval\u201d: jnp.array([0.1, 0.2, 0.3]) }, \u2026].
requiredvoltage_solver
str
The voltage solver that is used. Since jax.sparse
and jaxley.xyz
require different formats of the axial conductances, this function will default to different building methods.
Returns:
Type DescriptionDict[str, ndarray]
A dictionary of all module parameters.
Source code injaxley/modules/base.py
@only_allow_module\ndef get_all_parameters(\n self, pstate: List[Dict], voltage_solver: str\n) -> Dict[str, jnp.ndarray]:\n # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n \"\"\"Return all parameters (and coupling conductances) needed to simulate.\n\n Runs `_compute_axial_conductances()` and return every parameter that is needed\n to solve the ODE. This includes conductances, radiuses, lengths,\n axial_resistivities, but also coupling conductances.\n\n This is done by first obtaining the current value of every parameter (not only\n the trainable ones) and then replacing the trainable ones with the value\n in `trainable_params()`. This function is run within `jx.integrate()`.\n\n pstate can be obtained by calling `params_to_pstate()`.\n\n .. code-block:: python\n\n params = module.get_parameters() # i.e. [0, 1, 2]\n pstate = params_to_pstate(params, module.indices_set_by_trainables)\n module.to_jax() # needed for call to module.jaxnodes\n\n Args:\n pstate: The state of the trainable parameters. pstate takes the form\n [{\n \"key\": \"gNa\", \"indices\": jnp.array([0, 1, 2]),\n \"val\": jnp.array([0.1, 0.2, 0.3])\n }, ...].\n voltage_solver: The voltage solver that is used. Since `jax.sparse` and\n `jaxley.xyz` require different formats of the axial conductances, this\n function will default to different building methods.\n\n Returns:\n A dictionary of all module parameters.\n \"\"\"\n params = {}\n for key in [\"radius\", \"length\", \"axial_resistivity\", \"capacitance\"]:\n params[key] = self.base.jaxnodes[key]\n\n for channel in self.base.channels:\n for channel_params in channel.channel_params:\n params[channel_params] = self.base.jaxnodes[channel_params]\n\n for synapse_params in self.base.synapse_param_names:\n params[synapse_params] = self.base.jaxedges[synapse_params]\n\n # Override with those parameters set by `.make_trainable()`.\n for parameter in pstate:\n key = parameter[\"key\"]\n inds = parameter[\"indices\"]\n set_param = parameter[\"val\"]\n\n # This is needed since SynapseViews worked differently before.\n # This mimics the old behaviour and tranformes the new indices\n # to the old indices.\n # TODO FROM #447: Longterm this should be gotten rid of.\n # Instead edges should work similar to nodes (would also allow for\n # param sharing).\n synapse_inds = self.base.edges.groupby(\"type\").rank()[\"global_edge_index\"]\n synapse_inds = (synapse_inds.astype(int) - 1).to_numpy()\n if key in self.base.synapse_param_names:\n inds = synapse_inds[inds]\n\n if key in params: # Only parameters, not initial states.\n # `inds` is of shape `(num_params, num_comps_per_param)`.\n # `set_param` is of shape `(num_params,)`\n # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the\n # `.set()` to work. This is done with `[:, None]`.\n params[key] = params[key].at[inds].set(set_param[:, None])\n\n # Compute conductance params and add them to the params dictionary.\n params[\"axial_conductances\"] = self.base._compute_axial_conductances(\n params=params\n )\n return params\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.get_all_states","title":"get_all_states(pstate, all_params, delta_t)
","text":"Get the full initial state of the module from jaxnodes and trainables.
Parameters:
Name Type Description Defaultpstate
List[Dict]
The state of the trainable parameters.
requiredall_params
All parameters of the module.
requireddelta_t
float
The time step.
requiredReturns:
Type DescriptionDict[str, ndarray]
A dictionary of all states of the module.
Source code injaxley/modules/base.py
@only_allow_module\ndef get_all_states(\n self, pstate: List[Dict], all_params, delta_t: float\n) -> Dict[str, jnp.ndarray]:\n # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n \"\"\"Get the full initial state of the module from jaxnodes and trainables.\n\n Args:\n pstate: The state of the trainable parameters.\n all_params: All parameters of the module.\n delta_t: The time step.\n\n Returns:\n A dictionary of all states of the module.\n \"\"\"\n states = self.base._get_states_from_nodes_and_edges()\n\n # Override with the initial states set by `.make_trainable()`.\n for parameter in pstate:\n key = parameter[\"key\"]\n inds = parameter[\"indices\"]\n set_param = parameter[\"val\"]\n if key in states: # Only initial states, not parameters.\n # `inds` is of shape `(num_params, num_comps_per_param)`.\n # `set_param` is of shape `(num_params,)`\n # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the\n # `.set()` to work. This is done with `[:, None]`.\n states[key] = states[key].at[inds].set(set_param[:, None])\n\n # Add to the states the initial current through every channel.\n states, _ = self.base._channel_currents(\n states, delta_t, self.channels, self.nodes, all_params\n )\n\n # Add to the states the initial current through every synapse.\n states, _ = self.base._synapse_currents(\n states, self.synapses, all_params, delta_t, self.edges\n )\n return states\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.get_parameters","title":"get_parameters()
","text":"Get all trainable parameters.
The returned parameters should be passed to `jx.integrate(\u2026, params=params).
Returns:
Type DescriptionList[Dict[str, ndarray]]
A list of all trainable parameters in the form of [{\u201cgNa\u201d: jnp.array([0.1, 0.2, 0.3])}, \u2026].
Source code injaxley/modules/base.py
def get_parameters(self) -> List[Dict[str, jnp.ndarray]]:\n \"\"\"Get all trainable parameters.\n\n The returned parameters should be passed to `jx.integrate(..., params=params).\n\n Returns:\n A list of all trainable parameters in the form of\n [{\"gNa\": jnp.array([0.1, 0.2, 0.3])}, ...].\n \"\"\"\n return self.trainable_params\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.init_states","title":"init_states(delta_t=0.025)
","text":"Initialize all mechanisms in their steady state.
This considers the voltages and parameters of each compartment.
Parameters:
Name Type Description Defaultdelta_t
float
Passed on to channel.init_state()
.
0.025
Source code in jaxley/modules/base.py
@only_allow_module\ndef init_states(self, delta_t: float = 0.025):\n # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n \"\"\"Initialize all mechanisms in their steady state.\n\n This considers the voltages and parameters of each compartment.\n\n Args:\n delta_t: Passed on to `channel.init_state()`.\n \"\"\"\n # Update states of the channels.\n channel_nodes = self.base.nodes\n states = self.base._get_states_from_nodes_and_edges()\n\n # We do not use any `pstate` for initializing. In principle, we could change\n # that by allowing an input `params` and `pstate` to this function.\n # `voltage_solver` could also be `jax.sparse` here, because both of them\n # build the channel parameters in the same way.\n params = self.base.get_all_parameters([], voltage_solver=\"jaxley.thomas\")\n\n for channel in self.base.channels:\n name = channel._name\n channel_indices = channel_nodes.loc[channel_nodes[name]][\n \"global_comp_index\"\n ].to_numpy()\n voltages = channel_nodes.loc[channel_indices, \"v\"].to_numpy()\n\n channel_param_names = list(channel.channel_params.keys())\n channel_state_names = list(channel.channel_states.keys())\n channel_states = query_channel_states_and_params(\n states, channel_state_names, channel_indices\n )\n channel_params = query_channel_states_and_params(\n params, channel_param_names, channel_indices\n )\n\n init_state = channel.init_state(\n channel_states, voltages, channel_params, delta_t\n )\n\n # `init_state` might not return all channel states. Only the ones that are\n # returned are updated here.\n for key, val in init_state.items():\n # Note that we are overriding `self.nodes` here, but `self.nodes` is\n # not used above to actually compute the current states (so there are\n # no issues with overriding states).\n self.nodes.loc[channel_indices, key] = val\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.insert","title":"insert(channel)
","text":"Insert a channel into the module.
Parameters:
Name Type Description Defaultchannel
Channel
The channel to insert.
required Source code injaxley/modules/base.py
def insert(self, channel: Channel):\n \"\"\"Insert a channel into the module.\n\n Args:\n channel: The channel to insert.\"\"\"\n name = channel._name\n\n # Channel does not yet exist in the `jx.Module` at all.\n if name not in [c._name for c in self.base.channels]:\n self.base.channels.append(channel)\n self.base.nodes[name] = (\n False # Previous columns do not have the new channel.\n )\n\n if channel.current_name not in self.base.membrane_current_names:\n self.base.membrane_current_names.append(channel.current_name)\n\n # Add a binary column that indicates if a channel is present.\n self.base.nodes.loc[self._nodes_in_view, name] = True\n\n # Loop over all new parameters, e.g. gNa, eNa.\n for key in channel.channel_params:\n self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_params[key]\n\n # Loop over all new parameters, e.g. gNa, eNa.\n for key in channel.channel_states:\n self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_states[key]\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.loc","title":"loc(at)
","text":"Return a View of the module at the selected branch location(s).
Parameters:
Name Type Description Defaultat
Any
location along the branch.
requiredReturns:
Type DescriptionView
View of the module at the specified branch location.
Source code injaxley/modules/base.py
def loc(self, at: Any) -> View:\n \"\"\"Return a View of the module at the selected branch location(s).\n\n Args:\n at: location along the branch.\n\n Returns:\n View of the module at the specified branch location.\"\"\"\n global_comp_idxs = []\n for i in self._branches_in_view:\n ncomp = self.base.ncomp_per_branch[i]\n comp_locs = np.linspace(0, 1, ncomp)\n at = comp_locs if is_str_all(at) else self._reformat_index(at, dtype=float)\n comp_edges = np.linspace(0, 1 + 1e-10, ncomp + 1)\n idx = np.digitize(at, comp_edges) - 1 + self.base.cumsum_ncomp[i]\n global_comp_idxs.append(idx)\n global_comp_idxs = np.concatenate(global_comp_idxs)\n orig_scope = self._scope\n # global scope needed to select correct comps, for i.e. branches w. ncomp=[1,2]\n # loc(0.9) will correspond to different local branches (0 vs 1).\n view = self.scope(\"global\").comp(global_comp_idxs).scope(orig_scope)\n view._current_view = \"loc\"\n return view\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.make_trainable","title":"make_trainable(key, init_val=None, verbose=True)
","text":"Make a parameter trainable.
If a parameter is made trainable, it will be returned by get_parameters()
and should then be passed to jx.integrate(..., params=params)
.
Parameters:
Name Type Description Defaultkey
str
Name of the parameter to make trainable.
requiredinit_val
Optional[Union[float, list]]
Initial value of the parameter. If float
, the same value is used for every created parameter. If list
, the length of the list has to match the number of created parameters. If None
, the current parameter value is used and if parameter sharing is performed that the current parameter value is averaged over all shared parameters.
None
verbose
bool
Whether to print the number of parameters that are added and the total number of parameters.
True
Source code in jaxley/modules/base.py
def make_trainable(\n self,\n key: str,\n init_val: Optional[Union[float, list]] = None,\n verbose: bool = True,\n):\n \"\"\"Make a parameter trainable.\n\n If a parameter is made trainable, it will be returned by `get_parameters()`\n and should then be passed to `jx.integrate(..., params=params)`.\n\n Args:\n key: Name of the parameter to make trainable.\n init_val: Initial value of the parameter. If `float`, the same value is\n used for every created parameter. If `list`, the length of the list has\n to match the number of created parameters. If `None`, the current\n parameter value is used and if parameter sharing is performed that the\n current parameter value is averaged over all shared parameters.\n verbose: Whether to print the number of parameters that are added and the\n total number of parameters.\n \"\"\"\n assert (\n self.allow_make_trainable\n ), \"network.cell('all').make_trainable() is not supported. Use a for-loop over cells.\"\n ncomps_per_branch = (\n self.base.nodes[\"global_branch_index\"].value_counts().to_numpy()\n )\n assert np.all(\n ncomps_per_branch == ncomps_per_branch[0]\n ), \"Parameter sharing is not allowed for modules containing branches with different numbers of compartments.\"\n\n data = self.nodes if key in self.nodes.columns else None\n data = self.edges if key in self.edges.columns else data\n\n assert data is not None, f\"Key '{key}' not found in nodes or edges\"\n not_nan = ~data[key].isna()\n data = data.loc[not_nan]\n assert (\n len(data) > 0\n ), \"No settable parameters found in the selected compartments.\"\n\n grouped_view = data.groupby(\"controlled_by_param\")\n # Because of this `x.index.values` we cannot support `make_trainable()` on\n # the module level for synapse parameters (but only for `SynapseView`).\n inds_of_comps = list(\n grouped_view.apply(lambda x: x.index.values, include_groups=False)\n )\n indices_per_param = jnp.stack(inds_of_comps)\n # Sorted inds are only used to infer the correct starting values.\n param_vals = jnp.asarray(\n [data.loc[inds, key].to_numpy() for inds in inds_of_comps]\n )\n\n # Set the value which the trainable parameter should take.\n num_created_parameters = len(indices_per_param)\n if init_val is not None:\n if isinstance(init_val, float):\n new_params = jnp.asarray([init_val] * num_created_parameters)\n elif isinstance(init_val, list):\n assert (\n len(init_val) == num_created_parameters\n ), f\"len(init_val)={len(init_val)}, but trying to create {num_created_parameters} parameters.\"\n new_params = jnp.asarray(init_val)\n else:\n raise ValueError(\n f\"init_val must a float, list, or None, but it is a {type(init_val).__name__}.\"\n )\n else:\n new_params = jnp.mean(param_vals, axis=1)\n self.base.trainable_params.append({key: new_params})\n self.base.indices_set_by_trainables.append(indices_per_param)\n self.base.num_trainable_params += num_created_parameters\n if verbose:\n print(\n f\"Number of newly added trainable parameters: {num_created_parameters}. Total number of trainable parameters: {self.base.num_trainable_params}\"\n )\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.move","title":"move(x=0.0, y=0.0, z=0.0, update_nodes=False)
","text":"Move cells or networks by adding to their (x, y, z) coordinates.
This function is used only for visualization. It does not affect the simulation.
Parameters:
Name Type Description Defaultx
float
The amount to move in the x direction in um.
0.0
y
float
The amount to move in the y direction in um.
0.0
z
float
The amount to move in the z direction in um.
0.0
update_nodes
bool
Whether .nodes
should be updated or not. Setting this to False
largely speeds up moving, especially for big networks, but .nodes
or .show
will not show the new xyz coordinates.
False
Source code in jaxley/modules/base.py
def move(\n self, x: float = 0.0, y: float = 0.0, z: float = 0.0, update_nodes: bool = False\n):\n \"\"\"Move cells or networks by adding to their (x, y, z) coordinates.\n\n This function is used only for visualization. It does not affect the simulation.\n\n Args:\n x: The amount to move in the x direction in um.\n y: The amount to move in the y direction in um.\n z: The amount to move in the z direction in um.\n update_nodes: Whether `.nodes` should be updated or not. Setting this to\n `False` largely speeds up moving, especially for big networks, but\n `.nodes` or `.show` will not show the new xyz coordinates.\n \"\"\"\n for i in self._branches_in_view:\n self.base.xyzr[i][:, :3] += np.array([x, y, z])\n if update_nodes:\n self.compute_compartment_centers()\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.move_to","title":"move_to(x=0.0, y=0.0, z=0.0, update_nodes=False)
","text":"Move cells or networks to a location (x, y, z).
If x, y, and z are floats, then the first compartment of the first branch of the first cell is moved to that float coordinate, and everything else is shifted by the difference between that compartment\u2019s previous coordinate and the new float location.
If x, y, and z are arrays, then they must each have a length equal to the number of cells being moved. Then the first compartment of the first branch of each cell is moved to the specified location.
Parameters:
Name Type Description Defaultupdate_nodes
bool
Whether .nodes
should be updated or not. Setting this to False
largely speeds up moving, especially for big networks, but .nodes
or .show
will not show the new xyz coordinates.
False
Source code in jaxley/modules/base.py
def move_to(\n self,\n x: Union[float, np.ndarray] = 0.0,\n y: Union[float, np.ndarray] = 0.0,\n z: Union[float, np.ndarray] = 0.0,\n update_nodes: bool = False,\n):\n \"\"\"Move cells or networks to a location (x, y, z).\n\n If x, y, and z are floats, then the first compartment of the first branch\n of the first cell is moved to that float coordinate, and everything else is\n shifted by the difference between that compartment's previous coordinate and\n the new float location.\n\n If x, y, and z are arrays, then they must each have a length equal to the number\n of cells being moved. Then the first compartment of the first branch of each\n cell is moved to the specified location.\n\n Args:\n update_nodes: Whether `.nodes` should be updated or not. Setting this to\n `False` largely speeds up moving, especially for big networks, but\n `.nodes` or `.show` will not show the new xyz coordinates.\n \"\"\"\n # Test if any coordinate values are NaN which would greatly affect moving\n if np.any(np.concatenate(self.xyzr, axis=0)[:, :3] == np.nan):\n raise ValueError(\n \"NaN coordinate values detected. Shift amounts cannot be computed. Please run compute_xyzr() or assign initial coordinate values.\"\n )\n\n # can only iterate over cells for networks\n # lambda makes sure that generator can be created multiple times\n base_is_net = self.base._current_view == \"network\"\n cells = lambda: (self.cells if base_is_net else [self])\n\n root_xyz_cells = np.array([c.xyzr[0][0, :3] for c in cells()])\n root_xyz = root_xyz_cells[0] if isinstance(x, float) else root_xyz_cells\n move_by = np.array([x, y, z]).T - root_xyz\n\n if len(move_by.shape) == 1:\n move_by = np.tile(move_by, (len(self._cells_in_view), 1))\n\n for cell, offset in zip(cells(), move_by):\n for idx in cell._branches_in_view:\n self.base.xyzr[idx][:, :3] += offset\n if update_nodes:\n self.compute_compartment_centers()\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.rotate","title":"rotate(degrees, rotation_axis='xy', update_nodes=False)
","text":"Rotate jaxley modules clockwise. Used only for visualization.
This function is used only for visualization. It does not affect the simulation.
Parameters:
Name Type Description Defaultdegrees
float
How many degrees to rotate the module by.
requiredrotation_axis
str
Either of {xy
| xz
| yz
}.
'xy'
Source code in jaxley/modules/base.py
def rotate(\n self, degrees: float, rotation_axis: str = \"xy\", update_nodes: bool = False\n):\n \"\"\"Rotate jaxley modules clockwise. Used only for visualization.\n\n This function is used only for visualization. It does not affect the simulation.\n\n Args:\n degrees: How many degrees to rotate the module by.\n rotation_axis: Either of {`xy` | `xz` | `yz`}.\n \"\"\"\n degrees = degrees / 180 * np.pi\n if rotation_axis == \"xy\":\n dims = [0, 1]\n elif rotation_axis == \"xz\":\n dims = [0, 2]\n elif rotation_axis == \"yz\":\n dims = [1, 2]\n else:\n raise ValueError\n\n rotation_matrix = np.asarray(\n [[np.cos(degrees), np.sin(degrees)], [-np.sin(degrees), np.cos(degrees)]]\n )\n for i in self._branches_in_view:\n rot = np.dot(rotation_matrix, self.base.xyzr[i][:, dims].T).T\n self.base.xyzr[i][:, dims] = rot\n if update_nodes:\n self.compute_compartment_centers()\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.scope","title":"scope(scope)
","text":"Return a View of the module with the specified scope.
For example cell.scope(\"global\").branch(2).scope(\"local\").comp(1)
will return the 1st compartment of branch 2.
Parameters:
Name Type Description Defaultscope
str
either \u201cglobal\u201d or \u201clocal\u201d.
requiredReturns:
Type DescriptionView
View with the specified scope.
Source code injaxley/modules/base.py
def scope(self, scope: str) -> View:\n \"\"\"Return a View of the module with the specified scope.\n\n For example `cell.scope(\"global\").branch(2).scope(\"local\").comp(1)`\n will return the 1st compartment of branch 2.\n\n Args:\n scope: either \"global\" or \"local\".\n\n Returns:\n View with the specified scope.\"\"\"\n view = self.view\n view.set_scope(scope)\n return view\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.select","title":"select(nodes=None, edges=None, sorted=False)
","text":"Return View of the module filtered by specific node or edges indices.
Parameters:
Name Type Description Defaultnodes
ndarray
indices of nodes to view. If None, all nodes are viewed.
None
edges
ndarray
indices of edges to view. If None, all edges are viewed.
None
sorted
bool
if True, nodes and edges are sorted.
False
Returns:
Type DescriptionView
View for subset of selected nodes and/or edges.
Source code injaxley/modules/base.py
def select(\n self, nodes: np.ndarray = None, edges: np.ndarray = None, sorted: bool = False\n) -> View:\n \"\"\"Return View of the module filtered by specific node or edges indices.\n\n Args:\n nodes: indices of nodes to view. If None, all nodes are viewed.\n edges: indices of edges to view. If None, all edges are viewed.\n sorted: if True, nodes and edges are sorted.\n\n Returns:\n View for subset of selected nodes and/or edges.\"\"\"\n\n nodes = self._reformat_index(nodes) if nodes is not None else None\n nodes = self._nodes_in_view if is_str_all(nodes) else nodes\n nodes = np.sort(nodes) if sorted else nodes\n\n edges = self._reformat_index(edges) if edges is not None else None\n edges = self._edges_in_view if is_str_all(edges) else edges\n edges = np.sort(edges) if sorted else edges\n\n view = View(self, nodes, edges)\n view._set_controlled_by_param(\"filter\")\n return view\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.set","title":"set(key, val)
","text":"Set parameter of module (or its view) to a new value.
Note that this function can not be called within jax.jit
or jax.grad
. Instead, it should be used set the parameters of the module before the simulation. Use .data_set()
to set parameters during jax.jit
or jax.grad
.
Parameters:
Name Type Description Defaultkey
str
The name of the parameter to set.
requiredval
Union[float, ndarray]
The value to set the parameter to. If it is jnp.ndarray
then it must be of shape (len(num_compartments))
.
jaxley/modules/base.py
def set(self, key: str, val: Union[float, jnp.ndarray]):\n \"\"\"Set parameter of module (or its view) to a new value.\n\n Note that this function can not be called within `jax.jit` or `jax.grad`.\n Instead, it should be used set the parameters of the module **before** the\n simulation. Use `.data_set()` to set parameters during `jax.jit` or\n `jax.grad`.\n\n Args:\n key: The name of the parameter to set.\n val: The value to set the parameter to. If it is `jnp.ndarray` then it\n must be of shape `(len(num_compartments))`.\n \"\"\"\n if key in self.nodes.columns:\n not_nan = ~self.nodes[key].isna().to_numpy()\n self.base.nodes.loc[self._nodes_in_view[not_nan], key] = val\n elif key in self.edges.columns:\n not_nan = ~self.edges[key].isna().to_numpy()\n self.base.edges.loc[self._edges_in_view[not_nan], key] = val\n else:\n raise KeyError(f\"Key '{key}' not found in nodes or edges\")\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.set_ncomp","title":"set_ncomp(ncomp, min_radius=None)
","text":"Set the number of compartments with which the branch is discretized.
Parameters:
Name Type Description Defaultncomp
int
The number of compartments that the branch should be discretized into.
requiredmin_radius
Optional[float]
Only used if the morphology was read from an SWC file. If passed the radius is capped to be at least this value.
None
Source code in jaxley/modules/base.py
def set_ncomp(\n self,\n ncomp: int,\n min_radius: Optional[float] = None,\n):\n \"\"\"Set the number of compartments with which the branch is discretized.\n\n Args:\n ncomp: The number of compartments that the branch should be discretized\n into.\n min_radius: Only used if the morphology was read from an SWC file. If passed\n the radius is capped to be at least this value.\n\n Raises:\n - When there are stimuli in any compartment in the module.\n - When there are recordings in any compartment in the module.\n - When the channels of the compartments are not the same within the branch\n that is modified.\n - When the lengths of the compartments are not the same within the branch\n that is modified.\n - Unless the morphology was read from an SWC file, when the radiuses of the\n compartments are not the same within the branch that is modified.\n \"\"\"\n assert len(self.base.externals) == 0, \"No stimuli allowed!\"\n assert len(self.base.recordings) == 0, \"No recordings allowed!\"\n assert len(self.base.trainable_params) == 0, \"No trainables allowed!\"\n\n assert self.base._module_type != \"network\", \"This is not allowed for networks.\"\n assert not (\n self.base._module_type == \"cell\"\n and len(self._branches_in_view) == len(self.base._branches_in_view)\n ), \"This is not allowed for cells.\"\n\n # Update all attributes that are affected by compartment structure.\n view = self.nodes.copy()\n all_nodes = self.base.nodes\n start_idx = self.nodes[\"global_comp_index\"].to_numpy()[0]\n ncomp_per_branch = self.base.ncomp_per_branch\n channel_names = [c._name for c in self.base.channels]\n channel_param_names = list(\n chain(*[c.channel_params for c in self.base.channels])\n )\n channel_state_names = list(\n chain(*[c.channel_states for c in self.base.channels])\n )\n radius_generating_fns = self.base._radius_generating_fns\n\n within_branch_radiuses = view[\"radius\"].to_numpy()\n compartment_lengths = view[\"length\"].to_numpy()\n num_previous_ncomp = len(within_branch_radiuses)\n branch_indices = pd.unique(view[\"global_branch_index\"])\n\n error_msg = lambda name: (\n f\"You previously modified the {name} of individual compartments, but \"\n f\"now you are modifying the number of compartments in this branch. \"\n f\"This is not allowed. First build the morphology with `set_ncomp()` and \"\n f\"then modify the radiuses and lengths of compartments.\"\n )\n\n if (\n ~np.all(within_branch_radiuses == within_branch_radiuses[0])\n and radius_generating_fns is None\n ):\n raise ValueError(error_msg(\"radius\"))\n\n for property_name in [\"length\", \"capacitance\", \"axial_resistivity\"]:\n compartment_properties = view[property_name].to_numpy()\n if ~np.all(compartment_properties == compartment_properties[0]):\n raise ValueError(error_msg(property_name))\n\n if not (self.nodes[channel_names].var() == 0.0).all():\n raise ValueError(\n \"Some channel exists only in some compartments of the branch which you\"\n \"are trying to modify. This is not allowed. First specify the number\"\n \"of compartments with `.set_ncomp()` and then insert the channels\"\n \"accordingly.\"\n )\n\n if not (\n self.nodes[channel_param_names + channel_state_names].var() == 0.0\n ).all():\n raise ValueError(\n \"Some channel has different parameters or states between the \"\n \"different compartments of the branch which you are trying to modify. \"\n \"This is not allowed. First specify the number of compartments with \"\n \"`.set_ncomp()` and then insert the channels accordingly.\"\n )\n\n # Add new rows as the average of all rows. Special case for the length is below.\n average_row = self.nodes.mean(skipna=False)\n average_row = average_row.to_frame().T\n view = pd.concat([*[average_row] * ncomp], axis=\"rows\")\n\n # Set the correct datatype after having performed an average which cast\n # everything to float.\n integer_cols = [\"global_cell_index\", \"global_branch_index\", \"global_comp_index\"]\n view[integer_cols] = view[integer_cols].astype(int)\n\n # Whether or not a channel exists in a compartment is a boolean.\n boolean_cols = channel_names\n view[boolean_cols] = view[boolean_cols].astype(bool)\n\n # Special treatment for the lengths and radiuses. These are not being set as\n # the average because we:\n # 1) Want to maintain the total length of a branch.\n # 2) Want to use the SWC inferred radius.\n #\n # Compute new compartment lengths.\n comp_lengths = np.sum(compartment_lengths) / ncomp\n view[\"length\"] = comp_lengths\n\n # Compute new compartment radiuses.\n if radius_generating_fns is not None:\n view[\"radius\"] = build_radiuses_from_xyzr(\n radius_fns=radius_generating_fns,\n branch_indices=branch_indices,\n min_radius=min_radius,\n ncomp=ncomp,\n )\n else:\n view[\"radius\"] = within_branch_radiuses[0] * np.ones(ncomp)\n\n # Update `.nodes`.\n # 1) Delete N rows starting from start_idx\n number_deleted = num_previous_ncomp\n all_nodes = all_nodes.drop(index=range(start_idx, start_idx + number_deleted))\n\n # 2) Insert M new rows at the same location\n df1 = all_nodes.iloc[:start_idx] # Rows before the insertion point\n df2 = all_nodes.iloc[start_idx:] # Rows after the insertion point\n\n # 3) Combine the parts: before, new rows, and after\n all_nodes = pd.concat([df1, view, df2]).reset_index(drop=True)\n\n # Override `comp_index` to just be a consecutive list.\n all_nodes[\"global_comp_index\"] = np.arange(len(all_nodes))\n\n # Update compartment structure arguments.\n ncomp_per_branch[branch_indices] = ncomp\n ncomp = int(np.max(ncomp_per_branch))\n cumsum_ncomp = cumsum_leading_zero(ncomp_per_branch)\n internal_node_inds = np.arange(cumsum_ncomp[-1])\n\n self.base.nodes = all_nodes\n self.base.ncomp_per_branch = ncomp_per_branch\n self.base.ncomp = ncomp\n self.base.cumsum_ncomp = cumsum_ncomp\n self.base._internal_node_inds = internal_node_inds\n\n # Update the morphology indexing (e.g., `.comp_edges`).\n self.base._initialize()\n self.base._init_view()\n self.base._update_local_indices()\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.set_scope","title":"set_scope(scope)
","text":"Toggle between \u201cglobal\u201d or \u201clocal\u201d scope.
Determines if global or local indices are used for viewing the module.
Parameters:
Name Type Description Defaultscope
str
either \u201cglobal\u201d or \u201clocal\u201d.
required Source code injaxley/modules/base.py
def set_scope(self, scope: str):\n \"\"\"Toggle between \"global\" or \"local\" scope.\n\n Determines if global or local indices are used for viewing the module.\n\n Args:\n scope: either \"global\" or \"local\".\"\"\"\n assert scope in [\"global\", \"local\"], \"Invalid scope.\"\n self._scope = scope\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.show","title":"show(param_names=None, *, indices=True, params=True, states=True, channel_names=None)
","text":"Print detailed information about the Module or a view of it.
Parameters:
Name Type Description Defaultparam_names
Optional[Union[str, List[str]]]
The names of the parameters to show. If None
, all parameters are shown.
None
indices
bool
Whether to show the indices of the compartments.
True
params
bool
Whether to show the parameters of the compartments.
True
states
bool
Whether to show the states of the compartments.
True
channel_names
Optional[List[str]]
The names of the channels to show. If None
, all channels are shown.
None
Returns:
Type DescriptionDataFrame
A pd.DataFrame
with the requested information.
jaxley/modules/base.py
def show(\n self,\n param_names: Optional[Union[str, List[str]]] = None,\n *,\n indices: bool = True,\n params: bool = True,\n states: bool = True,\n channel_names: Optional[List[str]] = None,\n) -> pd.DataFrame:\n \"\"\"Print detailed information about the Module or a view of it.\n\n Args:\n param_names: The names of the parameters to show. If `None`, all parameters\n are shown.\n indices: Whether to show the indices of the compartments.\n params: Whether to show the parameters of the compartments.\n states: Whether to show the states of the compartments.\n channel_names: The names of the channels to show. If `None`, all channels are\n shown.\n\n Returns:\n A `pd.DataFrame` with the requested information.\n \"\"\"\n nodes = self.nodes.copy() # prevents this from being edited\n\n cols = []\n inds = [\"comp_index\", \"branch_index\", \"cell_index\"]\n scopes = [\"local\", \"global\"]\n inds = [f\"{s}_{i}\" for i in inds for s in scopes] if indices else []\n cols += inds\n cols += [ch._name for ch in self.channels] if channel_names else []\n cols += (\n sum([list(ch.channel_params) for ch in self.channels], []) if params else []\n )\n cols += (\n sum([list(ch.channel_states) for ch in self.channels], []) if states else []\n )\n\n if not param_names is None:\n cols = (\n inds + [c for c in cols if c in param_names]\n if params\n else list(param_names)\n )\n\n return nodes[cols]\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.step","title":"step(u, delta_t, external_inds, externals, params, solver='bwd_euler', voltage_solver='jaxley.stone')
","text":"One step of solving the Ordinary Differential Equation.
This function is called inside of integrate
and increments the state of the module by one time step. Calls _step_channels
and _step_synapse
to update the states of the channels and synapses using fwd_euler.
Parameters:
Name Type Description Defaultu
Dict[str, ndarray]
The state of the module. voltages = u[\u201cv\u201d]
requireddelta_t
float
The time step.
requiredexternal_inds
Dict[str, ndarray]
The indices of the external inputs.
requiredexternals
Dict[str, ndarray]
The external inputs.
requiredparams
Dict[str, ndarray]
The parameters of the module.
requiredsolver
str
The solver to use for the voltages. Either of [\u201cbwd_euler\u201d, \u201cfwd_euler\u201d, \u201ccrank_nicolson\u201d].
'bwd_euler'
voltage_solver
str
The tridiagonal solver used to diagonalize the coefficient matrix of the ODE system. Either of [\u201cjaxley.thomas\u201d, \u201cjaxley.stone\u201d].
'jaxley.stone'
Returns:
Type DescriptionDict[str, ndarray]
The updated state of the module.
Source code injaxley/modules/base.py
@only_allow_module\ndef step(\n self,\n u: Dict[str, jnp.ndarray],\n delta_t: float,\n external_inds: Dict[str, jnp.ndarray],\n externals: Dict[str, jnp.ndarray],\n params: Dict[str, jnp.ndarray],\n solver: str = \"bwd_euler\",\n voltage_solver: str = \"jaxley.stone\",\n) -> Dict[str, jnp.ndarray]:\n \"\"\"One step of solving the Ordinary Differential Equation.\n\n This function is called inside of `integrate` and increments the state of the\n module by one time step. Calls `_step_channels` and `_step_synapse` to update\n the states of the channels and synapses using fwd_euler.\n\n Args:\n u: The state of the module. voltages = u[\"v\"]\n delta_t: The time step.\n external_inds: The indices of the external inputs.\n externals: The external inputs.\n params: The parameters of the module.\n solver: The solver to use for the voltages. Either of [\"bwd_euler\",\n \"fwd_euler\", \"crank_nicolson\"].\n voltage_solver: The tridiagonal solver used to diagonalize the\n coefficient matrix of the ODE system. Either of [\"jaxley.thomas\",\n \"jaxley.stone\"].\n\n Returns:\n The updated state of the module.\n \"\"\"\n\n # Extract the voltages\n voltages = u[\"v\"]\n\n # Extract the external inputs\n if \"i\" in externals.keys():\n i_current = externals[\"i\"]\n i_inds = external_inds[\"i\"]\n i_ext = self._get_external_input(\n voltages, i_inds, i_current, params[\"radius\"], params[\"length\"]\n )\n else:\n i_ext = 0.0\n\n # Step of the channels.\n u, (v_terms, const_terms) = self._step_channels(\n u, delta_t, self.channels, self.nodes, params\n )\n\n # Step of the synapse.\n u, (syn_v_terms, syn_const_terms) = self._step_synapse(\n u,\n self.synapses,\n params,\n delta_t,\n self.edges,\n )\n\n # Clamp for channels and synapses.\n for key in externals.keys():\n if key not in [\"i\", \"v\"]:\n u[key] = u[key].at[external_inds[key]].set(externals[key])\n\n # Voltage steps.\n cm = params[\"capacitance\"] # Abbreviation.\n\n # Arguments used by all solvers.\n solver_kwargs = {\n \"voltages\": voltages,\n \"voltage_terms\": (v_terms + syn_v_terms) / cm,\n \"constant_terms\": (const_terms + i_ext + syn_const_terms) / cm,\n \"axial_conductances\": params[\"axial_conductances\"],\n \"internal_node_inds\": self._internal_node_inds,\n }\n\n # Add solver specific arguments.\n if voltage_solver == \"jax.sparse\":\n solver_kwargs.update(\n {\n \"sinks\": np.asarray(self._comp_edges[\"sink\"].to_list()),\n \"data_inds\": self._data_inds,\n \"indices\": self._indices_jax_spsolve,\n \"indptr\": self._indptr_jax_spsolve,\n \"n_nodes\": self._n_nodes,\n }\n )\n # Only for `bwd_euler` and `cranck-nicolson`.\n step_voltage_implicit = step_voltage_implicit_with_jax_spsolve\n else:\n # Our custom sparse solver requires a different format of all conductance\n # values to perform triangulation and backsubstution optimally.\n #\n # Currently, the forward Euler solver also uses this format. However,\n # this is only for historical reasons and we are planning to change this in\n # the future.\n solver_kwargs.update(\n {\n \"sinks\": np.asarray(self._comp_edges[\"sink\"].to_list()),\n \"sources\": np.asarray(self._comp_edges[\"source\"].to_list()),\n \"types\": np.asarray(self._comp_edges[\"type\"].to_list()),\n \"ncomp_per_branch\": self.ncomp_per_branch,\n \"par_inds\": self._par_inds,\n \"child_inds\": self._child_inds,\n \"nbranches\": self.total_nbranches,\n \"solver\": voltage_solver,\n \"idx\": self._solve_indexer,\n \"debug_states\": self.debug_states,\n }\n )\n # Only for `bwd_euler` and `cranck-nicolson`.\n step_voltage_implicit = step_voltage_implicit_with_jaxley_spsolve\n\n if solver == \"bwd_euler\":\n u[\"v\"] = step_voltage_implicit(**solver_kwargs, delta_t=delta_t)\n elif solver == \"crank_nicolson\":\n # Crank-Nicolson advances by half a step of backward and half a step of\n # forward Euler.\n half_step_delta_t = delta_t / 2\n half_step_voltages = step_voltage_implicit(\n **solver_kwargs, delta_t=half_step_delta_t\n )\n # The forward Euler step in Crank-Nicolson can be performed easily as\n # `V_{n+1} = 2 * V_{n+1/2} - V_n`. See also NEURON book Chapter 4.\n u[\"v\"] = 2 * half_step_voltages - voltages\n elif solver == \"fwd_euler\":\n u[\"v\"] = step_voltage_explicit(**solver_kwargs, delta_t=delta_t)\n else:\n raise ValueError(\n f\"You specified `solver={solver}`. The only allowed solvers are \"\n \"['bwd_euler', 'fwd_euler', 'crank_nicolson'].\"\n )\n\n # Clamp for voltages.\n if \"v\" in externals.keys():\n u[\"v\"] = u[\"v\"].at[external_inds[\"v\"]].set(externals[\"v\"])\n\n return u\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.stimulate","title":"stimulate(current=None, verbose=True)
","text":"Insert a stimulus into the compartment.
current must be a 1d array or have batch dimension of size (num_compartments, )
or (1, )
. If 1d, the same stimulus is added to all compartments.
This function cannot be run during jax.jit
and jax.grad
. Because of this, it should only be used for static stimuli (i.e., stimuli that do not depend on the data and that should not be learned). For stimuli that depend on data (or that should be learned), please use data_stimulate()
.
Parameters:
Name Type Description Defaultcurrent
Optional[ndarray]
Current in nA
.
None
Source code in jaxley/modules/base.py
def stimulate(self, current: Optional[jnp.ndarray] = None, verbose: bool = True):\n \"\"\"Insert a stimulus into the compartment.\n\n current must be a 1d array or have batch dimension of size `(num_compartments, )`\n or `(1, )`. If 1d, the same stimulus is added to all compartments.\n\n This function cannot be run during `jax.jit` and `jax.grad`. Because of this,\n it should only be used for static stimuli (i.e., stimuli that do not depend\n on the data and that should not be learned). For stimuli that depend on data\n (or that should be learned), please use `data_stimulate()`.\n\n Args:\n current: Current in `nA`.\n \"\"\"\n self._external_input(\"i\", current, verbose=verbose)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.to_jax","title":"to_jax()
","text":"Move .nodes
to .jaxnodes
.
Before the actual simulation is run (via jx.integrate
), all parameters of the jx.Module
are stored in .nodes
(a pd.DataFrame
). However, for simulation, these parameters have to be moved to be jnp.ndarrays
such that they can be processed on GPU/TPU and such that the simulation can be differentiated. .to_jax()
copies the .nodes
to .jaxnodes
.
jaxley/modules/base.py
@only_allow_module\ndef to_jax(self):\n # TODO FROM #447: Make this work for View?\n \"\"\"Move `.nodes` to `.jaxnodes`.\n\n Before the actual simulation is run (via `jx.integrate`), all parameters of\n the `jx.Module` are stored in `.nodes` (a `pd.DataFrame`). However, for\n simulation, these parameters have to be moved to be `jnp.ndarrays` such that\n they can be processed on GPU/TPU and such that the simulation can be\n differentiated. `.to_jax()` copies the `.nodes` to `.jaxnodes`.\n \"\"\"\n self.base.jaxnodes = {}\n for key, value in self.base.nodes.to_dict(orient=\"list\").items():\n inds = jnp.arange(len(value))\n self.base.jaxnodes[key] = jnp.asarray(value)[inds]\n\n # `jaxedges` contains only parameters (no indices).\n # `jaxedges` contains only non-Nan elements. This is unlike the channels where\n # we allow parameter sharing.\n self.base.jaxedges = {}\n edges = self.base.edges.to_dict(orient=\"list\")\n for i, synapse in enumerate(self.base.synapses):\n condition = np.asarray(edges[\"type_ind\"]) == i\n for key in synapse.synapse_params:\n self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])\n for key in synapse.synapse_states:\n self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.vis","title":"vis(ax=None, col='k', dims=(0, 1), type='line', morph_plot_kwargs={})
","text":"Visualize the module.
Modules can be visualized on one of the cardinal planes (xy, xz, yz) or even in 3D.
Several options are available: - line
: All points from the traced morphology (xyzr
), are connected with a line plot. - scatter
: All traced points, are plotted as scatter points. - comp
: Plots the compartmentalized morphology, including radius and shape. (shows the true compartment lengths per default, but this can be changed via the morph_plot_kwargs
, for details see jaxley.utils.plot_utils.plot_comps
). - morph
: Reconstructs the 3D shape of the traced morphology. For details see jaxley.utils.plot_utils.plot_morph
. Warning: For 3D plots and morphologies with many traced points this can be very slow.
Parameters:
Name Type Description Defaultax
Optional[Axes]
An axis into which to plot.
None
col
str
The color for all branches.
'k'
dims
Tuple[int]
Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them.
(0, 1)
type
str
The type of plot. One of [\u201cline\u201d, \u201cscatter\u201d, \u201ccomp\u201d, \u201cmorph\u201d].
'line'
morph_plot_kwargs
Dict
Keyword arguments passed to the plotting function.
{}
Source code in jaxley/modules/base.py
def vis(\n self,\n ax: Optional[Axes] = None,\n col: str = \"k\",\n dims: Tuple[int] = (0, 1),\n type: str = \"line\",\n morph_plot_kwargs: Dict = {},\n) -> Axes:\n \"\"\"Visualize the module.\n\n Modules can be visualized on one of the cardinal planes (xy, xz, yz) or\n even in 3D.\n\n Several options are available:\n - `line`: All points from the traced morphology (`xyzr`), are connected\n with a line plot.\n - `scatter`: All traced points, are plotted as scatter points.\n - `comp`: Plots the compartmentalized morphology, including radius\n and shape. (shows the true compartment lengths per default, but this can\n be changed via the `morph_plot_kwargs`, for details see\n `jaxley.utils.plot_utils.plot_comps`).\n - `morph`: Reconstructs the 3D shape of the traced morphology. For details see\n `jaxley.utils.plot_utils.plot_morph`. Warning: For 3D plots and morphologies\n with many traced points this can be very slow.\n\n Args:\n ax: An axis into which to plot.\n col: The color for all branches.\n dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n two of them.\n type: The type of plot. One of [\"line\", \"scatter\", \"comp\", \"morph\"].\n morph_plot_kwargs: Keyword arguments passed to the plotting function.\n \"\"\"\n if \"comp\" in type.lower():\n return plot_comps(self, dims=dims, ax=ax, col=col, **morph_plot_kwargs)\n if \"morph\" in type.lower():\n return plot_morph(self, dims=dims, ax=ax, col=col, **morph_plot_kwargs)\n\n assert not np.any(\n [np.isnan(xyzr[:, dims]).all() for xyzr in self.xyzr]\n ), \"No coordinates available. Use `vis(detail='point')` or run `.compute_xyz()` before running `.vis()`.\"\n\n ax = plot_graph(\n self.xyzr,\n dims=dims,\n col=col,\n ax=ax,\n type=type,\n morph_plot_kwargs=morph_plot_kwargs,\n )\n\n return ax\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.write_trainables","title":"write_trainables(trainable_params)
","text":"Write the trainables into .nodes
and .edges
.
This allows to, e.g., visualize trained networks with .vis()
.
Parameters:
Name Type Description Defaulttrainable_params
List[Dict[str, ndarray]]
The trainable parameters returned by get_parameters()
.
jaxley/modules/base.py
def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]):\n \"\"\"Write the trainables into `.nodes` and `.edges`.\n\n This allows to, e.g., visualize trained networks with `.vis()`.\n\n Args:\n trainable_params: The trainable parameters returned by `get_parameters()`.\n \"\"\"\n # We do not support views. Why? `jaxedges` does not have any NaN\n # elements, whereas edges does. Because of this, we already need special\n # treatment to make this function work, and it would be an even bigger hassle\n # if we wanted to support this.\n assert self.__class__.__name__ in [\n \"Compartment\",\n \"Branch\",\n \"Cell\",\n \"Network\",\n ], \"Only supports modules.\"\n\n # We could also implement this without casting the module to jax.\n # However, I think it allows us to reuse as much code as possible and it avoids\n # any kind of issues with indexing or parameter sharing (as this is fully\n # taken care of by `get_all_parameters()`).\n self.base.to_jax()\n pstate = params_to_pstate(trainable_params, self.base.indices_set_by_trainables)\n all_params = self.base.get_all_parameters(pstate, voltage_solver=\"jaxley.stone\")\n\n # The value for `delta_t` does not matter here because it is only used to\n # compute the initial current. However, the initial current cannot be made\n # trainable and so its value never gets used below.\n all_states = self.base.get_all_states(pstate, all_params, delta_t=0.025)\n\n # Loop only over the keys in `pstate` to avoid unnecessary computation.\n for parameter in pstate:\n key = parameter[\"key\"]\n if key in self.base.nodes.columns:\n vals_to_set = all_params if key in all_params.keys() else all_states\n self.base.nodes[key] = vals_to_set[key]\n\n # `jaxedges` contains only non-Nan elements. This is unlike the channels where\n # we allow parameter sharing.\n edges = self.base.edges.to_dict(orient=\"list\")\n for i, synapse in enumerate(self.base.synapses):\n condition = np.asarray(edges[\"type_ind\"]) == i\n for key in list(synapse.synapse_params.keys()):\n self.base.edges.loc[condition, key] = all_params[key]\n for key in list(synapse.synapse_states.keys()):\n self.base.edges.loc[condition, key] = all_states[key]\n
"},{"location":"reference/modules/#compartment","title":"Compartment","text":" Bases: Module
Compartment class.
This class defines a single compartment that can be simulated by itself or connected up into branches. It is the basic building block of a neuron model.
Source code injaxley/modules/compartment.py
class Compartment(Module):\n \"\"\"Compartment class.\n\n This class defines a single compartment that can be simulated by itself or\n connected up into branches. It is the basic building block of a neuron model.\n \"\"\"\n\n compartment_params: Dict = {\n \"length\": 10.0, # um\n \"radius\": 1.0, # um\n \"axial_resistivity\": 5_000.0, # ohm cm\n \"capacitance\": 1.0, # uF/cm^2\n }\n compartment_states: Dict = {\"v\": -70.0}\n\n def __init__(self):\n super().__init__()\n\n self.ncomp = 1\n self.ncomp_per_branch = np.asarray([1])\n self.total_nbranches = 1\n self.nbranches_per_cell = [1]\n self._cumsum_nbranches = np.asarray([0, 1])\n self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n\n # Setting up the `nodes` for indexing.\n self.nodes = pd.DataFrame(\n dict(global_cell_index=[0], global_branch_index=[0], global_comp_index=[0])\n )\n self._append_params_and_states(self.compartment_params, self.compartment_states)\n self._update_local_indices()\n self._init_view()\n\n # Synapses.\n self.branch_edges = pd.DataFrame(\n dict(parent_branch_index=[], child_branch_index=[])\n )\n\n # For morphology indexing.\n self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n compute_children_and_parents(self.branch_edges)\n )\n self._internal_node_inds = jnp.asarray([0])\n\n # Initialize the module.\n self._initialize()\n\n # Coordinates.\n self.xyzr = [float(\"NaN\") * np.zeros((2, 4))]\n\n def _init_morph_jaxley_spsolve(self):\n self._solve_indexer = JaxleySolveIndexer(\n cumsum_ncomp=self.cumsum_ncomp,\n branchpoint_group_inds=np.asarray([]).astype(int),\n children_in_level=[],\n parents_in_level=[],\n root_inds=np.asarray([0]),\n remapped_node_indices=self._internal_node_inds,\n )\n\n def _init_morph_jax_spsolve(self):\n \"\"\"Initialize morphology for the jax sparse voltage solver.\n\n Explanation of `self._comp_eges['type']`:\n `type == 0`: compartment <--> compartment (within branch)\n `type == 1`: branchpoint --> parent-compartment\n `type == 2`: branchpoint --> child-compartment\n `type == 3`: parent-compartment --> branchpoint\n `type == 4`: child-compartment --> branchpoint\n \"\"\"\n self._comp_edges = pd.DataFrame().from_dict(\n {\"source\": [], \"sink\": [], \"type\": []}\n )\n n_nodes, data_inds, indices, indptr = comp_edges_to_indices(self._comp_edges)\n self._n_nodes = n_nodes\n self._data_inds = data_inds\n self._indices_jax_spsolve = indices\n self._indptr_jax_spsolve = indptr\n
"},{"location":"reference/modules/#branch","title":"Branch","text":" Bases: Module
Branch class.
This class defines a single branch that can be simulated by itself or connected to build a cell. A branch is linear segment of several compartments and can be connected to no, one or more other branches at each end to build more intricate cell morphologies.
Source code injaxley/modules/branch.py
class Branch(Module):\n \"\"\"Branch class.\n\n This class defines a single branch that can be simulated by itself or\n connected to build a cell. A branch is linear segment of several compartments\n and can be connected to no, one or more other branches at each end to build more\n intricate cell morphologies.\n \"\"\"\n\n branch_params: Dict = {}\n branch_states: Dict = {}\n\n @deprecated_kwargs(\"0.6.0\", [\"nseg\"])\n def __init__(\n self,\n compartments: Optional[Union[Compartment, List[Compartment]]] = None,\n ncomp: Optional[int] = None,\n nseg: Optional[int] = None,\n ):\n \"\"\"\n Args:\n compartments: A single compartment or a list of compartments that make up the\n branch.\n ncomp: Number of segments to divide the branch into. If `compartments` is an\n a single compartment, than the compartment is repeated `ncomp` times to\n create the branch.\n \"\"\"\n # Warnings and errors that deal with the change from `nseg` to `ncomp` change\n # in Jaxley v0.5.0.\n if ncomp is not None and nseg is not None:\n raise ValueError(\"You passed `ncomp` and `nseg`. Please pass only `ncomp`.\")\n if ncomp is None and nseg is not None:\n ncomp = nseg\n\n super().__init__()\n assert (\n isinstance(compartments, (Compartment, List)) or compartments is None\n ), \"Only Compartment or List[Compartment] is allowed.\"\n if isinstance(compartments, Compartment):\n assert (\n ncomp is not None\n ), \"If `compartments` is not a list then you have to set `ncomp`.\"\n compartments = Compartment() if compartments is None else compartments\n ncomp = 1 if ncomp is None else ncomp\n\n if isinstance(compartments, Compartment):\n compartment_list = [compartments] * ncomp\n else:\n compartment_list = compartments\n\n self.ncomp = len(compartment_list)\n self.ncomp_per_branch = np.asarray([self.ncomp])\n self.total_nbranches = 1\n self.nbranches_per_cell = [1]\n self._cumsum_nbranches = jnp.asarray([0, 1])\n self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n\n # Indexing.\n self.nodes = pd.concat([c.nodes for c in compartment_list], ignore_index=True)\n self._append_params_and_states(self.branch_params, self.branch_states)\n self.nodes[\"global_comp_index\"] = np.arange(self.ncomp).tolist()\n self.nodes[\"global_branch_index\"] = [0] * self.ncomp\n self.nodes[\"global_cell_index\"] = [0] * self.ncomp\n self._update_local_indices()\n self._init_view()\n\n # Channels.\n self._gather_channels_from_constituents(compartment_list)\n\n self.branch_edges = pd.DataFrame(\n dict(parent_branch_index=[], child_branch_index=[])\n )\n\n # For morphology indexing.\n self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n compute_children_and_parents(self.branch_edges)\n )\n self._internal_node_inds = jnp.arange(self.ncomp)\n\n self._initialize()\n\n # Coordinates.\n self.xyzr = [float(\"NaN\") * np.zeros((2, 4))]\n\n def _init_morph_jaxley_spsolve(self):\n self._solve_indexer = JaxleySolveIndexer(\n cumsum_ncomp=self.cumsum_ncomp,\n branchpoint_group_inds=np.asarray([]).astype(int),\n remapped_node_indices=self._internal_node_inds,\n children_in_level=[],\n parents_in_level=[],\n root_inds=np.asarray([0]),\n )\n\n def _init_morph_jax_spsolve(self):\n \"\"\"Initialize morphology for the jax sparse voltage solver.\n\n Explanation of `self._comp_eges['type']`:\n `type == 0`: compartment <--> compartment (within branch)\n `type == 1`: branchpoint --> parent-compartment\n `type == 2`: branchpoint --> child-compartment\n `type == 3`: parent-compartment --> branchpoint\n `type == 4`: child-compartment --> branchpoint\n \"\"\"\n self._comp_edges = pd.DataFrame().from_dict(\n {\n \"source\": list(range(self.ncomp - 1)) + list(range(1, self.ncomp)),\n \"sink\": list(range(1, self.ncomp)) + list(range(self.ncomp - 1)),\n }\n )\n self._comp_edges[\"type\"] = 0\n n_nodes, data_inds, indices, indptr = comp_edges_to_indices(self._comp_edges)\n self._n_nodes = n_nodes\n self._data_inds = data_inds\n self._indices_jax_spsolve = indices\n self._indptr_jax_spsolve = indptr\n\n def __len__(self) -> int:\n return self.ncomp\n
"},{"location":"reference/modules/#jaxley.modules.branch.Branch.__init__","title":"__init__(compartments=None, ncomp=None, nseg=None)
","text":"Parameters:
Name Type Description Defaultcompartments
Optional[Union[Compartment, List[Compartment]]]
A single compartment or a list of compartments that make up the branch.
None
ncomp
Optional[int]
Number of segments to divide the branch into. If compartments
is an a single compartment, than the compartment is repeated ncomp
times to create the branch.
None
Source code in jaxley/modules/branch.py
@deprecated_kwargs(\"0.6.0\", [\"nseg\"])\ndef __init__(\n self,\n compartments: Optional[Union[Compartment, List[Compartment]]] = None,\n ncomp: Optional[int] = None,\n nseg: Optional[int] = None,\n):\n \"\"\"\n Args:\n compartments: A single compartment or a list of compartments that make up the\n branch.\n ncomp: Number of segments to divide the branch into. If `compartments` is an\n a single compartment, than the compartment is repeated `ncomp` times to\n create the branch.\n \"\"\"\n # Warnings and errors that deal with the change from `nseg` to `ncomp` change\n # in Jaxley v0.5.0.\n if ncomp is not None and nseg is not None:\n raise ValueError(\"You passed `ncomp` and `nseg`. Please pass only `ncomp`.\")\n if ncomp is None and nseg is not None:\n ncomp = nseg\n\n super().__init__()\n assert (\n isinstance(compartments, (Compartment, List)) or compartments is None\n ), \"Only Compartment or List[Compartment] is allowed.\"\n if isinstance(compartments, Compartment):\n assert (\n ncomp is not None\n ), \"If `compartments` is not a list then you have to set `ncomp`.\"\n compartments = Compartment() if compartments is None else compartments\n ncomp = 1 if ncomp is None else ncomp\n\n if isinstance(compartments, Compartment):\n compartment_list = [compartments] * ncomp\n else:\n compartment_list = compartments\n\n self.ncomp = len(compartment_list)\n self.ncomp_per_branch = np.asarray([self.ncomp])\n self.total_nbranches = 1\n self.nbranches_per_cell = [1]\n self._cumsum_nbranches = jnp.asarray([0, 1])\n self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n\n # Indexing.\n self.nodes = pd.concat([c.nodes for c in compartment_list], ignore_index=True)\n self._append_params_and_states(self.branch_params, self.branch_states)\n self.nodes[\"global_comp_index\"] = np.arange(self.ncomp).tolist()\n self.nodes[\"global_branch_index\"] = [0] * self.ncomp\n self.nodes[\"global_cell_index\"] = [0] * self.ncomp\n self._update_local_indices()\n self._init_view()\n\n # Channels.\n self._gather_channels_from_constituents(compartment_list)\n\n self.branch_edges = pd.DataFrame(\n dict(parent_branch_index=[], child_branch_index=[])\n )\n\n # For morphology indexing.\n self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n compute_children_and_parents(self.branch_edges)\n )\n self._internal_node_inds = jnp.arange(self.ncomp)\n\n self._initialize()\n\n # Coordinates.\n self.xyzr = [float(\"NaN\") * np.zeros((2, 4))]\n
"},{"location":"reference/modules/#cell","title":"Cell","text":" Bases: Module
Cell class.
This class defines a single cell that can be simulated by itself or connected with synapses to build a network. A cell is made up of several branches and supports intricate cell morphologies.
Source code injaxley/modules/cell.py
class Cell(Module):\n \"\"\"Cell class.\n\n This class defines a single cell that can be simulated by itself or\n connected with synapses to build a network. A cell is made up of several branches\n and supports intricate cell morphologies.\n \"\"\"\n\n cell_params: Dict = {}\n cell_states: Dict = {}\n\n def __init__(\n self,\n branches: Optional[Union[Branch, List[Branch]]] = None,\n parents: Optional[List[int]] = None,\n xyzr: Optional[List[np.ndarray]] = None,\n ):\n \"\"\"Initialize a cell.\n\n Args:\n branches: A single branch or a list of branches that make up the cell.\n If a single branch is provided, then the branch is repeated `len(parents)`\n times to create the cell.\n parents: The parent branch index for each branch. The first branch has no\n parent and is therefore set to -1.\n xyzr: For every branch, the x, y, and z coordinates and the radius at the\n traced coordinates. Note that this is the full tracing (from SWC), not\n the stick representation coordinates.\n \"\"\"\n super().__init__()\n assert (\n isinstance(branches, (Branch, List)) or branches is None\n ), \"Only Branch or List[Branch] is allowed.\"\n if branches is not None:\n assert (\n parents is not None\n ), \"If `branches` is not a list then you have to set `parents`.\"\n if isinstance(branches, List):\n assert len(parents) == len(\n branches\n ), \"Ensure equally many parents, i.e. len(branches) == len(parents).\"\n\n branches = Branch() if branches is None else branches\n parents = [-1] if parents is None else parents\n\n if isinstance(branches, Branch):\n branch_list = [branches for _ in range(len(parents))]\n else:\n branch_list = branches\n\n if xyzr is not None:\n assert len(xyzr) == len(parents)\n self.xyzr = xyzr\n else:\n # For every branch (`len(parents)`), we have a start and end point (`2`) and\n # a (x,y,z,r) coordinate for each of them (`4`).\n # Since `xyzr` is only inspected at `.vis()` and because it depends on the\n # (potentially learned) length of every compartment, we only populate\n # self.xyzr at `.vis()`.\n self.xyzr = [float(\"NaN\") * np.zeros((2, 4)) for _ in range(len(parents))]\n\n self.total_nbranches = len(branch_list)\n self.nbranches_per_cell = [len(branch_list)]\n self.comb_parents = jnp.asarray(parents)\n self.comb_children = compute_children_indices(self.comb_parents)\n self._cumsum_nbranches = np.asarray([0, len(branch_list)])\n\n # Compartment structure. These arguments have to be rebuilt when `.set_ncomp()`\n # is run.\n self.ncomp_per_branch = np.asarray([branch.ncomp for branch in branch_list])\n self.ncomp = int(np.max(self.ncomp_per_branch))\n self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n self._internal_node_inds = np.arange(self.cumsum_ncomp[-1])\n\n # Build nodes. Has to be changed when `.set_ncomp()` is run.\n self.nodes = pd.concat([c.nodes for c in branch_list], ignore_index=True)\n self.nodes[\"global_comp_index\"] = np.arange(self.cumsum_ncomp[-1])\n self.nodes[\"global_branch_index\"] = np.repeat(\n np.arange(self.total_nbranches), self.ncomp_per_branch\n ).tolist()\n self.nodes[\"global_cell_index\"] = np.repeat(0, self.cumsum_ncomp[-1]).tolist()\n self._update_local_indices()\n self._init_view()\n\n # Appending general parameters (radius, length, r_a, cm) and channel parameters,\n # as well as the states (v, and channel states).\n self._append_params_and_states(self.cell_params, self.cell_states)\n\n # Channels.\n self._gather_channels_from_constituents(branch_list)\n\n self.branch_edges = pd.DataFrame(\n dict(\n parent_branch_index=self.comb_parents[1:],\n child_branch_index=np.arange(1, self.total_nbranches),\n )\n )\n\n # For morphology indexing.\n self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n compute_children_and_parents(self.branch_edges)\n )\n\n self._initialize()\n\n def _init_morph_jaxley_spsolve(self):\n \"\"\"Initialize morphology for the custom sparse solver.\n\n Running this function is only required for custom Jaxley solvers, i.e., for\n `voltage_solver={'jaxley.stone', 'jaxley.thomas'}`. However, because at\n `.__init__()` (when the function is run), we do not yet know which solver the\n user will use. Therefore, we always run this function at `.__init__()`.\n \"\"\"\n children_and_parents = compute_morphology_indices_in_levels(\n len(self._par_inds),\n self._child_belongs_to_branchpoint,\n self._par_inds,\n self._child_inds,\n )\n branchpoint_group_inds = build_branchpoint_group_inds(\n len(self._par_inds),\n self._child_belongs_to_branchpoint,\n self.cumsum_ncomp[-1],\n )\n parents = self.comb_parents\n children_inds = children_and_parents[\"children\"]\n parents_inds = children_and_parents[\"parents\"]\n\n levels = compute_levels(parents)\n children_in_level = compute_children_in_level(levels, children_inds)\n parents_in_level = compute_parents_in_level(\n levels, self._par_inds, parents_inds\n )\n levels_and_ncomp = pd.DataFrame().from_dict(\n {\n \"levels\": levels,\n \"ncomps\": self.ncomp_per_branch,\n }\n )\n levels_and_ncomp[\"max_ncomp_in_level\"] = levels_and_ncomp.groupby(\"levels\")[\n \"ncomps\"\n ].transform(\"max\")\n padded_cumsum_ncomp = cumsum_leading_zero(\n levels_and_ncomp[\"max_ncomp_in_level\"].to_numpy()\n )\n\n # Generate mapping to deal with the masking which allows using the custom\n # sparse solver to deal with different ncomp per branch.\n remapped_node_indices = remap_index_to_masked(\n self._internal_node_inds,\n self.nodes,\n padded_cumsum_ncomp,\n self.ncomp_per_branch,\n )\n self._solve_indexer = JaxleySolveIndexer(\n cumsum_ncomp=padded_cumsum_ncomp,\n branchpoint_group_inds=branchpoint_group_inds,\n children_in_level=children_in_level,\n parents_in_level=parents_in_level,\n root_inds=np.asarray([0]),\n remapped_node_indices=remapped_node_indices,\n )\n\n def _init_morph_jax_spsolve(self):\n \"\"\"For morphology indexing with the `jax.sparse` voltage volver.\n\n Explanation of `self._comp_eges['type']`:\n `type == 0`: compartment <--> compartment (within branch)\n `type == 1`: branchpoint --> parent-compartment\n `type == 2`: branchpoint --> child-compartment\n `type == 3`: parent-compartment --> branchpoint\n `type == 4`: child-compartment --> branchpoint\n\n Running this function is only required for generic sparse solvers, i.e., for\n `voltage_solver='jax.sparse'`.\n \"\"\"\n\n # Edges between compartments within the branches.\n self._comp_edges = pd.concat(\n [\n pd.DataFrame()\n .from_dict(\n {\n \"source\": list(range(cumsum_ncomp, ncomp - 1 + cumsum_ncomp))\n + list(range(1 + cumsum_ncomp, ncomp + cumsum_ncomp)),\n \"sink\": list(range(1 + cumsum_ncomp, ncomp + cumsum_ncomp))\n + list(range(cumsum_ncomp, ncomp - 1 + cumsum_ncomp)),\n }\n )\n .astype(int)\n for ncomp, cumsum_ncomp in zip(self.ncomp_per_branch, self.cumsum_ncomp)\n ]\n )\n self._comp_edges[\"type\"] = 0\n\n # Edges from branchpoints to compartments.\n branchpoint_to_parent_edges = pd.DataFrame().from_dict(\n {\n \"source\": np.arange(len(self._par_inds)) + self.cumsum_ncomp[-1],\n \"sink\": self.cumsum_ncomp[self._par_inds + 1] - 1,\n \"type\": 1,\n }\n )\n branchpoint_to_child_edges = pd.DataFrame().from_dict(\n {\n \"source\": self._child_belongs_to_branchpoint + self.cumsum_ncomp[-1],\n \"sink\": self.cumsum_ncomp[self._child_inds],\n \"type\": 2,\n }\n )\n self._comp_edges = pd.concat(\n [\n self._comp_edges,\n branchpoint_to_parent_edges,\n branchpoint_to_child_edges,\n ],\n ignore_index=True,\n )\n\n # Edges from compartments to branchpoints.\n parent_to_branchpoint_edges = branchpoint_to_parent_edges.rename(\n columns={\"sink\": \"source\", \"source\": \"sink\"}\n )\n parent_to_branchpoint_edges[\"type\"] = 3\n child_to_branchpoint_edges = branchpoint_to_child_edges.rename(\n columns={\"sink\": \"source\", \"source\": \"sink\"}\n )\n child_to_branchpoint_edges[\"type\"] = 4\n\n self._comp_edges = pd.concat(\n [\n self._comp_edges,\n parent_to_branchpoint_edges,\n child_to_branchpoint_edges,\n ],\n ignore_index=True,\n )\n\n n_nodes, data_inds, indices, indptr = comp_edges_to_indices(self._comp_edges)\n self._n_nodes = n_nodes\n self._data_inds = data_inds\n self._indices_jax_spsolve = indices\n self._indptr_jax_spsolve = indptr\n
"},{"location":"reference/modules/#jaxley.modules.cell.Cell.__init__","title":"__init__(branches=None, parents=None, xyzr=None)
","text":"Initialize a cell.
Parameters:
Name Type Description Defaultbranches
Optional[Union[Branch, List[Branch]]]
A single branch or a list of branches that make up the cell. If a single branch is provided, then the branch is repeated len(parents)
times to create the cell.
None
parents
Optional[List[int]]
The parent branch index for each branch. The first branch has no parent and is therefore set to -1.
None
xyzr
Optional[List[ndarray]]
For every branch, the x, y, and z coordinates and the radius at the traced coordinates. Note that this is the full tracing (from SWC), not the stick representation coordinates.
None
Source code in jaxley/modules/cell.py
def __init__(\n self,\n branches: Optional[Union[Branch, List[Branch]]] = None,\n parents: Optional[List[int]] = None,\n xyzr: Optional[List[np.ndarray]] = None,\n):\n \"\"\"Initialize a cell.\n\n Args:\n branches: A single branch or a list of branches that make up the cell.\n If a single branch is provided, then the branch is repeated `len(parents)`\n times to create the cell.\n parents: The parent branch index for each branch. The first branch has no\n parent and is therefore set to -1.\n xyzr: For every branch, the x, y, and z coordinates and the radius at the\n traced coordinates. Note that this is the full tracing (from SWC), not\n the stick representation coordinates.\n \"\"\"\n super().__init__()\n assert (\n isinstance(branches, (Branch, List)) or branches is None\n ), \"Only Branch or List[Branch] is allowed.\"\n if branches is not None:\n assert (\n parents is not None\n ), \"If `branches` is not a list then you have to set `parents`.\"\n if isinstance(branches, List):\n assert len(parents) == len(\n branches\n ), \"Ensure equally many parents, i.e. len(branches) == len(parents).\"\n\n branches = Branch() if branches is None else branches\n parents = [-1] if parents is None else parents\n\n if isinstance(branches, Branch):\n branch_list = [branches for _ in range(len(parents))]\n else:\n branch_list = branches\n\n if xyzr is not None:\n assert len(xyzr) == len(parents)\n self.xyzr = xyzr\n else:\n # For every branch (`len(parents)`), we have a start and end point (`2`) and\n # a (x,y,z,r) coordinate for each of them (`4`).\n # Since `xyzr` is only inspected at `.vis()` and because it depends on the\n # (potentially learned) length of every compartment, we only populate\n # self.xyzr at `.vis()`.\n self.xyzr = [float(\"NaN\") * np.zeros((2, 4)) for _ in range(len(parents))]\n\n self.total_nbranches = len(branch_list)\n self.nbranches_per_cell = [len(branch_list)]\n self.comb_parents = jnp.asarray(parents)\n self.comb_children = compute_children_indices(self.comb_parents)\n self._cumsum_nbranches = np.asarray([0, len(branch_list)])\n\n # Compartment structure. These arguments have to be rebuilt when `.set_ncomp()`\n # is run.\n self.ncomp_per_branch = np.asarray([branch.ncomp for branch in branch_list])\n self.ncomp = int(np.max(self.ncomp_per_branch))\n self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n self._internal_node_inds = np.arange(self.cumsum_ncomp[-1])\n\n # Build nodes. Has to be changed when `.set_ncomp()` is run.\n self.nodes = pd.concat([c.nodes for c in branch_list], ignore_index=True)\n self.nodes[\"global_comp_index\"] = np.arange(self.cumsum_ncomp[-1])\n self.nodes[\"global_branch_index\"] = np.repeat(\n np.arange(self.total_nbranches), self.ncomp_per_branch\n ).tolist()\n self.nodes[\"global_cell_index\"] = np.repeat(0, self.cumsum_ncomp[-1]).tolist()\n self._update_local_indices()\n self._init_view()\n\n # Appending general parameters (radius, length, r_a, cm) and channel parameters,\n # as well as the states (v, and channel states).\n self._append_params_and_states(self.cell_params, self.cell_states)\n\n # Channels.\n self._gather_channels_from_constituents(branch_list)\n\n self.branch_edges = pd.DataFrame(\n dict(\n parent_branch_index=self.comb_parents[1:],\n child_branch_index=np.arange(1, self.total_nbranches),\n )\n )\n\n # For morphology indexing.\n self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n compute_children_and_parents(self.branch_edges)\n )\n\n self._initialize()\n
"},{"location":"reference/modules/#network","title":"Network","text":" Bases: Module
Network class.
This class defines a network of cells that can be connected with synapses.
Source code injaxley/modules/network.py
class Network(Module):\n \"\"\"Network class.\n\n This class defines a network of cells that can be connected with synapses.\n \"\"\"\n\n network_params: Dict = {}\n network_states: Dict = {}\n\n def __init__(\n self,\n cells: List[Cell],\n ):\n \"\"\"Initialize network of cells and synapses.\n\n Args:\n cells: A list of cells that make up the network.\n \"\"\"\n super().__init__()\n for cell in cells:\n self.xyzr += deepcopy(cell.xyzr)\n\n self._cells_list = cells\n self.ncomp_per_branch = np.concatenate(\n [cell.ncomp_per_branch for cell in cells]\n )\n self.ncomp = int(np.max(self.ncomp_per_branch))\n self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n self._internal_node_inds = np.arange(self.cumsum_ncomp[-1])\n self._append_params_and_states(self.network_params, self.network_states)\n\n self.nbranches_per_cell = [cell.total_nbranches for cell in cells]\n self.total_nbranches = sum(self.nbranches_per_cell)\n self._cumsum_nbranches = cumsum_leading_zero(self.nbranches_per_cell)\n\n self.nodes = pd.concat([c.nodes for c in cells], ignore_index=True)\n self.nodes[\"global_comp_index\"] = np.arange(self.cumsum_ncomp[-1])\n self.nodes[\"global_branch_index\"] = np.repeat(\n np.arange(self.total_nbranches), self.ncomp_per_branch\n ).tolist()\n self.nodes[\"global_cell_index\"] = list(\n itertools.chain(\n *[[i] * int(cell.cumsum_ncomp[-1]) for i, cell in enumerate(cells)]\n )\n )\n self._update_local_indices()\n self._init_view()\n\n parents = [cell.comb_parents for cell in cells]\n self.comb_parents = jnp.concatenate(\n [p.at[1:].add(self._cumsum_nbranches[i]) for i, p in enumerate(parents)]\n )\n\n # Two columns: `parent_branch_index` and `child_branch_index`. One row per\n # branch, apart from those branches which do not have a parent (i.e.\n # -1 in parents). For every branch, tracks the global index of that branch\n # (`child_branch_index`) and the global index of its parent\n # (`parent_branch_index`).\n self.branch_edges = pd.DataFrame(\n dict(\n parent_branch_index=self.comb_parents[self.comb_parents != -1],\n child_branch_index=np.where(self.comb_parents != -1)[0],\n )\n )\n\n # For morphology indexing of both `jax.sparse` and the custom `jaxley` solvers.\n self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n compute_children_and_parents(self.branch_edges)\n )\n\n # `nbranchpoints` in each cell == cell._par_inds (because `par_inds` are unique).\n nbranchpoints = jnp.asarray([len(cell._par_inds) for cell in cells])\n self._cumsum_nbranchpoints_per_cell = cumsum_leading_zero(nbranchpoints)\n\n # Channels.\n self._gather_channels_from_constituents(cells)\n\n self._initialize()\n del self._cells_list\n\n def __repr__(self):\n return f\"{type(self).__name__} with {len(self.channels)} different channels and {len(self.synapses)} synapses. Use `.nodes` or `.edges` for details.\"\n\n def _init_morph_jaxley_spsolve(self):\n branchpoint_group_inds = build_branchpoint_group_inds(\n len(self._par_inds),\n self._child_belongs_to_branchpoint,\n self.cumsum_ncomp[-1],\n )\n children_in_level = merge_cells(\n self._cumsum_nbranches,\n self._cumsum_nbranchpoints_per_cell,\n [cell._solve_indexer.children_in_level for cell in self._cells_list],\n exclude_first=False,\n )\n parents_in_level = merge_cells(\n self._cumsum_nbranches,\n self._cumsum_nbranchpoints_per_cell,\n [cell._solve_indexer.parents_in_level for cell in self._cells_list],\n exclude_first=False,\n )\n padded_cumsum_ncomp = cumsum_leading_zero(\n np.concatenate(\n [np.diff(cell._solve_indexer.cumsum_ncomp) for cell in self._cells_list]\n )\n )\n\n # Generate mapping to dealing with the masking which allows using the custom\n # sparse solver to deal with different ncomp per branch.\n remapped_node_indices = remap_index_to_masked(\n self._internal_node_inds,\n self.nodes,\n padded_cumsum_ncomp,\n self.ncomp_per_branch,\n )\n self._solve_indexer = JaxleySolveIndexer(\n cumsum_ncomp=padded_cumsum_ncomp,\n branchpoint_group_inds=branchpoint_group_inds,\n children_in_level=children_in_level,\n parents_in_level=parents_in_level,\n root_inds=self._cumsum_nbranches[:-1],\n remapped_node_indices=remapped_node_indices,\n )\n\n def _init_morph_jax_spsolve(self):\n \"\"\"Initialize the morphology for networks.\n\n The reason that this function is a bit involved for a `Network` is that Jaxley\n considers branchpoint nodes to be at the very end of __all__ nodes (i.e. the\n branchpoints of the first cell are even after the compartments of the second\n cell. The reason for this is that, otherwise, `cumsum_ncomp` becomes tricky).\n\n To achieve this, we first loop over all compartments and append them, and then\n loop over all branchpoints and append those. The code for building the indices\n from the `comp_edges` is identical to `jx.Cell`.\n\n Explanation of `self._comp_eges['type']`:\n `type == 0`: compartment <--> compartment (within branch)\n `type == 1`: branchpoint --> parent-compartment\n `type == 2`: branchpoint --> child-compartment\n `type == 3`: parent-compartment --> branchpoint\n `type == 4`: child-compartment --> branchpoint\n \"\"\"\n self._cumsum_ncomp_per_cell = cumsum_leading_zero(\n jnp.asarray([cell.cumsum_ncomp[-1] for cell in self.cells])\n )\n self._comp_edges = pd.DataFrame()\n\n # Add all the internal nodes.\n for offset, cell in zip(self._cumsum_ncomp_per_cell, self._cells_list):\n condition = cell._comp_edges[\"type\"].to_numpy() == 0\n rows = cell._comp_edges[condition]\n self._comp_edges = pd.concat(\n [self._comp_edges, [offset, offset, 0] + rows], ignore_index=True\n )\n\n # All branchpoint-to-compartment nodes.\n start_branchpoints = self.cumsum_ncomp[-1] # Index of the first branchpoint.\n for offset, offset_branchpoints, cell in zip(\n self._cumsum_ncomp_per_cell,\n self._cumsum_nbranchpoints_per_cell,\n self._cells_list,\n ):\n offset_within_cell = cell.cumsum_ncomp[-1]\n condition = cell._comp_edges[\"type\"].isin([1, 2])\n rows = cell._comp_edges[condition]\n self._comp_edges = pd.concat(\n [\n self._comp_edges,\n [\n start_branchpoints - offset_within_cell + offset_branchpoints,\n offset,\n 0,\n ]\n + rows,\n ],\n ignore_index=True,\n )\n\n # All compartment-to-branchpoint nodes.\n for offset, offset_branchpoints, cell in zip(\n self._cumsum_ncomp_per_cell,\n self._cumsum_nbranchpoints_per_cell,\n self._cells_list,\n ):\n offset_within_cell = cell.cumsum_ncomp[-1]\n condition = cell._comp_edges[\"type\"].isin([3, 4])\n rows = cell._comp_edges[condition]\n self._comp_edges = pd.concat(\n [\n self._comp_edges,\n [\n offset,\n start_branchpoints - offset_within_cell + offset_branchpoints,\n 0,\n ]\n + rows,\n ],\n ignore_index=True,\n )\n\n # Convert comp_edges to the index format required for `jax.sparse` solvers.\n n_nodes, data_inds, indices, indptr = comp_edges_to_indices(self._comp_edges)\n self._n_nodes = n_nodes\n self._data_inds = data_inds\n self._indices_jax_spsolve = indices\n self._indptr_jax_spsolve = indptr\n\n def _step_synapse(\n self,\n states: Dict,\n syn_channels: List,\n params: Dict,\n delta_t: float,\n edges: pd.DataFrame,\n ) -> Tuple[Dict, Tuple[jnp.ndarray, jnp.ndarray]]:\n \"\"\"Perform one step of the synapses and obtain their currents.\"\"\"\n states = self._step_synapse_state(states, syn_channels, params, delta_t, edges)\n states, current_terms = self._synapse_currents(\n states, syn_channels, params, delta_t, edges\n )\n return states, current_terms\n\n def _step_synapse_state(\n self,\n states: Dict,\n syn_channels: List,\n params: Dict,\n delta_t: float,\n edges: pd.DataFrame,\n ) -> Dict:\n voltages = states[\"v\"]\n\n grouped_syns = edges.groupby(\"type\", sort=False, group_keys=False)\n pre_syn_inds = grouped_syns[\"pre_global_comp_index\"].apply(list)\n post_syn_inds = grouped_syns[\"post_global_comp_index\"].apply(list)\n synapse_names = list(grouped_syns.indices.keys())\n\n for i, synapse_type in enumerate(syn_channels):\n assert (\n synapse_names[i] == synapse_type._name\n ), \"Mixup in the ordering of synapses. Please create an issue on Github.\"\n synapse_param_names = list(synapse_type.synapse_params.keys())\n synapse_state_names = list(synapse_type.synapse_states.keys())\n\n synapse_params = {}\n for p in synapse_param_names:\n synapse_params[p] = params[p]\n synapse_states = {}\n for s in synapse_state_names:\n synapse_states[s] = states[s]\n\n pre_inds = np.asarray(pre_syn_inds[synapse_names[i]])\n post_inds = np.asarray(post_syn_inds[synapse_names[i]])\n\n # State updates.\n states_updated = synapse_type.update_states(\n synapse_states,\n delta_t,\n voltages[pre_inds],\n voltages[post_inds],\n synapse_params,\n )\n\n # Rebuild state.\n for key, val in states_updated.items():\n states[key] = val\n\n return states\n\n def _synapse_currents(\n self,\n states: Dict,\n syn_channels: List,\n params: Dict,\n delta_t: float,\n edges: pd.DataFrame,\n ) -> Tuple[Dict, Tuple[jnp.ndarray, jnp.ndarray]]:\n voltages = states[\"v\"]\n\n grouped_syns = edges.groupby(\"type\", sort=False, group_keys=False)\n pre_syn_inds = grouped_syns[\"pre_global_comp_index\"].apply(list)\n post_syn_inds = grouped_syns[\"post_global_comp_index\"].apply(list)\n synapse_names = list(grouped_syns.indices.keys())\n\n syn_voltage_terms = jnp.zeros_like(voltages)\n syn_constant_terms = jnp.zeros_like(voltages)\n # Run with two different voltages that are `diff` apart to infer the slope and\n # offset.\n diff = 1e-3\n for i, synapse_type in enumerate(syn_channels):\n assert (\n synapse_names[i] == synapse_type._name\n ), \"Mixup in the ordering of synapses. Please create an issue on Github.\"\n synapse_param_names = list(synapse_type.synapse_params.keys())\n synapse_state_names = list(synapse_type.synapse_states.keys())\n\n synapse_params = {}\n for p in synapse_param_names:\n synapse_params[p] = params[p]\n synapse_states = {}\n for s in synapse_state_names:\n synapse_states[s] = states[s]\n\n # Get pre and post indexes of the current synapse type.\n pre_inds = np.asarray(pre_syn_inds[synapse_names[i]])\n post_inds = np.asarray(post_syn_inds[synapse_names[i]])\n\n # Compute slope and offset of the current through every synapse.\n pre_v_and_perturbed = jnp.stack(\n [voltages[pre_inds], voltages[pre_inds] + diff]\n )\n post_v_and_perturbed = jnp.stack(\n [voltages[post_inds], voltages[post_inds] + diff]\n )\n synapse_currents = vmap(\n synapse_type.compute_current, in_axes=(None, 0, 0, None)\n )(\n synapse_states,\n pre_v_and_perturbed,\n post_v_and_perturbed,\n synapse_params,\n )\n synapse_currents_dist = convert_point_process_to_distributed(\n synapse_currents,\n params[\"radius\"][post_inds],\n params[\"length\"][post_inds],\n )\n\n # Split into voltage and constant terms.\n voltage_term = (synapse_currents_dist[1] - synapse_currents_dist[0]) / diff\n constant_term = (\n synapse_currents_dist[0] - voltage_term * voltages[post_inds]\n )\n\n # Gather slope and offset for every postsynaptic compartment.\n gathered_syn_currents = gather_synapes(\n len(voltages),\n post_inds,\n voltage_term,\n constant_term,\n )\n syn_voltage_terms += gathered_syn_currents[0]\n syn_constant_terms -= gathered_syn_currents[1]\n\n # Add the synaptic currents through every compartment as state.\n # `post_syn_currents` is a `jnp.ndarray` of as many elements as there are\n # compartments in the network.\n # `[0]` because we only use the non-perturbed voltage.\n states[f\"{synapse_type._name}_current\"] = synapse_currents[0]\n\n return states, (syn_voltage_terms, syn_constant_terms)\n\n def vis(\n self,\n detail: str = \"full\",\n ax: Optional[Axes] = None,\n col: str = \"k\",\n synapse_col: str = \"b\",\n dims: Tuple[int] = (0, 1),\n type: str = \"line\",\n layers: Optional[List] = None,\n morph_plot_kwargs: Dict = {},\n synapse_plot_kwargs: Dict = {},\n synapse_scatter_kwargs: Dict = {},\n networkx_options: Dict = {},\n layer_kwargs: Dict = {},\n ) -> Axes:\n \"\"\"Visualize the module.\n\n Args:\n detail: Either of [point, full]. `point` visualizes every neuron in the\n network as a dot (and it uses `networkx` to obtain cell positions).\n `full` plots the full morphology of every neuron. It requires that\n `compute_xyz()` has been run and allows for indivual neurons to be\n moved with `.move()`.\n col: The color in which cells are plotted. Only takes effect if\n `detail='full'`.\n type: Either `line` or `scatter`. Only takes effect if `detail='full'`.\n synapse_col: The color in which synapses are plotted. Only takes effect if\n `detail='full'`.\n dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n two of them.\n layers: Allows to plot the network in layers. Should provide the number of\n neurons in each layer, e.g., [5, 10, 1] would be a network with 5 input\n neurons, 10 hidden layer neurons, and 1 output neuron.\n morph_plot_kwargs: Keyword arguments passed to the plotting function for\n cell morphologies. Only takes effect for `detail='full'`.\n synapse_plot_kwargs: Keyword arguments passed to the plotting function for\n syanpses. Only takes effect for `detail='full'`.\n synapse_scatter_kwargs: Keyword arguments passed to the scatter function\n for the end point of synapses. Only takes effect for `detail='full'`.\n networkx_options: Options passed to `networkx.draw()`. Only takes effect if\n `detail='point'`.\n layer_kwargs: Only used if `layers` is specified and if `detail='full'`.\n Can have the following entries: `within_layer_offset` (float),\n `between_layer_offset` (float), `vertical_layers` (bool).\n \"\"\"\n if detail == \"point\":\n graph = self._build_graph(layers)\n\n if layers is not None:\n pos = nx.multipartite_layout(graph, subset_key=\"layer\")\n nx.draw(graph, pos, with_labels=True, **networkx_options)\n else:\n nx.draw(graph, with_labels=True, **networkx_options)\n elif detail == \"full\":\n if layers is not None:\n # Assemble cells in the network into layers.\n global_counter = 0\n layers_config = {\n \"within_layer_offset\": 500.0,\n \"between_layer_offset\": 1500.0,\n \"vertical_layers\": False,\n }\n layers_config.update(layer_kwargs)\n for layer_ind, num_in_layer in enumerate(layers):\n for ind_within_layer in range(num_in_layer):\n if layers_config[\"vertical_layers\"]:\n x_offset = (\n ind_within_layer - (num_in_layer - 1) / 2\n ) * layers_config[\"within_layer_offset\"]\n y_offset = (len(layers) - 1 - layer_ind) * layers_config[\n \"between_layer_offset\"\n ]\n else:\n x_offset = layer_ind * layers_config[\"between_layer_offset\"]\n y_offset = (\n ind_within_layer - (num_in_layer - 1) / 2\n ) * layers_config[\"within_layer_offset\"]\n\n self.cell(global_counter).move_to(x=x_offset, y=y_offset, z=0)\n global_counter += 1\n ax = super().vis(\n dims=dims,\n col=col,\n ax=ax,\n type=type,\n morph_plot_kwargs=morph_plot_kwargs,\n )\n\n pre_locs = self.edges[\"pre_locs\"].to_numpy()\n post_locs = self.edges[\"post_locs\"].to_numpy()\n pre_comp = self.edges[\"pre_global_comp_index\"].to_numpy()\n nodes = self.nodes.set_index(\"global_comp_index\")\n pre_branch = nodes.loc[pre_comp, \"global_branch_index\"].to_numpy()\n post_comp = self.edges[\"post_global_comp_index\"].to_numpy()\n post_branch = nodes.loc[post_comp, \"global_branch_index\"].to_numpy()\n\n dims_np = np.asarray(dims)\n\n for pre_loc, post_loc, pre_b, post_b in zip(\n pre_locs, post_locs, pre_branch, post_branch\n ):\n pre_coord = self.xyzr[pre_b]\n if len(pre_coord) == 2:\n # If only start and end point of a branch are traced, perform a\n # linear interpolation to get the synpase location.\n pre_coord = pre_coord[0] + (pre_coord[1] - pre_coord[0]) * pre_loc\n else:\n # If densely traced, use intermediate trace values for synapse loc.\n middle_ind = int((len(pre_coord) - 1) * pre_loc)\n pre_coord = pre_coord[middle_ind]\n\n post_coord = self.xyzr[post_b]\n if len(post_coord) == 2:\n # If only start and end point of a branch are traced, perform a\n # linear interpolation to get the synpase location.\n post_coord = (\n post_coord[0] + (post_coord[1] - post_coord[0]) * post_loc\n )\n else:\n # If densely traced, use intermediate trace values for synapse loc.\n middle_ind = int((len(post_coord) - 1) * post_loc)\n post_coord = post_coord[middle_ind]\n\n coords = np.stack([pre_coord[dims_np], post_coord[dims_np]]).T\n ax.plot(\n coords[0],\n coords[1],\n c=synapse_col,\n **synapse_plot_kwargs,\n )\n ax.scatter(\n post_coord[dims_np[0]],\n post_coord[dims_np[1]],\n c=synapse_col,\n **synapse_scatter_kwargs,\n )\n else:\n raise ValueError(\"detail must be in {full, point}.\")\n\n return ax\n\n def _build_graph(self, layers: Optional[List] = None, **options):\n graph = nx.DiGraph()\n\n def build_extents(*subset_sizes):\n return nx.utils.pairwise(itertools.accumulate((0,) + subset_sizes))\n\n if layers is not None:\n extents = build_extents(*layers)\n layers = [range(start, end) for start, end in extents]\n for i, layer in enumerate(layers):\n graph.add_nodes_from(layer, layer=i)\n else:\n graph.add_nodes_from(range(len(self._cells_in_view)))\n\n pre_comp = self.edges[\"pre_global_comp_index\"].to_numpy()\n nodes = self.nodes.set_index(\"global_comp_index\")\n pre_cell = nodes.loc[pre_comp, \"global_cell_index\"].to_numpy()\n post_comp = self.edges[\"post_global_comp_index\"].to_numpy()\n post_cell = nodes.loc[post_comp, \"global_cell_index\"].to_numpy()\n\n inds = np.stack([pre_cell, post_cell]).T\n graph.add_edges_from(inds)\n\n return graph\n\n def _infer_synapse_type_ind(self, synapse_name):\n syn_names = self.base.synapse_names\n is_new_type = False if synapse_name in syn_names else True\n type_ind = len(syn_names) if is_new_type else syn_names.index(synapse_name)\n return type_ind, is_new_type\n\n def _update_synapse_state_names(self, synapse_type):\n # (Potentially) update variables that track meta information about synapses.\n self.base.synapse_names.append(synapse_type._name)\n self.base.synapse_param_names += list(synapse_type.synapse_params.keys())\n self.base.synapse_state_names += list(synapse_type.synapse_states.keys())\n self.base.synapses.append(synapse_type)\n\n def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type):\n # Add synapse types to the module and infer their unique identifier.\n synapse_name = synapse_type._name\n type_ind, is_new = self._infer_synapse_type_ind(synapse_name)\n if is_new: # synapse is not known\n self._update_synapse_state_names(synapse_type)\n\n index = len(self.base.edges)\n indices = [idx for idx in range(index, index + len(pre_nodes))]\n global_edge_index = pd.DataFrame({\"global_edge_index\": indices})\n post_loc = loc_of_index(\n post_nodes[\"global_comp_index\"].to_numpy(),\n post_nodes[\"global_branch_index\"].to_numpy(),\n self.ncomp_per_branch,\n )\n pre_loc = loc_of_index(\n pre_nodes[\"global_comp_index\"].to_numpy(),\n pre_nodes[\"global_branch_index\"].to_numpy(),\n self.ncomp_per_branch,\n )\n\n # Define new synapses. Each row is one synapse.\n pre_nodes = pre_nodes[[\"global_comp_index\"]]\n pre_nodes.columns = [\"pre_global_comp_index\"]\n post_nodes = post_nodes[[\"global_comp_index\"]]\n post_nodes.columns = [\"post_global_comp_index\"]\n new_rows = pd.concat(\n [\n global_edge_index,\n pre_nodes.reset_index(drop=True),\n post_nodes.reset_index(drop=True),\n ],\n axis=1,\n )\n new_rows[\"type\"] = synapse_name\n new_rows[\"type_ind\"] = type_ind\n new_rows[\"pre_locs\"] = pre_loc\n new_rows[\"post_locs\"] = post_loc\n self.base.edges = concat_and_ignore_empty(\n [self.base.edges, new_rows], ignore_index=True, axis=0\n )\n self._add_params_to_edges(synapse_type, indices)\n self.base.edges[\"controlled_by_param\"] = 0\n self._edges_in_view = self.edges.index.to_numpy()\n\n def _add_params_to_edges(self, synapse_type, indices):\n # Add parameters and states to the `.edges` table.\n for key, param_val in synapse_type.synapse_params.items():\n self.base.edges.loc[indices, key] = param_val\n\n # Update synaptic state array.\n for key, state_val in synapse_type.synapse_states.items():\n self.base.edges.loc[indices, key] = state_val\n
"},{"location":"reference/modules/#jaxley.modules.network.Network.__init__","title":"__init__(cells)
","text":"Initialize network of cells and synapses.
Parameters:
Name Type Description Defaultcells
List[Cell]
A list of cells that make up the network.
required Source code injaxley/modules/network.py
def __init__(\n self,\n cells: List[Cell],\n):\n \"\"\"Initialize network of cells and synapses.\n\n Args:\n cells: A list of cells that make up the network.\n \"\"\"\n super().__init__()\n for cell in cells:\n self.xyzr += deepcopy(cell.xyzr)\n\n self._cells_list = cells\n self.ncomp_per_branch = np.concatenate(\n [cell.ncomp_per_branch for cell in cells]\n )\n self.ncomp = int(np.max(self.ncomp_per_branch))\n self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n self._internal_node_inds = np.arange(self.cumsum_ncomp[-1])\n self._append_params_and_states(self.network_params, self.network_states)\n\n self.nbranches_per_cell = [cell.total_nbranches for cell in cells]\n self.total_nbranches = sum(self.nbranches_per_cell)\n self._cumsum_nbranches = cumsum_leading_zero(self.nbranches_per_cell)\n\n self.nodes = pd.concat([c.nodes for c in cells], ignore_index=True)\n self.nodes[\"global_comp_index\"] = np.arange(self.cumsum_ncomp[-1])\n self.nodes[\"global_branch_index\"] = np.repeat(\n np.arange(self.total_nbranches), self.ncomp_per_branch\n ).tolist()\n self.nodes[\"global_cell_index\"] = list(\n itertools.chain(\n *[[i] * int(cell.cumsum_ncomp[-1]) for i, cell in enumerate(cells)]\n )\n )\n self._update_local_indices()\n self._init_view()\n\n parents = [cell.comb_parents for cell in cells]\n self.comb_parents = jnp.concatenate(\n [p.at[1:].add(self._cumsum_nbranches[i]) for i, p in enumerate(parents)]\n )\n\n # Two columns: `parent_branch_index` and `child_branch_index`. One row per\n # branch, apart from those branches which do not have a parent (i.e.\n # -1 in parents). For every branch, tracks the global index of that branch\n # (`child_branch_index`) and the global index of its parent\n # (`parent_branch_index`).\n self.branch_edges = pd.DataFrame(\n dict(\n parent_branch_index=self.comb_parents[self.comb_parents != -1],\n child_branch_index=np.where(self.comb_parents != -1)[0],\n )\n )\n\n # For morphology indexing of both `jax.sparse` and the custom `jaxley` solvers.\n self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n compute_children_and_parents(self.branch_edges)\n )\n\n # `nbranchpoints` in each cell == cell._par_inds (because `par_inds` are unique).\n nbranchpoints = jnp.asarray([len(cell._par_inds) for cell in cells])\n self._cumsum_nbranchpoints_per_cell = cumsum_leading_zero(nbranchpoints)\n\n # Channels.\n self._gather_channels_from_constituents(cells)\n\n self._initialize()\n del self._cells_list\n
"},{"location":"reference/modules/#jaxley.modules.network.Network.vis","title":"vis(detail='full', ax=None, col='k', synapse_col='b', dims=(0, 1), type='line', layers=None, morph_plot_kwargs={}, synapse_plot_kwargs={}, synapse_scatter_kwargs={}, networkx_options={}, layer_kwargs={})
","text":"Visualize the module.
Parameters:
Name Type Description Defaultdetail
str
Either of [point, full]. point
visualizes every neuron in the network as a dot (and it uses networkx
to obtain cell positions). full
plots the full morphology of every neuron. It requires that compute_xyz()
has been run and allows for indivual neurons to be moved with .move()
.
'full'
col
str
The color in which cells are plotted. Only takes effect if detail='full'
.
'k'
type
str
Either line
or scatter
. Only takes effect if detail='full'
.
'line'
synapse_col
str
The color in which synapses are plotted. Only takes effect if detail='full'
.
'b'
dims
Tuple[int]
Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them.
(0, 1)
layers
Optional[List]
Allows to plot the network in layers. Should provide the number of neurons in each layer, e.g., [5, 10, 1] would be a network with 5 input neurons, 10 hidden layer neurons, and 1 output neuron.
None
morph_plot_kwargs
Dict
Keyword arguments passed to the plotting function for cell morphologies. Only takes effect for detail='full'
.
{}
synapse_plot_kwargs
Dict
Keyword arguments passed to the plotting function for syanpses. Only takes effect for detail='full'
.
{}
synapse_scatter_kwargs
Dict
Keyword arguments passed to the scatter function for the end point of synapses. Only takes effect for detail='full'
.
{}
networkx_options
Dict
Options passed to networkx.draw()
. Only takes effect if detail='point'
.
{}
layer_kwargs
Dict
Only used if layers
is specified and if detail='full'
. Can have the following entries: within_layer_offset
(float), between_layer_offset
(float), vertical_layers
(bool).
{}
Source code in jaxley/modules/network.py
def vis(\n self,\n detail: str = \"full\",\n ax: Optional[Axes] = None,\n col: str = \"k\",\n synapse_col: str = \"b\",\n dims: Tuple[int] = (0, 1),\n type: str = \"line\",\n layers: Optional[List] = None,\n morph_plot_kwargs: Dict = {},\n synapse_plot_kwargs: Dict = {},\n synapse_scatter_kwargs: Dict = {},\n networkx_options: Dict = {},\n layer_kwargs: Dict = {},\n) -> Axes:\n \"\"\"Visualize the module.\n\n Args:\n detail: Either of [point, full]. `point` visualizes every neuron in the\n network as a dot (and it uses `networkx` to obtain cell positions).\n `full` plots the full morphology of every neuron. It requires that\n `compute_xyz()` has been run and allows for indivual neurons to be\n moved with `.move()`.\n col: The color in which cells are plotted. Only takes effect if\n `detail='full'`.\n type: Either `line` or `scatter`. Only takes effect if `detail='full'`.\n synapse_col: The color in which synapses are plotted. Only takes effect if\n `detail='full'`.\n dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n two of them.\n layers: Allows to plot the network in layers. Should provide the number of\n neurons in each layer, e.g., [5, 10, 1] would be a network with 5 input\n neurons, 10 hidden layer neurons, and 1 output neuron.\n morph_plot_kwargs: Keyword arguments passed to the plotting function for\n cell morphologies. Only takes effect for `detail='full'`.\n synapse_plot_kwargs: Keyword arguments passed to the plotting function for\n syanpses. Only takes effect for `detail='full'`.\n synapse_scatter_kwargs: Keyword arguments passed to the scatter function\n for the end point of synapses. Only takes effect for `detail='full'`.\n networkx_options: Options passed to `networkx.draw()`. Only takes effect if\n `detail='point'`.\n layer_kwargs: Only used if `layers` is specified and if `detail='full'`.\n Can have the following entries: `within_layer_offset` (float),\n `between_layer_offset` (float), `vertical_layers` (bool).\n \"\"\"\n if detail == \"point\":\n graph = self._build_graph(layers)\n\n if layers is not None:\n pos = nx.multipartite_layout(graph, subset_key=\"layer\")\n nx.draw(graph, pos, with_labels=True, **networkx_options)\n else:\n nx.draw(graph, with_labels=True, **networkx_options)\n elif detail == \"full\":\n if layers is not None:\n # Assemble cells in the network into layers.\n global_counter = 0\n layers_config = {\n \"within_layer_offset\": 500.0,\n \"between_layer_offset\": 1500.0,\n \"vertical_layers\": False,\n }\n layers_config.update(layer_kwargs)\n for layer_ind, num_in_layer in enumerate(layers):\n for ind_within_layer in range(num_in_layer):\n if layers_config[\"vertical_layers\"]:\n x_offset = (\n ind_within_layer - (num_in_layer - 1) / 2\n ) * layers_config[\"within_layer_offset\"]\n y_offset = (len(layers) - 1 - layer_ind) * layers_config[\n \"between_layer_offset\"\n ]\n else:\n x_offset = layer_ind * layers_config[\"between_layer_offset\"]\n y_offset = (\n ind_within_layer - (num_in_layer - 1) / 2\n ) * layers_config[\"within_layer_offset\"]\n\n self.cell(global_counter).move_to(x=x_offset, y=y_offset, z=0)\n global_counter += 1\n ax = super().vis(\n dims=dims,\n col=col,\n ax=ax,\n type=type,\n morph_plot_kwargs=morph_plot_kwargs,\n )\n\n pre_locs = self.edges[\"pre_locs\"].to_numpy()\n post_locs = self.edges[\"post_locs\"].to_numpy()\n pre_comp = self.edges[\"pre_global_comp_index\"].to_numpy()\n nodes = self.nodes.set_index(\"global_comp_index\")\n pre_branch = nodes.loc[pre_comp, \"global_branch_index\"].to_numpy()\n post_comp = self.edges[\"post_global_comp_index\"].to_numpy()\n post_branch = nodes.loc[post_comp, \"global_branch_index\"].to_numpy()\n\n dims_np = np.asarray(dims)\n\n for pre_loc, post_loc, pre_b, post_b in zip(\n pre_locs, post_locs, pre_branch, post_branch\n ):\n pre_coord = self.xyzr[pre_b]\n if len(pre_coord) == 2:\n # If only start and end point of a branch are traced, perform a\n # linear interpolation to get the synpase location.\n pre_coord = pre_coord[0] + (pre_coord[1] - pre_coord[0]) * pre_loc\n else:\n # If densely traced, use intermediate trace values for synapse loc.\n middle_ind = int((len(pre_coord) - 1) * pre_loc)\n pre_coord = pre_coord[middle_ind]\n\n post_coord = self.xyzr[post_b]\n if len(post_coord) == 2:\n # If only start and end point of a branch are traced, perform a\n # linear interpolation to get the synpase location.\n post_coord = (\n post_coord[0] + (post_coord[1] - post_coord[0]) * post_loc\n )\n else:\n # If densely traced, use intermediate trace values for synapse loc.\n middle_ind = int((len(post_coord) - 1) * post_loc)\n post_coord = post_coord[middle_ind]\n\n coords = np.stack([pre_coord[dims_np], post_coord[dims_np]]).T\n ax.plot(\n coords[0],\n coords[1],\n c=synapse_col,\n **synapse_plot_kwargs,\n )\n ax.scatter(\n post_coord[dims_np[0]],\n post_coord[dims_np[1]],\n c=synapse_col,\n **synapse_scatter_kwargs,\n )\n else:\n raise ValueError(\"detail must be in {full, point}.\")\n\n return ax\n
"},{"location":"reference/optimize/","title":"Optimization","text":""},{"location":"reference/optimize/#jaxley.optimize.optimizer.TypeOptimizer","title":"TypeOptimizer
","text":"optax
wrapper which allows different argument values for different params.
jaxley/optimize/optimizer.py
class TypeOptimizer:\n \"\"\"`optax` wrapper which allows different argument values for different params.\"\"\"\n\n def __init__(\n self,\n optimizer: Callable,\n optimizer_args: Dict[str, Any],\n opt_params: List[Dict[str, jnp.ndarray]],\n ):\n \"\"\"Create the optimizers.\n\n This requires access to `opt_params` in order to know how many optimizers\n should be created. It creates `len(opt_params)` optimizers.\n\n Example usage:\n ```\n lrs = {\"HH_gNa\": 0.01, \"radius\": 1.0}\n optimizer = TypeOptimizer(lambda lr: optax.adam(lr), lrs, opt_params)\n opt_state = optimizer.init(opt_params)\n ```\n\n ```\n optimizer_args = {\"HH_gNa\": [0.01, 0.4], \"radius\": [1.0, 0.8]}\n optimizer = TypeOptimizer(\n lambda args: optax.sgd(args[0], momentum=args[1]),\n optimizer_args,\n opt_params\n )\n opt_state = optimizer.init(opt_params)\n ```\n\n Args:\n optimizer: A Callable that takes the learning rate and returns the\n `optax.optimizer` which should be used.\n optimizer_args: The arguments for different kinds of parameters.\n Each item of the dictionary will be passed to the `Callable` passed to\n `optimizer`.\n opt_params: The parameters to be optimized. The exact values are not used,\n only the number of elements in the list and the key of each dict.\n \"\"\"\n self.base_optimizer = optimizer\n\n self.optimizers = []\n for params in opt_params:\n names = list(params.keys())\n assert len(names) == 1, \"Multiple parameters were added at once.\"\n name = names[0]\n optimizer = self.base_optimizer(optimizer_args[name])\n self.optimizers.append({name: optimizer})\n\n def init(self, opt_params: List[Dict[str, jnp.ndarray]]) -> List:\n \"\"\"Initialize the optimizers. Equivalent to `optax.optimizers.init()`.\"\"\"\n opt_states = []\n for params, optimizer in zip(opt_params, self.optimizers):\n name = list(optimizer.keys())[0]\n opt_state = optimizer[name].init(params)\n opt_states.append(opt_state)\n return opt_states\n\n def update(self, gradient: jnp.ndarray, opt_state: List) -> Tuple[List, List]:\n \"\"\"Update the optimizers. Equivalent to `optax.optimizers.update()`.\"\"\"\n all_updates = []\n new_opt_states = []\n for grad, state, opt in zip(gradient, opt_state, self.optimizers):\n name = list(opt.keys())[0]\n updates, new_opt_state = opt[name].update(grad, state)\n all_updates.append(updates)\n new_opt_states.append(new_opt_state)\n return all_updates, new_opt_states\n
"},{"location":"reference/optimize/#jaxley.optimize.optimizer.TypeOptimizer.__init__","title":"__init__(optimizer, optimizer_args, opt_params)
","text":"Create the optimizers.
This requires access to opt_params
in order to know how many optimizers should be created. It creates len(opt_params)
optimizers.
Example usage:
lrs = {\"HH_gNa\": 0.01, \"radius\": 1.0}\noptimizer = TypeOptimizer(lambda lr: optax.adam(lr), lrs, opt_params)\nopt_state = optimizer.init(opt_params)\n
optimizer_args = {\"HH_gNa\": [0.01, 0.4], \"radius\": [1.0, 0.8]}\noptimizer = TypeOptimizer(\n lambda args: optax.sgd(args[0], momentum=args[1]),\n optimizer_args,\n opt_params\n)\nopt_state = optimizer.init(opt_params)\n
Parameters:
Name Type Description Defaultoptimizer
Callable
A Callable that takes the learning rate and returns the optax.optimizer
which should be used.
optimizer_args
Dict[str, Any]
The arguments for different kinds of parameters. Each item of the dictionary will be passed to the Callable
passed to optimizer
.
opt_params
List[Dict[str, ndarray]]
The parameters to be optimized. The exact values are not used, only the number of elements in the list and the key of each dict.
required Source code injaxley/optimize/optimizer.py
def __init__(\n self,\n optimizer: Callable,\n optimizer_args: Dict[str, Any],\n opt_params: List[Dict[str, jnp.ndarray]],\n):\n \"\"\"Create the optimizers.\n\n This requires access to `opt_params` in order to know how many optimizers\n should be created. It creates `len(opt_params)` optimizers.\n\n Example usage:\n ```\n lrs = {\"HH_gNa\": 0.01, \"radius\": 1.0}\n optimizer = TypeOptimizer(lambda lr: optax.adam(lr), lrs, opt_params)\n opt_state = optimizer.init(opt_params)\n ```\n\n ```\n optimizer_args = {\"HH_gNa\": [0.01, 0.4], \"radius\": [1.0, 0.8]}\n optimizer = TypeOptimizer(\n lambda args: optax.sgd(args[0], momentum=args[1]),\n optimizer_args,\n opt_params\n )\n opt_state = optimizer.init(opt_params)\n ```\n\n Args:\n optimizer: A Callable that takes the learning rate and returns the\n `optax.optimizer` which should be used.\n optimizer_args: The arguments for different kinds of parameters.\n Each item of the dictionary will be passed to the `Callable` passed to\n `optimizer`.\n opt_params: The parameters to be optimized. The exact values are not used,\n only the number of elements in the list and the key of each dict.\n \"\"\"\n self.base_optimizer = optimizer\n\n self.optimizers = []\n for params in opt_params:\n names = list(params.keys())\n assert len(names) == 1, \"Multiple parameters were added at once.\"\n name = names[0]\n optimizer = self.base_optimizer(optimizer_args[name])\n self.optimizers.append({name: optimizer})\n
"},{"location":"reference/optimize/#jaxley.optimize.optimizer.TypeOptimizer.init","title":"init(opt_params)
","text":"Initialize the optimizers. Equivalent to optax.optimizers.init()
.
jaxley/optimize/optimizer.py
def init(self, opt_params: List[Dict[str, jnp.ndarray]]) -> List:\n \"\"\"Initialize the optimizers. Equivalent to `optax.optimizers.init()`.\"\"\"\n opt_states = []\n for params, optimizer in zip(opt_params, self.optimizers):\n name = list(optimizer.keys())[0]\n opt_state = optimizer[name].init(params)\n opt_states.append(opt_state)\n return opt_states\n
"},{"location":"reference/optimize/#jaxley.optimize.optimizer.TypeOptimizer.update","title":"update(gradient, opt_state)
","text":"Update the optimizers. Equivalent to optax.optimizers.update()
.
jaxley/optimize/optimizer.py
def update(self, gradient: jnp.ndarray, opt_state: List) -> Tuple[List, List]:\n \"\"\"Update the optimizers. Equivalent to `optax.optimizers.update()`.\"\"\"\n all_updates = []\n new_opt_states = []\n for grad, state, opt in zip(gradient, opt_state, self.optimizers):\n name = list(opt.keys())[0]\n updates, new_opt_state = opt[name].update(grad, state)\n all_updates.append(updates)\n new_opt_states.append(new_opt_state)\n return all_updates, new_opt_states\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.AffineTransform","title":"AffineTransform
","text":" Bases: Transform
jaxley/optimize/transforms.py
class AffineTransform(Transform):\n def __init__(self, scale: ArrayLike, shift: ArrayLike):\n \"\"\"This transform rescales and shifts the input.\n\n Args:\n scale (ArrayLike): Scaling factor.\n shift (ArrayLike): Additive shift.\n\n Raises:\n ValueError: Scale needs to be larger than 0\n \"\"\"\n if jnp.allclose(scale, 0):\n raise ValueError(\"a cannot be zero, must be invertible\")\n self.a = scale\n self.b = shift\n\n def forward(self, x: ArrayLike) -> Array:\n return self.a * x + self.b\n\n def inverse(self, x: ArrayLike) -> Array:\n return (x - self.b) / self.a\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.AffineTransform.__init__","title":"__init__(scale, shift)
","text":"This transform rescales and shifts the input.
Parameters:
Name Type Description Defaultscale
ArrayLike
Scaling factor.
requiredshift
ArrayLike
Additive shift.
requiredRaises:
Type DescriptionValueError
Scale needs to be larger than 0
Source code injaxley/optimize/transforms.py
def __init__(self, scale: ArrayLike, shift: ArrayLike):\n \"\"\"This transform rescales and shifts the input.\n\n Args:\n scale (ArrayLike): Scaling factor.\n shift (ArrayLike): Additive shift.\n\n Raises:\n ValueError: Scale needs to be larger than 0\n \"\"\"\n if jnp.allclose(scale, 0):\n raise ValueError(\"a cannot be zero, must be invertible\")\n self.a = scale\n self.b = shift\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.ChainTransform","title":"ChainTransform
","text":" Bases: Transform
Chaining together multiple transformations
Source code injaxley/optimize/transforms.py
class ChainTransform(Transform):\n \"\"\"Chaining together multiple transformations\"\"\"\n\n def __init__(self, transforms: Sequence[Transform]) -> None:\n \"\"\"A chain of transformations\n\n Args:\n transforms (Sequence[Transform]): Transforms to apply\n \"\"\"\n super().__init__()\n self.transforms = transforms\n\n def forward(self, x: ArrayLike) -> Array:\n for transform in self.transforms:\n x = transform(x)\n return x\n\n def inverse(self, y: ArrayLike) -> Array:\n for transform in reversed(self.transforms):\n y = transform.inverse(y)\n return y\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.ChainTransform.__init__","title":"__init__(transforms)
","text":"A chain of transformations
Parameters:
Name Type Description Defaulttransforms
Sequence[Transform]
Transforms to apply
required Source code injaxley/optimize/transforms.py
def __init__(self, transforms: Sequence[Transform]) -> None:\n \"\"\"A chain of transformations\n\n Args:\n transforms (Sequence[Transform]): Transforms to apply\n \"\"\"\n super().__init__()\n self.transforms = transforms\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.CustomTransform","title":"CustomTransform
","text":" Bases: Transform
Custom transformation
Source code injaxley/optimize/transforms.py
class CustomTransform(Transform):\n \"\"\"Custom transformation\"\"\"\n\n def __init__(self, forward_fn: Callable, inverse_fn: Callable) -> None:\n \"\"\"A custom transformation using a user-defined froward and\n inverse function\n\n Args:\n forward_fn (Callable): Forward transformation\n inverse_fn (Callable): Inverse transformation\n \"\"\"\n super().__init__()\n self.forward_fn = forward_fn\n self.inverse_fn = inverse_fn\n\n def forward(self, x: ArrayLike) -> Array:\n return self.forward_fn(x)\n\n def inverse(self, y: ArrayLike) -> Array:\n return self.inverse_fn(y)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.CustomTransform.__init__","title":"__init__(forward_fn, inverse_fn)
","text":"A custom transformation using a user-defined froward and inverse function
Parameters:
Name Type Description Defaultforward_fn
Callable
Forward transformation
requiredinverse_fn
Callable
Inverse transformation
required Source code injaxley/optimize/transforms.py
def __init__(self, forward_fn: Callable, inverse_fn: Callable) -> None:\n \"\"\"A custom transformation using a user-defined froward and\n inverse function\n\n Args:\n forward_fn (Callable): Forward transformation\n inverse_fn (Callable): Inverse transformation\n \"\"\"\n super().__init__()\n self.forward_fn = forward_fn\n self.inverse_fn = inverse_fn\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.MaskedTransform","title":"MaskedTransform
","text":" Bases: Transform
jaxley/optimize/transforms.py
class MaskedTransform(Transform):\n def __init__(self, mask: ArrayLike, transform: Transform) -> None:\n \"\"\"A masked transformation\n\n Args:\n mask (ArrayLike): Which elements to transform\n transform (Transform): Transformation to apply\n \"\"\"\n super().__init__()\n self.mask = mask\n self.transform = transform\n\n def forward(self, x: ArrayLike) -> Array:\n return jnp.where(self.mask, self.transform.forward(x), x)\n\n def inverse(self, y: ArrayLike) -> Array:\n return jnp.where(self.mask, self.transform.inverse(y), y)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.MaskedTransform.__init__","title":"__init__(mask, transform)
","text":"A masked transformation
Parameters:
Name Type Description Defaultmask
ArrayLike
Which elements to transform
requiredtransform
Transform
Transformation to apply
required Source code injaxley/optimize/transforms.py
def __init__(self, mask: ArrayLike, transform: Transform) -> None:\n \"\"\"A masked transformation\n\n Args:\n mask (ArrayLike): Which elements to transform\n transform (Transform): Transformation to apply\n \"\"\"\n super().__init__()\n self.mask = mask\n self.transform = transform\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.NegSoftplusTransform","title":"NegSoftplusTransform
","text":" Bases: SoftplusTransform
Negative softplus transformation.
Source code injaxley/optimize/transforms.py
class NegSoftplusTransform(SoftplusTransform):\n \"\"\"Negative softplus transformation.\"\"\"\n\n def __init__(self, upper: ArrayLike) -> None:\n \"\"\"This transform maps any value bijectively to the interval (-inf, upper].\n\n Args:\n upper (ArrayLike): Upper bound of the interval.\n \"\"\"\n super().__init__(upper)\n\n def forward(self, x: ArrayLike) -> Array:\n return -super().forward(-x)\n\n def inverse(self, y: ArrayLike) -> Array:\n return -super().inverse(-y)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.NegSoftplusTransform.__init__","title":"__init__(upper)
","text":"This transform maps any value bijectively to the interval (-inf, upper].
Parameters:
Name Type Description Defaultupper
ArrayLike
Upper bound of the interval.
required Source code injaxley/optimize/transforms.py
def __init__(self, upper: ArrayLike) -> None:\n \"\"\"This transform maps any value bijectively to the interval (-inf, upper].\n\n Args:\n upper (ArrayLike): Upper bound of the interval.\n \"\"\"\n super().__init__(upper)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.ParamTransform","title":"ParamTransform
","text":"Parameter transformation utility.
This class is used to transform parameters usually from an unconstrained space to a constrained space and back (bacause most biophysical parameter are bounded). The user can specify a PyTree of transforms that are applied to the parameters.
Attributes:
Name Type Descriptiontf_dict
A PyTree of transforms for each parameter.
Source code injaxley/optimize/transforms.py
class ParamTransform:\n \"\"\"Parameter transformation utility.\n\n This class is used to transform parameters usually from an unconstrained space to a constrained space\n and back (bacause most biophysical parameter are bounded). The user can specify a PyTree of transforms\n that are applied to the parameters.\n\n Attributes:\n tf_dict: A PyTree of transforms for each parameter.\n\n \"\"\"\n\n def __init__(self, tf_dict: List[Dict[str, Transform]] | Transform) -> None:\n \"\"\"Creates a new ParamTransform object.\n\n Args:\n tf_dict: A PyTree of transforms for each parameter.\n \"\"\"\n\n self.tf_dict = tf_dict\n\n def forward(\n self, params: List[Dict[str, ArrayLike]] | ArrayLike\n ) -> Dict[str, Array]:\n \"\"\"Pushes unconstrained parameters through a tf such that they fit the interval.\n\n Args:\n params: A list of dictionaries (or any PyTree) with unconstrained parameters.\n\n Returns:\n A list of dictionaries (or any PyTree) with transformed parameters.\n\n \"\"\"\n\n return jax.tree_util.tree_map(lambda x, tf: tf.forward(x), params, self.tf_dict)\n\n def inverse(\n self, params: List[Dict[str, ArrayLike]] | ArrayLike\n ) -> Dict[str, Array]:\n \"\"\"Takes parameters from within the interval and makes them unconstrained.\n\n Args:\n params: A list of dictionaries (or any PyTree) with transformed parameters.\n\n Returns:\n A list of dictionaries (or any PyTree) with unconstrained parameters.\n \"\"\"\n\n return jax.tree_util.tree_map(lambda x, tf: tf.inverse(x), params, self.tf_dict)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.ParamTransform.__init__","title":"__init__(tf_dict)
","text":"Creates a new ParamTransform object.
Parameters:
Name Type Description Defaulttf_dict
List[Dict[str, Transform]] | Transform
A PyTree of transforms for each parameter.
required Source code injaxley/optimize/transforms.py
def __init__(self, tf_dict: List[Dict[str, Transform]] | Transform) -> None:\n \"\"\"Creates a new ParamTransform object.\n\n Args:\n tf_dict: A PyTree of transforms for each parameter.\n \"\"\"\n\n self.tf_dict = tf_dict\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.ParamTransform.forward","title":"forward(params)
","text":"Pushes unconstrained parameters through a tf such that they fit the interval.
Parameters:
Name Type Description Defaultparams
List[Dict[str, ArrayLike]] | ArrayLike
A list of dictionaries (or any PyTree) with unconstrained parameters.
requiredReturns:
Type DescriptionDict[str, Array]
A list of dictionaries (or any PyTree) with transformed parameters.
Source code injaxley/optimize/transforms.py
def forward(\n self, params: List[Dict[str, ArrayLike]] | ArrayLike\n) -> Dict[str, Array]:\n \"\"\"Pushes unconstrained parameters through a tf such that they fit the interval.\n\n Args:\n params: A list of dictionaries (or any PyTree) with unconstrained parameters.\n\n Returns:\n A list of dictionaries (or any PyTree) with transformed parameters.\n\n \"\"\"\n\n return jax.tree_util.tree_map(lambda x, tf: tf.forward(x), params, self.tf_dict)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.ParamTransform.inverse","title":"inverse(params)
","text":"Takes parameters from within the interval and makes them unconstrained.
Parameters:
Name Type Description Defaultparams
List[Dict[str, ArrayLike]] | ArrayLike
A list of dictionaries (or any PyTree) with transformed parameters.
requiredReturns:
Type DescriptionDict[str, Array]
A list of dictionaries (or any PyTree) with unconstrained parameters.
Source code injaxley/optimize/transforms.py
def inverse(\n self, params: List[Dict[str, ArrayLike]] | ArrayLike\n) -> Dict[str, Array]:\n \"\"\"Takes parameters from within the interval and makes them unconstrained.\n\n Args:\n params: A list of dictionaries (or any PyTree) with transformed parameters.\n\n Returns:\n A list of dictionaries (or any PyTree) with unconstrained parameters.\n \"\"\"\n\n return jax.tree_util.tree_map(lambda x, tf: tf.inverse(x), params, self.tf_dict)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.SigmoidTransform","title":"SigmoidTransform
","text":" Bases: Transform
Sigmoid transformation.
Source code injaxley/optimize/transforms.py
class SigmoidTransform(Transform):\n \"\"\"Sigmoid transformation.\"\"\"\n\n def __init__(self, lower: ArrayLike, upper: ArrayLike) -> None:\n \"\"\"This transform maps any value bijectively to the interval [lower, upper].\n\n Args:\n lower (ArrayLike): Lower bound of the interval.\n upper (ArrayLike): Upper bound of the interval.\n \"\"\"\n super().__init__()\n self.lower = lower\n self.width = upper - lower\n\n def forward(self, x: ArrayLike) -> Array:\n y = 1.0 / (1.0 + save_exp(-x))\n return self.lower + self.width * y\n\n def inverse(self, y: ArrayLike) -> Array:\n x = (y - self.lower) / self.width\n x = -jnp.log((1.0 / x) - 1.0)\n return x\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.SigmoidTransform.__init__","title":"__init__(lower, upper)
","text":"This transform maps any value bijectively to the interval [lower, upper].
Parameters:
Name Type Description Defaultlower
ArrayLike
Lower bound of the interval.
requiredupper
ArrayLike
Upper bound of the interval.
required Source code injaxley/optimize/transforms.py
def __init__(self, lower: ArrayLike, upper: ArrayLike) -> None:\n \"\"\"This transform maps any value bijectively to the interval [lower, upper].\n\n Args:\n lower (ArrayLike): Lower bound of the interval.\n upper (ArrayLike): Upper bound of the interval.\n \"\"\"\n super().__init__()\n self.lower = lower\n self.width = upper - lower\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.SoftplusTransform","title":"SoftplusTransform
","text":" Bases: Transform
Softplus transformation.
Source code injaxley/optimize/transforms.py
class SoftplusTransform(Transform):\n \"\"\"Softplus transformation.\"\"\"\n\n def __init__(self, lower: ArrayLike) -> None:\n \"\"\"This transform maps any value bijectively to the interval [lower, inf).\n\n Args:\n lower (ArrayLike): Lower bound of the interval.\n \"\"\"\n super().__init__()\n self.lower = lower\n\n def forward(self, x: ArrayLike) -> Array:\n return jnp.log1p(save_exp(x)) + self.lower\n\n def inverse(self, y: ArrayLike) -> Array:\n return jnp.log(save_exp(y - self.lower) - 1.0)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.SoftplusTransform.__init__","title":"__init__(lower)
","text":"This transform maps any value bijectively to the interval [lower, inf).
Parameters:
Name Type Description Defaultlower
ArrayLike
Lower bound of the interval.
required Source code injaxley/optimize/transforms.py
def __init__(self, lower: ArrayLike) -> None:\n \"\"\"This transform maps any value bijectively to the interval [lower, inf).\n\n Args:\n lower (ArrayLike): Lower bound of the interval.\n \"\"\"\n super().__init__()\n self.lower = lower\n
"},{"location":"reference/utils/","title":"Utils","text":""},{"location":"reference/utils/#jaxley.utils.cell_utils.build_radiuses_from_xyzr","title":"build_radiuses_from_xyzr(radius_fns, branch_indices, min_radius, ncomp)
","text":"Return the radiuses of branches given SWC file xyzr.
Returns an array of shape (num_branches, ncomp)
.
Parameters:
Name Type Description Defaultradius_fns
List[Callable]
Functions which, given compartment locations return the radius.
requiredbranch_indices
List[int]
The indices of the branches for which to return the radiuses.
requiredmin_radius
Optional[float]
If passed, the radiuses are clipped to be at least as large.
requiredncomp
int
The number of compartments that every branch is discretized into.
required Source code injaxley/utils/cell_utils.py
def build_radiuses_from_xyzr(\n radius_fns: List[Callable],\n branch_indices: List[int],\n min_radius: Optional[float],\n ncomp: int,\n) -> jnp.ndarray:\n \"\"\"Return the radiuses of branches given SWC file xyzr.\n\n Returns an array of shape `(num_branches, ncomp)`.\n\n Args:\n radius_fns: Functions which, given compartment locations return the radius.\n branch_indices: The indices of the branches for which to return the radiuses.\n min_radius: If passed, the radiuses are clipped to be at least as large.\n ncomp: The number of compartments that every branch is discretized into.\n \"\"\"\n # Compartment locations are at the center of the internal nodes.\n non_split = 1 / ncomp\n range_ = np.linspace(non_split / 2, 1 - non_split / 2, ncomp)\n\n # Build radiuses.\n radiuses = np.asarray([radius_fns[b](range_) for b in branch_indices])\n radiuses_each = radiuses.ravel(order=\"C\")\n if min_radius is None:\n assert np.all(\n radiuses_each > 0.0\n ), \"Radius 0.0 in SWC file. Set `read_swc(..., min_radius=...)`.\"\n else:\n radiuses_each[radiuses_each < min_radius] = min_radius\n\n return radiuses_each\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_axial_conductances","title":"compute_axial_conductances(comp_edges, params)
","text":"Given comp_edges
, radius, length, r_a, cm, compute the axial conductances.
Note that the resulting axial conductances will already by divided by the capacitance cm
.
jaxley/utils/cell_utils.py
def compute_axial_conductances(\n comp_edges: pd.DataFrame, params: Dict[str, jnp.ndarray]\n) -> jnp.ndarray:\n \"\"\"Given `comp_edges`, radius, length, r_a, cm, compute the axial conductances.\n\n Note that the resulting axial conductances will already by divided by the\n capacitance `cm`.\n \"\"\"\n # `Compartment-to-compartment` (c2c) axial coupling conductances.\n condition = comp_edges[\"type\"].to_numpy() == 0\n source_comp_inds = np.asarray(comp_edges[condition][\"source\"].to_list())\n sink_comp_inds = np.asarray(comp_edges[condition][\"sink\"].to_list())\n\n if len(sink_comp_inds) > 0:\n conds_c2c = (\n vmap(compute_coupling_cond, in_axes=(0, 0, 0, 0, 0, 0))(\n params[\"radius\"][sink_comp_inds],\n params[\"radius\"][source_comp_inds],\n params[\"axial_resistivity\"][sink_comp_inds],\n params[\"axial_resistivity\"][source_comp_inds],\n params[\"length\"][sink_comp_inds],\n params[\"length\"][source_comp_inds],\n )\n / params[\"capacitance\"][sink_comp_inds]\n )\n else:\n conds_c2c = jnp.asarray([])\n\n # `branchpoint-to-compartment` (bp2c) axial coupling conductances.\n condition = comp_edges[\"type\"].isin([1, 2])\n sink_comp_inds = np.asarray(comp_edges[condition][\"sink\"].to_list())\n\n if len(sink_comp_inds) > 0:\n conds_bp2c = (\n vmap(compute_coupling_cond_branchpoint, in_axes=(0, 0, 0))(\n params[\"radius\"][sink_comp_inds],\n params[\"axial_resistivity\"][sink_comp_inds],\n params[\"length\"][sink_comp_inds],\n )\n / params[\"capacitance\"][sink_comp_inds]\n )\n else:\n conds_bp2c = jnp.asarray([])\n\n # `compartment-to-branchpoint` (c2bp) axial coupling conductances.\n condition = comp_edges[\"type\"].isin([3, 4])\n source_comp_inds = np.asarray(comp_edges[condition][\"source\"].to_list())\n\n if len(source_comp_inds) > 0:\n conds_c2bp = vmap(compute_impact_on_node, in_axes=(0, 0, 0))(\n params[\"radius\"][source_comp_inds],\n params[\"axial_resistivity\"][source_comp_inds],\n params[\"length\"][source_comp_inds],\n )\n # For numerical stability. These values are very small, but their scale\n # does not matter.\n conds_c2bp *= 1_000\n else:\n conds_c2bp = jnp.asarray([])\n\n # All axial coupling conductances.\n return jnp.concatenate([conds_c2c, conds_bp2c, conds_c2bp])\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_children_and_parents","title":"compute_children_and_parents(branch_edges)
","text":"Build indices used during `._init_morph_custom_spsolve().
Source code injaxley/utils/cell_utils.py
def compute_children_and_parents(\n branch_edges: pd.DataFrame,\n) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int]:\n \"\"\"Build indices used during `._init_morph_custom_spsolve().\"\"\"\n par_inds = branch_edges[\"parent_branch_index\"].to_numpy()\n child_inds = branch_edges[\"child_branch_index\"].to_numpy()\n child_belongs_to_branchpoint = remap_to_consecutive(par_inds)\n par_inds = np.unique(par_inds)\n return par_inds, child_inds, child_belongs_to_branchpoint\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_children_indices","title":"compute_children_indices(parents)
","text":"Return all children indices of every branch.
Example:
parents = [-1, 0, 0]\ncompute_children_indices(parents) -> [[1, 2], [], []]\n
Source code in jaxley/utils/cell_utils.py
def compute_children_indices(parents) -> List[jnp.ndarray]:\n \"\"\"Return all children indices of every branch.\n\n Example:\n ```\n parents = [-1, 0, 0]\n compute_children_indices(parents) -> [[1, 2], [], []]\n ```\n \"\"\"\n num_branches = len(parents)\n child_indices = []\n for b in range(num_branches):\n child_indices.append(np.where(parents == b)[0])\n return child_indices\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_coupling_cond","title":"compute_coupling_cond(rad1, rad2, r_a1, r_a2, l1, l2)
","text":"Return the coupling conductance between two compartments.
Equations taken from https://en.wikipedia.org/wiki/Compartmental_neuron_models
.
radius
: um r_a
: ohm cm length_single_compartment
: um coupling_conds
: S * um / cm / um^2 = S / cm / um -> *10**7 -> mS / cm^2
jaxley/utils/cell_utils.py
def compute_coupling_cond(rad1, rad2, r_a1, r_a2, l1, l2):\n \"\"\"Return the coupling conductance between two compartments.\n\n Equations taken from `https://en.wikipedia.org/wiki/Compartmental_neuron_models`.\n\n `radius`: um\n `r_a`: ohm cm\n `length_single_compartment`: um\n `coupling_conds`: S * um / cm / um^2 = S / cm / um -> *10**7 -> mS / cm^2\n \"\"\"\n # Multiply by 10**7 to convert (S / cm / um) -> (mS / cm^2).\n return rad1 * rad2**2 / (r_a1 * rad2**2 * l1 + r_a2 * rad1**2 * l2) / l1 * 10**7\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_coupling_cond_branchpoint","title":"compute_coupling_cond_branchpoint(rad, r_a, l)
","text":"Return the coupling conductance between one compartment and a comp with l=0.
From https://en.wikipedia.org/wiki/Compartmental_neuron_models
If one compartment has l=0.0 then the equations simplify.
R_long = \\sum_i r_a * L_i/2 / crosssection_i
with crosssection = pi * r**2
For a single compartment with L>0, this turns into: R_long = r_a * L/2 / crosssection
Then, g_long = crosssection * 2 / L / r_a
Then, the effective conductance is g_long / zylinder_area. So: g = pi * r**2 * 2 / L / r_a / 2 / pi / r / L g = r / r_a / L**2
Source code injaxley/utils/cell_utils.py
def compute_coupling_cond_branchpoint(rad, r_a, l):\n r\"\"\"Return the coupling conductance between one compartment and a comp with l=0.\n\n From https://en.wikipedia.org/wiki/Compartmental_neuron_models\n\n If one compartment has l=0.0 then the equations simplify.\n\n R_long = \\sum_i r_a * L_i/2 / crosssection_i\n\n with crosssection = pi * r**2\n\n For a single compartment with L>0, this turns into:\n R_long = r_a * L/2 / crosssection\n\n Then, g_long = crosssection * 2 / L / r_a\n\n Then, the effective conductance is g_long / zylinder_area. So:\n g = pi * r**2 * 2 / L / r_a / 2 / pi / r / L\n g = r / r_a / L**2\n \"\"\"\n return rad / r_a / l**2 * 10**7 # Convert (S / cm / um) -> (mS / cm^2)\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_impact_on_node","title":"compute_impact_on_node(rad, r_a, l)
","text":"Compute the weight with which a compartment influences its node.
In order to satisfy Kirchhoffs current law, the current at a branch point must be proportional to the crosssection of the compartment. We only require proportionality here because the branch point equation reads: g_1 * (V_1 - V_b) + g_2 * (V_2 - V_b) = 0.0
Because R_long = r_a * L/2 / crosssection, we get g_long = crosssection * 2 / L / r_a \\propto rad**2 / L / r_a
This equation can be multiplied by any constant.
Source code injaxley/utils/cell_utils.py
def compute_impact_on_node(rad, r_a, l):\n r\"\"\"Compute the weight with which a compartment influences its node.\n\n In order to satisfy Kirchhoffs current law, the current at a branch point must be\n proportional to the crosssection of the compartment. We only require proportionality\n here because the branch point equation reads:\n `g_1 * (V_1 - V_b) + g_2 * (V_2 - V_b) = 0.0`\n\n Because R_long = r_a * L/2 / crosssection, we get\n g_long = crosssection * 2 / L / r_a \\propto rad**2 / L / r_a\n\n This equation can be multiplied by any constant.\"\"\"\n return rad**2 / r_a / l\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_morphology_indices_in_levels","title":"compute_morphology_indices_in_levels(num_branchpoints, child_belongs_to_branchpoint, par_inds, child_inds)
","text":"Return (row, col) to build the sparse matrix defining the voltage eqs.
This is run at init
, not during runtime.
jaxley/utils/cell_utils.py
def compute_morphology_indices_in_levels(\n num_branchpoints,\n child_belongs_to_branchpoint,\n par_inds,\n child_inds,\n):\n \"\"\"Return (row, col) to build the sparse matrix defining the voltage eqs.\n\n This is run at `init`, not during runtime.\n \"\"\"\n branchpoint_inds_parents = jnp.arange(num_branchpoints)\n branchpoint_inds_children = child_belongs_to_branchpoint\n branch_inds_parents = par_inds\n branch_inds_children = child_inds\n\n children = jnp.stack([branch_inds_children, branchpoint_inds_children])\n parents = jnp.stack([branch_inds_parents, branchpoint_inds_parents])\n\n return {\"children\": children.T, \"parents\": parents.T}\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.convert_point_process_to_distributed","title":"convert_point_process_to_distributed(current, radius, length)
","text":"Convert current point process (nA) to distributed current (uA/cm2).
This function gets called for synapses and for external stimuli.
Parameters:
Name Type Description Defaultcurrent
ndarray
Current in nA
.
radius
ndarray
Compartment radius in um
.
length
ndarray
Compartment length in um
.
Current in uA/cm2
.
jaxley/utils/cell_utils.py
def convert_point_process_to_distributed(\n current: jnp.ndarray, radius: jnp.ndarray, length: jnp.ndarray\n) -> jnp.ndarray:\n \"\"\"Convert current point process (nA) to distributed current (uA/cm2).\n\n This function gets called for synapses and for external stimuli.\n\n Args:\n current: Current in `nA`.\n radius: Compartment radius in `um`.\n length: Compartment length in `um`.\n\n Return:\n Current in `uA/cm2`.\n \"\"\"\n area = 2 * pi * radius * length\n current /= area # nA / um^2\n return current * 100_000 # Convert (nA / um^2) to (uA / cm^2)\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.equal_segments","title":"equal_segments(branch_property, ncomp_per_branch)
","text":"Generates segments where some property is the same in each segment.
Parameters:
Name Type Description Defaultbranch_property
list
List of values of the property in each branch. Should have len(branch_property) == num_branches
.
jaxley/utils/cell_utils.py
def equal_segments(branch_property: list, ncomp_per_branch: int):\n \"\"\"Generates segments where some property is the same in each segment.\n\n Args:\n branch_property: List of values of the property in each branch. Should have\n `len(branch_property) == num_branches`.\n \"\"\"\n assert isinstance(branch_property, list), \"branch_property must be a list.\"\n return jnp.asarray([branch_property] * ncomp_per_branch).T\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.get_num_neighbours","title":"get_num_neighbours(num_children, ncomp_per_branch, num_branches)
","text":"Number of neighbours of each compartment.
Source code injaxley/utils/cell_utils.py
def get_num_neighbours(\n num_children: jnp.ndarray,\n ncomp_per_branch: int,\n num_branches: int,\n):\n \"\"\"\n Number of neighbours of each compartment.\n \"\"\"\n num_neighbours = 2 * jnp.ones((num_branches * ncomp_per_branch))\n num_neighbours = num_neighbours.at[ncomp_per_branch - 1].set(1.0)\n num_neighbours = num_neighbours.at[jnp.arange(num_branches) * ncomp_per_branch].set(\n num_children + 1.0\n )\n return num_neighbours\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.group_and_sum","title":"group_and_sum(values_to_sum, inds_to_group_by, num_branchpoints)
","text":"Group values by whether they have the same integer and sum values within group.
This is used to construct the last diagonals at the branch points.
Written by ChatGPT.
Source code injaxley/utils/cell_utils.py
def group_and_sum(\n values_to_sum: jnp.ndarray, inds_to_group_by: jnp.ndarray, num_branchpoints: int\n) -> jnp.ndarray:\n \"\"\"Group values by whether they have the same integer and sum values within group.\n\n This is used to construct the last diagonals at the branch points.\n\n Written by ChatGPT.\n \"\"\"\n # Initialize an array to hold the sum of each group\n group_sums = jnp.zeros(num_branchpoints)\n\n # `.at[inds]` requires that `inds` is not empty, so we need an if-case here.\n # `len(inds) == 0` is the case for branches and compartments.\n if num_branchpoints > 0:\n group_sums = group_sums.at[inds_to_group_by].add(values_to_sum)\n\n return group_sums\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.interpolate_xyzr","title":"interpolate_xyzr(loc, coords)
","text":"Perform a linear interpolation between xyz-coordinates.
Parameters:
Name Type Description Defaultloc
float
The location in [0,1] along the branch.
requiredcoords
ndarray
Array containing the reconstructed xyzr points of the branch.
required ReturnInterpolated xyz coordinate at loc
, shape `(3,).
jaxley/utils/cell_utils.py
def interpolate_xyzr(loc: float, coords: np.ndarray):\n \"\"\"Perform a linear interpolation between xyz-coordinates.\n\n Args:\n loc: The location in [0,1] along the branch.\n coords: Array containing the reconstructed xyzr points of the branch.\n\n Return:\n Interpolated xyz coordinate at `loc`, shape `(3,).\n \"\"\"\n dl = np.sqrt(np.sum(np.diff(coords[:, :3], axis=0) ** 2, axis=1))\n pathlens = np.insert(np.cumsum(dl), 0, 0) # cummulative length of sections\n norm_pathlens = pathlens / np.maximum(1e-8, pathlens[-1]) # norm lengths to [0,1].\n\n return v_interp(loc, norm_pathlens, coords)\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.linear_segments","title":"linear_segments(initial_val, endpoint_vals, parents, ncomp_per_branch)
","text":"Generates segments where some property is linearly interpolated.
Parameters:
Name Type Description Defaultinitial_val
float
The value at the tip of the soma.
requiredendpoint_vals
list
The value at the endpoints of each branch.
required Source code injaxley/utils/cell_utils.py
def linear_segments(\n initial_val: float, endpoint_vals: list, parents: jnp.ndarray, ncomp_per_branch: int\n):\n \"\"\"Generates segments where some property is linearly interpolated.\n\n Args:\n initial_val: The value at the tip of the soma.\n endpoint_vals: The value at the endpoints of each branch.\n \"\"\"\n branch_property = endpoint_vals + [initial_val]\n num_branches = len(parents)\n # Compute radiuses by linear interpolation.\n endpoint_radiuses = jnp.asarray(branch_property)\n\n def compute_rad(branch_ind, loc):\n start = endpoint_radiuses[parents[branch_ind]]\n end = endpoint_radiuses[branch_ind]\n return (end - start) * loc + start\n\n branch_inds_of_each_comp = jnp.tile(jnp.arange(num_branches), ncomp_per_branch)\n locs_of_each_comp = jnp.linspace(1, 0, ncomp_per_branch).repeat(num_branches)\n rad_of_each_comp = compute_rad(branch_inds_of_each_comp, locs_of_each_comp)\n\n return jnp.reshape(rad_of_each_comp, (ncomp_per_branch, num_branches)).T\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.loc_of_index","title":"loc_of_index(global_comp_index, global_branch_index, ncomp_per_branch)
","text":"Return location corresponding to global compartment index.
Source code injaxley/utils/cell_utils.py
def loc_of_index(global_comp_index, global_branch_index, ncomp_per_branch):\n \"\"\"Return location corresponding to global compartment index.\"\"\"\n cumsum_ncomp = cumsum_leading_zero(ncomp_per_branch)\n index = global_comp_index - cumsum_ncomp[global_branch_index]\n ncomp = ncomp_per_branch[global_branch_index]\n return (0.5 + index) / ncomp\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.local_index_of_loc","title":"local_index_of_loc(loc, global_branch_ind, ncomp_per_branch)
","text":"Returns the local index of a comp given a loc [0, 1] and the index of a branch.
This is used because we specify locations such as synapses as a value between 0 and 1. We have to convert this onto a discrete segment here.
Parameters:
Name Type Description Defaultbranch_ind
Index of the branch.
requiredloc
float
Location (in [0, 1]) along that branch.
requiredncomp_per_branch
int
Number of segments of each branch.
requiredReturns:
Type Descriptionint
The local index of the compartment.
Source code injaxley/utils/cell_utils.py
def local_index_of_loc(\n loc: float, global_branch_ind: int, ncomp_per_branch: int\n) -> int:\n \"\"\"Returns the local index of a comp given a loc [0, 1] and the index of a branch.\n\n This is used because we specify locations such as synapses as a value between 0 and\n 1. We have to convert this onto a discrete segment here.\n\n Args:\n branch_ind: Index of the branch.\n loc: Location (in [0, 1]) along that branch.\n ncomp_per_branch: Number of segments of each branch.\n\n Returns:\n The local index of the compartment.\n \"\"\"\n ncomp = ncomp_per_branch[global_branch_ind] # only for convenience.\n possible_locs = np.linspace(0.5 / ncomp, 1 - 0.5 / ncomp, ncomp)\n ind_along_branch = np.argmin(np.abs(possible_locs - loc))\n return ind_along_branch\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.merge_cells","title":"merge_cells(cumsum_num_branches, cumsum_num_branchpoints, arrs, exclude_first=True)
","text":"Build full list of which branches are solved in which iteration.
From the branching pattern of single cells, this \u201cmerges\u201d them into a single ordering of branches.
Parameters:
Name Type Description Defaultcumsum_num_branches
List[int]
cumulative number of branches. E.g., for three cells with 10, 15, and 5 branches respectively, this will should be a list containing [0, 10, 25, 30]
.
arrs
List[List[ndarray]]
A list of a list of arrays that should be merged.
requiredexclude_first
bool
If True
, the first element of each list in arrs
will remain unchanged. Useful if a -1
(which indicates \u201cno parent\u201d) entry should not be changed.
True
Returns:
Type Descriptionndarray
A list of arrays which contain the branch indices that are computed at each
ndarray
level (i.e., iteration).
Source code injaxley/utils/cell_utils.py
def merge_cells(\n cumsum_num_branches: List[int],\n cumsum_num_branchpoints: List[int],\n arrs: List[List[np.ndarray]],\n exclude_first: bool = True,\n) -> np.ndarray:\n \"\"\"\n Build full list of which branches are solved in which iteration.\n\n From the branching pattern of single cells, this \"merges\" them into a single\n ordering of branches.\n\n Args:\n cumsum_num_branches: cumulative number of branches. E.g., for three cells with\n 10, 15, and 5 branches respectively, this will should be a list containing\n `[0, 10, 25, 30]`.\n arrs: A list of a list of arrays that should be merged.\n exclude_first: If `True`, the first element of each list in `arrs` will remain\n unchanged. Useful if a `-1` (which indicates \"no parent\") entry should not\n be changed.\n\n Returns:\n A list of arrays which contain the branch indices that are computed at each\n level (i.e., iteration).\n \"\"\"\n ps = []\n for i, att in enumerate(arrs):\n p = att\n if exclude_first:\n raise NotImplementedError\n p = [p[0]] + [p_in_level + cumsum_num_branches[i] for p_in_level in p[1:]]\n else:\n p = [\n p_in_level\n + np.asarray([cumsum_num_branches[i], cumsum_num_branchpoints[i]])\n for p_in_level in p\n ]\n ps.append(p)\n\n max_len = max([len(att) for att in arrs])\n combined_parents_in_level = []\n for i in range(max_len):\n current_ps = []\n for p in ps:\n if len(p) > i:\n current_ps.append(p[i])\n combined_parents_in_level.append(np.concatenate(current_ps))\n\n return combined_parents_in_level\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.params_to_pstate","title":"params_to_pstate(params, indices_set_by_trainables)
","text":"Make outputs get_parameters()
conform with outputs of .data_set()
.
make_trainable()
followed by params=get_parameters()
does not return indices because these indices would also be differentiated by jax.grad
(as soon as the params
are passed to def simulate(params)
. Therefore, in jx.integrate
, we run the function to add indices to the dict. The outputs of params_to_pstate
are of the same shape as the outputs of .data_set()
.
jaxley/utils/cell_utils.py
def params_to_pstate(\n params: List[Dict[str, jnp.ndarray]],\n indices_set_by_trainables: List[jnp.ndarray],\n):\n \"\"\"Make outputs `get_parameters()` conform with outputs of `.data_set()`.\n\n `make_trainable()` followed by `params=get_parameters()` does not return indices\n because these indices would also be differentiated by `jax.grad` (as soon as\n the `params` are passed to `def simulate(params)`. Therefore, in `jx.integrate`,\n we run the function to add indices to the dict. The outputs of `params_to_pstate`\n are of the same shape as the outputs of `.data_set()`.\"\"\"\n return [\n {\"key\": list(p.keys())[0], \"val\": list(p.values())[0], \"indices\": i}\n for p, i in zip(params, indices_set_by_trainables)\n ]\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.query_channel_states_and_params","title":"query_channel_states_and_params(d, keys, idcs)
","text":"Get dict with subset of keys and values from d.
This is used to restrict a dict where every item contains all states to only the ones that are relevant for the channel. E.g.
states = {'eCa': Array([ 0., 0., nan]}
will be states = {'eCa': Array([ 0., 0.]}
Only loops over necessary keys, as opposed to looping over d.items()
.
jaxley/utils/cell_utils.py
def query_channel_states_and_params(d, keys, idcs):\n \"\"\"Get dict with subset of keys and values from d.\n\n This is used to restrict a dict where every item contains __all__ states to only\n the ones that are relevant for the channel. E.g.\n\n ```states = {'eCa': Array([ 0., 0., nan]}```\n\n will be\n ```states = {'eCa': Array([ 0., 0.]}```\n\n Only loops over necessary keys, as opposed to looping over `d.items()`.\"\"\"\n return dict(zip(keys, (v[idcs] for v in map(d.get, keys))))\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.remap_to_consecutive","title":"remap_to_consecutive(arr)
","text":"Maps an array of integers to an array of consecutive integers.
E.g. [0, 0, 1, 4, 4, 6, 6] -> [0, 0, 1, 2, 2, 3, 3]
jaxley/utils/cell_utils.py
def remap_to_consecutive(arr):\n \"\"\"Maps an array of integers to an array of consecutive integers.\n\n E.g. `[0, 0, 1, 4, 4, 6, 6] -> [0, 0, 1, 2, 2, 3, 3]`\n \"\"\"\n _, inverse_indices = jnp.unique(arr, return_inverse=True)\n return inverse_indices\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.compute_rotation_matrix","title":"compute_rotation_matrix(axis, angle)
","text":"Return the rotation matrix associated with counterclockwise rotation about the given axis by the given angle.
Can be used to rotate a coordinate vector by multiplying it with the rotation matrix.
Parameters:
Name Type Description Defaultaxis
ndarray
The axis of rotation.
requiredangle
float
The angle of rotation in radians.
requiredReturns:
Type Descriptionndarray
A 3x3 rotation matrix.
Source code injaxley/utils/plot_utils.py
def compute_rotation_matrix(axis: ndarray, angle: float) -> ndarray:\n \"\"\"\n Return the rotation matrix associated with counterclockwise rotation about\n the given axis by the given angle.\n\n Can be used to rotate a coordinate vector by multiplying it with the rotation\n matrix.\n\n Args:\n axis: The axis of rotation.\n angle: The angle of rotation in radians.\n\n Returns:\n A 3x3 rotation matrix.\n \"\"\"\n axis = axis / np.sqrt(np.dot(axis, axis))\n a = np.cos(angle / 2.0)\n b, c, d = -axis * np.sin(angle / 2.0)\n aa, bb, cc, dd = a * a, b * b, c * c, d * d\n bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d\n return np.array(\n [\n [aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],\n [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],\n [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc],\n ]\n )\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.create_cone_frustum_mesh","title":"create_cone_frustum_mesh(length, radius_bottom, radius_top, bottom_dome=False, top_dome=False, resolution=100)
","text":"Generates mesh points for a cone frustum, with optional domes at either end.
This is used to render the traced morphology in 3D (and to project it to 2D) as part of plot_morph
. Sections between two traced coordinates with two different radii can be represented by a cone frustum. Additionally, the ends of the frustum can be capped with hemispheres to ensure that two neighbouring frustums are connected smoothly (like ball joints).
Parameters:
Name Type Description Defaultlength
float
The length of the frustum.
requiredradius_bottom
float
The radius of the bottom of the frustum.
requiredradius_top
float
The radius of the top of the frustum.
requiredbottom_dome
bool
If True, a dome is added to the bottom of the frustum. The dome is a hemisphere with radius radius_bottom
.
False
top_dome
bool
If True, a dome is added to the top of the frustum. The dome is a hemisphere with radius radius_top
.
False
resolution
int
defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.
100
Returns:
Type Descriptionndarray
An array of mesh points.
Source code injaxley/utils/plot_utils.py
def create_cone_frustum_mesh(\n length: float,\n radius_bottom: float,\n radius_top: float,\n bottom_dome: bool = False,\n top_dome: bool = False,\n resolution: int = 100,\n) -> ndarray:\n \"\"\"Generates mesh points for a cone frustum, with optional domes at either end.\n\n This is used to render the traced morphology in 3D (and to project it to 2D)\n as part of `plot_morph`. Sections between two traced coordinates with two\n different radii can be represented by a cone frustum. Additionally, the ends\n of the frustum can be capped with hemispheres to ensure that two neighbouring\n frustums are connected smoothly (like ball joints).\n\n Args:\n length: The length of the frustum.\n radius_bottom: The radius of the bottom of the frustum.\n radius_top: The radius of the top of the frustum.\n bottom_dome: If True, a dome is added to the bottom of the frustum.\n The dome is a hemisphere with radius `radius_bottom`.\n top_dome: If True, a dome is added to the top of the frustum.\n The dome is a hemisphere with radius `radius_top`.\n resolution: defines the resolution of the mesh.\n If too low (typically <10), can result in errors.\n Useful too have a simpler mesh for plotting.\n\n Returns:\n An array of mesh points.\n \"\"\"\n\n t = np.linspace(0, 2 * np.pi, resolution)\n\n # Determine the total height including domes\n total_height = length\n total_height += radius_bottom if bottom_dome else 0\n total_height += radius_top if top_dome else 0\n\n z = np.linspace(0, total_height, resolution)\n t_grid, z_coords = np.meshgrid(t, z)\n\n # Initialize arrays\n x_coords = np.zeros_like(t_grid)\n y_coords = np.zeros_like(t_grid)\n r_coords = np.zeros_like(t_grid)\n\n # Bottom hemisphere\n if bottom_dome:\n dome_mask = z_coords < radius_bottom\n arg = 1 - z_coords[dome_mask] / radius_bottom\n arg[np.isclose(arg, 1, atol=1e-6, rtol=1e-6)] = 1\n arg[np.isclose(arg, -1, atol=1e-6, rtol=1e-6)] = -1\n phi = np.arccos(1 - z_coords[dome_mask] / radius_bottom)\n r_coords[dome_mask] = radius_bottom * np.sin(phi)\n z_coords[dome_mask] = z_coords[dome_mask]\n\n # Frustum\n frustum_start = radius_bottom if bottom_dome else 0\n frustum_end = total_height - (radius_top if top_dome else 0)\n frustum_mask = (z_coords >= frustum_start) & (z_coords <= frustum_end)\n z_frustum = z_coords[frustum_mask] - frustum_start\n r_coords[frustum_mask] = radius_bottom + (radius_top - radius_bottom) * (\n z_frustum / length\n )\n\n # Top hemisphere\n if top_dome:\n dome_mask = z_coords > (total_height - radius_top)\n arg = (z_coords[dome_mask] - (total_height - radius_top)) / radius_top\n arg[np.isclose(arg, 1, atol=1e-6, rtol=1e-6)] = 1\n arg[np.isclose(arg, -1, atol=1e-6, rtol=1e-6)] = -1\n phi = np.arccos(arg)\n r_coords[dome_mask] = radius_top * np.sin(phi)\n\n x_coords = r_coords * np.cos(t_grid)\n y_coords = r_coords * np.sin(t_grid)\n\n return np.stack([x_coords, y_coords, z_coords])\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.create_cylinder_mesh","title":"create_cylinder_mesh(length, radius, resolution=100)
","text":"Generates mesh points for a cylinder.
This is used to render cylindrical compartments in 3D (and to project it to 2D) as part of plot_comps
.
Parameters:
Name Type Description Defaultlength
float
The length of the cylinder.
requiredradius
float
The radius of the cylinder.
requiredresolution
int
defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.
100
Returns:
Type Descriptionndarray
An array of mesh points.
Source code injaxley/utils/plot_utils.py
def create_cylinder_mesh(\n length: float, radius: float, resolution: int = 100\n) -> ndarray:\n \"\"\"Generates mesh points for a cylinder.\n\n This is used to render cylindrical compartments in 3D (and to project it to 2D)\n as part of `plot_comps`.\n\n Args:\n length: The length of the cylinder.\n radius: The radius of the cylinder.\n resolution: defines the resolution of the mesh.\n If too low (typically <10), can result in errors.\n Useful too have a simpler mesh for plotting.\n\n Returns:\n An array of mesh points.\n \"\"\"\n # Define cylinder\n t = np.linspace(0, 2 * np.pi, resolution)\n z_coords = np.linspace(-length / 2, length / 2, resolution)\n t_grid, z_coords = np.meshgrid(t, z_coords)\n\n x_coords = radius * np.cos(t_grid)\n y_coords = radius * np.sin(t_grid)\n return np.stack([x_coords, y_coords, z_coords])\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.create_sphere_mesh","title":"create_sphere_mesh(radius, resolution=100)
","text":"Generates mesh points for a sphere.
This is used to render spherical compartments in 3D (and to project it to 2D) as part of plot_comps
.
Parameters:
Name Type Description Defaultradius
float
The radius of the sphere.
requiredresolution
int
defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.
100
Returns:
Type Descriptionndarray
An array of mesh points.
Source code injaxley/utils/plot_utils.py
def create_sphere_mesh(radius: float, resolution: int = 100) -> np.ndarray:\n \"\"\"Generates mesh points for a sphere.\n\n This is used to render spherical compartments in 3D (and to project it to 2D)\n as part of `plot_comps`.\n\n Args:\n radius: The radius of the sphere.\n resolution: defines the resolution of the mesh.\n If too low (typically <10), can result in errors.\n Useful too have a simpler mesh for plotting.\n\n Returns:\n An array of mesh points.\n \"\"\"\n phi = np.linspace(0, np.pi, resolution)\n theta = np.linspace(0, 2 * np.pi, resolution)\n\n # Create a 2D meshgrid for phi and theta\n phi_coords, theta_coords = np.meshgrid(phi, theta)\n\n # Convert spherical coordinates to Cartesian coordinates\n x_coords = radius * np.sin(phi_coords) * np.cos(theta_coords)\n y_coords = radius * np.sin(phi_coords) * np.sin(theta_coords)\n z_coords = radius * np.cos(phi_coords)\n\n return np.stack([x_coords, y_coords, z_coords])\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.extract_outline","title":"extract_outline(points)
","text":"Get the outline of a 2D/3D shape.
Extracts the subset of points which form the convex hull, i.e. the outline of the input points.
Parameters:
Name Type Description Defaultpoints
ndarray
An array of points / corrdinates.
requiredReturns:
Type Descriptionndarray
An array of points which form the convex hull.
Source code injaxley/utils/plot_utils.py
def extract_outline(points: ndarray) -> ndarray:\n \"\"\"Get the outline of a 2D/3D shape.\n\n Extracts the subset of points which form the convex hull, i.e. the outline of\n the input points.\n\n Args:\n points: An array of points / corrdinates.\n\n Returns:\n An array of points which form the convex hull.\n \"\"\"\n hull = ConvexHull(points)\n hull_points = points[hull.vertices]\n return hull_points\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.plot_comps","title":"plot_comps(module_or_view, dims=(0, 1), col='k', ax=None, comp_plot_kwargs={}, true_comp_length=True, resolution=100)
","text":"Plot compartmentalized neural morphology.
Plots the projection of the cylindrical compartments.
Parameters:
Name Type Description Defaultmodule_or_view
Union[Module, View]
The module or view to plot.
requireddims
Tuple[int]
The dimensions to plot / to project the cylinder onto, i.e. [0,1] xy-plane or [0,1,2] for 3D.
(0, 1)
col
str
The color for all compartments
'k'
ax
Optional[Axes]
The matplotlib axis to plot on.
None
comp_plot_kwargs
Dict
The plot kwargs for plt.fill.
{}
true_comp_length
bool
If True, the length of the compartment is used, i.e. the length of the traced neurite. This means for zig-zagging neurites the cylinders will be longer than the straight-line distance between the start and end point of the neurite. This can lead to overlapping and miss-aligned cylinders. Setting this False will use the straight-line distance instead for nicer plots.
True
resolution
int
defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.
100
Returns:
Type DescriptionAxes
Plot of the compartmentalized morphology.
Source code injaxley/utils/plot_utils.py
def plot_comps(\n module_or_view: Union[\"jx.Module\", \"jx.View\"],\n dims: Tuple[int] = (0, 1),\n col: str = \"k\",\n ax: Optional[Axes] = None,\n comp_plot_kwargs: Dict = {},\n true_comp_length: bool = True,\n resolution: int = 100,\n) -> Axes:\n \"\"\"Plot compartmentalized neural morphology.\n\n Plots the projection of the cylindrical compartments.\n\n Args:\n module_or_view: The module or view to plot.\n dims: The dimensions to plot / to project the cylinder onto,\n i.e. [0,1] xy-plane or [0,1,2] for 3D.\n col: The color for all compartments\n ax: The matplotlib axis to plot on.\n comp_plot_kwargs: The plot kwargs for plt.fill.\n true_comp_length: If True, the length of the compartment is used, i.e. the\n length of the traced neurite. This means for zig-zagging neurites the\n cylinders will be longer than the straight-line distance between the\n start and end point of the neurite. This can lead to overlapping and\n miss-aligned cylinders. Setting this False will use the straight-line\n distance instead for nicer plots.\n resolution: defines the resolution of the mesh.\n If too low (typically <10), can result in errors.\n Useful too have a simpler mesh for plotting.\n\n Returns:\n Plot of the compartmentalized morphology.\n \"\"\"\n if ax is None:\n fig = plt.figure(figsize=(3, 3))\n ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection=\"3d\")\n\n assert not np.any(\n np.isnan(module_or_view.xyzr[0][:, :3])\n ), \"missing xyz coordinates.\"\n if \"x\" not in module_or_view.nodes.columns:\n module_or_view.compute_compartment_centers()\n\n for idx, xyzr in zip(module_or_view._branches_in_view, module_or_view.xyzr):\n locs = xyzr[:, :3]\n if locs.shape[0] == 1: # assume spherical comp\n radius = xyzr[:, -1]\n center = xyzr[0, :3]\n if len(dims) == 3:\n xyz = create_sphere_mesh(radius, resolution)\n ax = plot_mesh(\n xyz,\n np.array([0, 0, 1]),\n center,\n np.array(dims),\n ax,\n color=col,\n **comp_plot_kwargs,\n )\n else:\n ax.add_artist(plt.Circle(locs[0, dims], radius, color=col))\n else:\n lens = np.sqrt(np.nansum(np.diff(locs, axis=0) ** 2, axis=1))\n lens = np.cumsum([0] + lens.tolist())\n comp_ends = v_interp(\n np.linspace(0, lens[-1], module_or_view.ncomp + 1), lens, locs\n ).T\n axes = np.diff(comp_ends, axis=0)\n cylinder_lens = np.sqrt(np.sum(axes**2, axis=1))\n\n branch_df = module_or_view.nodes[\n module_or_view.nodes[\"global_branch_index\"] == idx\n ]\n for l, axis, (i, comp) in zip(cylinder_lens, axes, branch_df.iterrows()):\n center = comp[[\"x\", \"y\", \"z\"]]\n radius = comp[\"radius\"]\n length = comp[\"length\"] if true_comp_length else l\n xyz = create_cylinder_mesh(length, radius, resolution)\n ax = plot_mesh(\n xyz,\n axis,\n center,\n np.array(dims),\n ax,\n color=col,\n **comp_plot_kwargs,\n )\n return ax\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.plot_graph","title":"plot_graph(xyzr, dims=(0, 1), col='k', ax=None, type='line', morph_plot_kwargs={})
","text":"Plot morphology.
Parameters:
Name Type Description Defaultxyzr
ndarray
The coordinates of the morphology.
requireddims
Tuple[int]
Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two or three of them.
(0, 1)
col
str
The color for all branches.
'k'
ax
Optional[Axes]
The matplotlib axis to plot on.
None
type
str
Either line
or scatter
.
'line'
morph_plot_kwargs
Dict
The plot kwargs for plt.plot or plt.scatter.
{}
Source code in jaxley/utils/plot_utils.py
def plot_graph(\n xyzr: ndarray,\n dims: Tuple[int] = (0, 1),\n col: str = \"k\",\n ax: Optional[Axes] = None,\n type: str = \"line\",\n morph_plot_kwargs: Dict = {},\n) -> Axes:\n \"\"\"Plot morphology.\n\n Args:\n xyzr: The coordinates of the morphology.\n dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n two or three of them.\n col: The color for all branches.\n ax: The matplotlib axis to plot on.\n type: Either `line` or `scatter`.\n morph_plot_kwargs: The plot kwargs for plt.plot or plt.scatter.\n \"\"\"\n\n if ax is None:\n fig = plt.figure(figsize=(3, 3))\n ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection=\"3d\")\n\n for coords_of_branch in xyzr:\n points = coords_of_branch[:, dims].T\n\n if \"line\" in type.lower():\n _ = ax.plot(*points, color=col, **morph_plot_kwargs)\n elif \"scatter\" in type.lower():\n _ = ax.scatter(*points, color=col, **morph_plot_kwargs)\n else:\n raise NotImplementedError\n\n return ax\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.plot_mesh","title":"plot_mesh(mesh_points, orientation, center, dims, ax=None, **kwargs)
","text":"Plot the 2D projection of a volume mesh on a cardinal plane.
Project the projection of a cylinder that is oriented in 3D space. - Create cylinder mesh - rotate cylinder mesh to orient it lengthwise along a given orientation vector. - move its center - project onto plane - compute outline of projected mesh. - fill area inside the outline
Parameters:
Name Type Description Defaultmesh_points
ndarray
coordinates of the xyz mesh that define the volume
requiredorientation
ndarray
orientation vector. The cylinder will be oriented along this vector.
requiredcenter
ndarray
The x,y,z coordinates of the center of the cylinder.
requireddims
Tuple[int]
The dimensions to plot / to project the cylinder onto,
requiredax
Axes
The matplotlib axis to plot on.
None
Returns:
Type DescriptionAxes
Plot of the cylinder projection.
Source code injaxley/utils/plot_utils.py
def plot_mesh(\n mesh_points: ndarray,\n orientation: ndarray,\n center: ndarray,\n dims: Tuple[int],\n ax: Axes = None,\n **kwargs,\n) -> Axes:\n \"\"\"Plot the 2D projection of a volume mesh on a cardinal plane.\n\n Project the projection of a cylinder that is oriented in 3D space.\n - Create cylinder mesh\n - rotate cylinder mesh to orient it lengthwise along a given orientation vector.\n - move its center\n - project onto plane\n - compute outline of projected mesh.\n - fill area inside the outline\n\n Args:\n mesh_points: coordinates of the xyz mesh that define the volume\n orientation: orientation vector. The cylinder will be oriented along this vector.\n center: The x,y,z coordinates of the center of the cylinder.\n dims: The dimensions to plot / to project the cylinder onto,\n i.e. [0,1] xy-plane or [0,1,2] for 3D.\n ax: The matplotlib axis to plot on.\n\n Returns:\n Plot of the cylinder projection.\n \"\"\"\n if ax is None:\n fig = plt.figure(figsize=(3, 3))\n ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection=\"3d\")\n\n # Normalize axis vector\n orientation = np.array(orientation)\n orientation = orientation / np.linalg.norm(orientation)\n\n # Create a rotation matrix to align the cylinder with the given axis\n z_axis = np.array([0, 0, 1])\n rotation_axis = np.cross(z_axis, orientation)\n rotation_angle = np.arccos(np.dot(z_axis, orientation))\n\n if np.allclose(rotation_axis, 0):\n rotation_matrix = np.eye(3)\n else:\n rotation_matrix = compute_rotation_matrix(rotation_axis, rotation_angle)\n\n # Rotate mesh\n x_mesh, y_mesh, z_mesh = mesh_points\n rotated_mesh_points = np.dot(\n rotation_matrix,\n np.array([x_mesh.flatten(), y_mesh.flatten(), z_mesh.flatten()]),\n )\n rotated_mesh_points = rotated_mesh_points.reshape(3, -1)\n\n # project onto plane and move\n rotated_mesh_points = rotated_mesh_points[dims]\n rotated_mesh_points += np.array(center)[dims, np.newaxis]\n\n if len(dims) < 3:\n # get outline of cylinder mesh\n mesh_outline = extract_outline(rotated_mesh_points.T).T\n ax.fill(*mesh_outline.reshape(mesh_outline.shape[0], -1), **kwargs)\n else:\n # plot 3d mesh\n ax.plot_surface(*rotated_mesh_points.reshape(*mesh_points.shape), **kwargs)\n return ax\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.plot_morph","title":"plot_morph(module_or_view, dims=(0, 1), col='k', ax=None, resolution=100, morph_plot_kwargs={})
","text":"Plot the detailed morphology.
Plots the traced morphology it was traced. That means at every point that was traced a disc of radius r
is plotted. The outline of the discs are then connected to form the morphology. This means every trace segement can be represented by a cone frustum. To prevent breaks in the morphology, each segement is connected with a ball joint.
Parameters:
Name Type Description Defaultmodule_or_view
Union[Module, View]
The module or view to plot.
requireddims
Tuple[int]
The dimensions to plot / to project the cylinder onto, i.e. [0,1] xy-plane or [0,1,2] for 3D.
(0, 1)
col
str
The color for all branches
'k'
ax
Optional[Axes]
The matplotlib axis to plot on.
None
morph_plot_kwargs
Dict
The plot kwargs for plt.fill.
{}
resolution
int
defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.
100
Returns:
Type DescriptionAxes
Plot of the detailed morphology.
Source code injaxley/utils/plot_utils.py
def plot_morph(\n module_or_view: Union[\"jx.Module\", \"jx.View\"],\n dims: Tuple[int] = (0, 1),\n col: str = \"k\",\n ax: Optional[Axes] = None,\n resolution: int = 100,\n morph_plot_kwargs: Dict = {},\n) -> Axes:\n \"\"\"Plot the detailed morphology.\n\n Plots the traced morphology it was traced. That means at every point that was\n traced a disc of radius `r` is plotted. The outline of the discs are then\n connected to form the morphology. This means every trace segement can be\n represented by a cone frustum. To prevent breaks in the morphology, each\n segement is connected with a ball joint.\n\n Args:\n module_or_view: The module or view to plot.\n dims: The dimensions to plot / to project the cylinder onto,\n i.e. [0,1] xy-plane or [0,1,2] for 3D.\n col: The color for all branches\n ax: The matplotlib axis to plot on.\n morph_plot_kwargs: The plot kwargs for plt.fill.\n\n resolution: defines the resolution of the mesh.\n If too low (typically <10), can result in errors.\n Useful too have a simpler mesh for plotting.\n\n Returns:\n Plot of the detailed morphology.\"\"\"\n if ax is None:\n fig = plt.figure(figsize=(3, 3))\n ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection=\"3d\")\n if len(dims) == 3:\n warn(\n \"rendering large morphologies in 3D can take a while. Consider projecting to 2D instead.\"\n )\n\n assert not np.any(\n np.isnan(module_or_view.xyzr[0][:, :3])\n ), \"missing xyz coordinates.\"\n\n for xyzr in module_or_view.xyzr:\n if len(xyzr) > 1:\n for xyzr1, xyzr2 in zip(xyzr[1:, :], xyzr[:-1, :]):\n dxyz = xyzr2[:3] - xyzr1[:3]\n length = np.sqrt(np.sum(dxyz**2))\n points = create_cone_frustum_mesh(\n length,\n xyzr1[-1],\n xyzr2[-1],\n bottom_dome=True,\n top_dome=True,\n resolution=resolution,\n )\n plot_mesh(\n points,\n dxyz,\n xyzr1[:3],\n np.array(dims),\n color=col,\n ax=ax,\n **morph_plot_kwargs,\n )\n else:\n points = create_cone_frustum_mesh(\n 0,\n xyzr[:, -1],\n xyzr[:, -1],\n bottom_dome=True,\n top_dome=True,\n resolution=resolution,\n )\n plot_mesh(\n points,\n np.ones(3),\n xyzr[0, :3],\n dims=np.array(dims),\n color=col,\n ax=ax,\n **morph_plot_kwargs,\n )\n\n return ax\n
"},{"location":"reference/utils/#jaxley.utils.jax_utils.nested_checkpoint_scan","title":"nested_checkpoint_scan(f, init, xs, length=None, *, nested_lengths, scan_fn=jax.lax.scan, checkpoint_fn=jax.checkpoint)
","text":"A version of lax.scan that supports recursive gradient checkpointing.
Code taken from: https://github.com/google/jax/issues/2139
The interface of nested_checkpoint_scan
exactly matches lax.scan, except for the required nested_lengths
argument.
The key feature of nested_checkpoint_scan
is that gradient calculations require O(max(nested_lengths)) memory, vs O(prod(nested_lengths)) for unnested scans, which it achieves by re-evaluating the forward pass len(nested_lengths) - 1
times.
nested_checkpoint_scan
reduces to lax.scan
when nested_lengths
has a single element.
Parameters:
Name Type Description Defaultf
Callable[[Carry, Dict[str, ndarray]], Tuple[Carry, Output]]
function to scan over.
requiredinit
Carry
initial value.
requiredxs
Dict[str, ndarray]
scanned over values.
requiredlength
Optional[int]
leading length of all dimensions
None
nested_lengths
Sequence[int]
required list of lengths to scan over for each level of checkpointing. The product of nested_lengths must match length (if provided) and the size of the leading axis for all arrays in xs
.
scan_fn
function matching the API of lax.scan
scan
checkpoint_fn
Callable[[Func], Func]
function matching the API of jax.checkpoint.
checkpoint
Source code in jaxley/utils/jax_utils.py
def nested_checkpoint_scan(\n f: Callable[[Carry, Dict[str, jnp.ndarray]], Tuple[Carry, Output]],\n init: Carry,\n xs: Dict[str, jnp.ndarray],\n length: Optional[int] = None,\n *,\n nested_lengths: Sequence[int],\n scan_fn=jax.lax.scan,\n checkpoint_fn: Callable[[Func], Func] = jax.checkpoint,\n):\n \"\"\"A version of lax.scan that supports recursive gradient checkpointing.\n\n Code taken from: https://github.com/google/jax/issues/2139\n\n The interface of `nested_checkpoint_scan` exactly matches lax.scan, except for\n the required `nested_lengths` argument.\n\n The key feature of `nested_checkpoint_scan` is that gradient calculations\n require O(max(nested_lengths)) memory, vs O(prod(nested_lengths)) for unnested\n scans, which it achieves by re-evaluating the forward pass\n `len(nested_lengths) - 1` times.\n\n `nested_checkpoint_scan` reduces to `lax.scan` when `nested_lengths` has a\n single element.\n\n Args:\n f: function to scan over.\n init: initial value.\n xs: scanned over values.\n length: leading length of all dimensions\n nested_lengths: required list of lengths to scan over for each level of\n checkpointing. The product of nested_lengths must match length (if\n provided) and the size of the leading axis for all arrays in ``xs``.\n scan_fn: function matching the API of lax.scan\n checkpoint_fn: function matching the API of jax.checkpoint.\n \"\"\"\n if length is not None and length != math.prod(nested_lengths):\n raise ValueError(f\"inconsistent {length=} and {nested_lengths=}\")\n\n def nested_reshape(x):\n x = jnp.asarray(x)\n new_shape = tuple(nested_lengths) + x.shape[1:]\n return x.reshape(new_shape)\n\n sub_xs = jax.tree_util.tree_map(nested_reshape, xs)\n return _inner_nested_scan(f, init, sub_xs, nested_lengths, scan_fn, checkpoint_fn)\n
"},{"location":"reference/utils/#jaxley.utils.syn_utils.gather_synapes","title":"gather_synapes(number_of_compartments, post_syn_comp_inds, current_each_synapse_voltage_term, current_each_synapse_constant_term)
","text":"Compute current at the post synapse.
All this does it that it sums the synaptic currents that come into a particular compartment. It returns an array of as many elements as there are compartments.
Source code injaxley/utils/syn_utils.py
def gather_synapes(\n number_of_compartments: jnp.ndarray,\n post_syn_comp_inds: np.ndarray,\n current_each_synapse_voltage_term: jnp.ndarray,\n current_each_synapse_constant_term: jnp.ndarray,\n) -> Tuple[jnp.ndarray, jnp.ndarray]:\n \"\"\"Compute current at the post synapse.\n\n All this does it that it sums the synaptic currents that come into a particular\n compartment. It returns an array of as many elements as there are compartments.\n \"\"\"\n incoming_currents_voltages = jnp.zeros((number_of_compartments,))\n incoming_currents_contant = jnp.zeros((number_of_compartments,))\n\n dnums = ScatterDimensionNumbers(\n update_window_dims=(),\n inserted_window_dims=(0,),\n scatter_dims_to_operand_dims=(0,),\n )\n incoming_currents_voltages = scatter_add(\n incoming_currents_voltages,\n post_syn_comp_inds[:, None],\n current_each_synapse_voltage_term,\n dnums,\n )\n incoming_currents_contant = scatter_add(\n incoming_currents_contant,\n post_syn_comp_inds[:, None],\n current_each_synapse_constant_term,\n dnums,\n )\n return incoming_currents_voltages, incoming_currents_contant\n
"},{"location":"tutorial/00_jaxley_api/","title":"Key concepts in Jaxley","text":"In this tutorial, we will introduce you to the basic concepts of Jaxley. You will learn about:
Here is a code snippet which you will learn to understand in this tutorial:
import jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import connect\nimport matplotlib.pyplot as plt\nimport numpy as np\n\n\n# Assembling different Modules into a Network\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=1)\ncell = jx.Cell(branch, parents=[-1, 0, 0])\nnet = jx.Network([cell]*3)\n\n# Navigating and inspecting the Modules using Views\ncell0 = net.cell(0)\ncell0.nodes\n\n# How to group together parts of Modules\nnet.cell(1).add_to_group(\"cell1\")\n\n# inserting channels in the membrane\nwith net.cell(0) as cell0:\n cell0.insert(Na())\n cell0.insert(K())\n\n# connecting two cells using a Synapse\npre_comp = cell0.branch(1).comp(0)\npost_comp = net.cell1.branch(0).comp(0)\n\nconnect(pre_comp, post_comp)\n
First, we import the relevant libraries:
from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import connect\nimport matplotlib.pyplot as plt\nimport numpy as np\n
"},{"location":"tutorial/00_jaxley_api/#modules","title":"Modules","text":"In Jaxley, we heavily rely on the concept of Modules to build biophyiscal models of neural systems at various scales. Jaxley implements four types of Modules: - Compartment
- Branch
- Cell
- Network
Modules can be connected together to build increasingly detailed and complex models. Compartment
-> Branch
-> Cell
-> Network
.
Compartment
s are the atoms of biophysical models in Jaxley. All mechanisms and synaptic connections live on the level of Compartment
s and can already be simulated using jx.integrate
on their own. Everything you do in Jaxley starts with a Compartment
.
comp = jx.Compartment() # single compartment model.\n
Mutliple Compartments
can be connected together to form longer, linear cables, which we call Branch
es and are equivalent to sections in NEURON
.
ncomp = 4\nbranch = jx.Branch([comp] * ncomp)\n
In order to construct cell morphologies in Jaxley, multiple Branches
can to be connected together as a Cell
:
# -1 indicates that the first branch has no parent branch.\n# The other two branches both have the 0-eth branch as their parent.\nparents = [-1, 0, 0]\ncell = jx.Cell([branch] * len(parents), parents)\n
Finally, several Cell
s can be grouped together to form a Network
, which can than be connected together using Synpase
s.
ncells = 2\nnet = jx.Network([cell]*ncells)\n\nnet.shape # shows you the num_cells, num_branches, num_comps\n
(2, 6, 24)\n
Every module tracks information about its current state and parameters in two Dataframes called nodes
and edges
. nodes
contains all the information that we associate with compartments in the model (each row corresponds to one compartment) and edges
tracks all the information relevant to synapses.
This means that you can easily keep track of the current state of your Module
and how it changes at all times.
net.nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v global_cell_index global_branch_index global_comp_index controlled_by_param 0 0 0 0 10.0 1.0 5000.0 1.0 -70.0 0 0 0 0 1 0 0 1 10.0 1.0 5000.0 1.0 -70.0 0 0 1 0 2 0 0 2 10.0 1.0 5000.0 1.0 -70.0 0 0 2 0 3 0 0 3 10.0 1.0 5000.0 1.0 -70.0 0 0 3 0 4 0 1 0 10.0 1.0 5000.0 1.0 -70.0 0 1 4 0 5 0 1 1 10.0 1.0 5000.0 1.0 -70.0 0 1 5 0 6 0 1 2 10.0 1.0 5000.0 1.0 -70.0 0 1 6 0 7 0 1 3 10.0 1.0 5000.0 1.0 -70.0 0 1 7 0 8 0 2 0 10.0 1.0 5000.0 1.0 -70.0 0 2 8 0 9 0 2 1 10.0 1.0 5000.0 1.0 -70.0 0 2 9 0 10 0 2 2 10.0 1.0 5000.0 1.0 -70.0 0 2 10 0 11 0 2 3 10.0 1.0 5000.0 1.0 -70.0 0 2 11 0 12 1 0 0 10.0 1.0 5000.0 1.0 -70.0 1 3 12 0 13 1 0 1 10.0 1.0 5000.0 1.0 -70.0 1 3 13 0 14 1 0 2 10.0 1.0 5000.0 1.0 -70.0 1 3 14 0 15 1 0 3 10.0 1.0 5000.0 1.0 -70.0 1 3 15 0 16 1 1 0 10.0 1.0 5000.0 1.0 -70.0 1 4 16 0 17 1 1 1 10.0 1.0 5000.0 1.0 -70.0 1 4 17 0 18 1 1 2 10.0 1.0 5000.0 1.0 -70.0 1 4 18 0 19 1 1 3 10.0 1.0 5000.0 1.0 -70.0 1 4 19 0 20 1 2 0 10.0 1.0 5000.0 1.0 -70.0 1 5 20 0 21 1 2 1 10.0 1.0 5000.0 1.0 -70.0 1 5 21 0 22 1 2 2 10.0 1.0 5000.0 1.0 -70.0 1 5 22 0 23 1 2 3 10.0 1.0 5000.0 1.0 -70.0 1 5 23 0 net.edges.head() # this is currently empty since we have not made any connections yet\n
global_edge_index global_pre_comp_index global_post_comp_index pre_locs post_locs type type_ind"},{"location":"tutorial/00_jaxley_api/#views","title":"Views","text":"Since these Module
s can become very complex, Jaxley utilizes so called View
s to make working with Module
s easy and intuitive.
The simplest way to navigate Modules is by navigating them via the hierachy that we introduced above. A View
is what you get when you index into the module. For example, for a Network
:
net.cell(0)\n
View with 0 different channels. Use `.nodes` for details.\n
Views behave very similarly to Module
s, i.e. the cell(0)
(the 0th cell of the network) behaves like the cell
we instantiated earlier. As such, cell(0)
also has a nodes
attribute, which keeps track of it\u2019s part of the network:
net.cell(0).nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v global_cell_index global_branch_index global_comp_index controlled_by_param 0 0 0 0 10.0 1.0 5000.0 1.0 -70.0 0 0 0 0 1 0 0 1 10.0 1.0 5000.0 1.0 -70.0 0 0 1 0 2 0 0 2 10.0 1.0 5000.0 1.0 -70.0 0 0 2 0 3 0 0 3 10.0 1.0 5000.0 1.0 -70.0 0 0 3 0 4 0 1 0 10.0 1.0 5000.0 1.0 -70.0 0 1 4 0 5 0 1 1 10.0 1.0 5000.0 1.0 -70.0 0 1 5 0 6 0 1 2 10.0 1.0 5000.0 1.0 -70.0 0 1 6 0 7 0 1 3 10.0 1.0 5000.0 1.0 -70.0 0 1 7 0 8 0 2 0 10.0 1.0 5000.0 1.0 -70.0 0 2 8 0 9 0 2 1 10.0 1.0 5000.0 1.0 -70.0 0 2 9 0 10 0 2 2 10.0 1.0 5000.0 1.0 -70.0 0 2 10 0 11 0 2 3 10.0 1.0 5000.0 1.0 -70.0 0 2 11 0 Let\u2019s use View
s to visualize only parts of the Network
. Before we do that, we create x, y, and z coordinates for the Network
:
# Compute xyz coordinates of the cells.\nnet.compute_xyz()\n\n# Move cells (since they are placed on top of each other by default).\nnet.cell(0).move(y=30)\n
We can now visualize the entire net
(i.e., the entire Module
) with the .vis()
method\u2026
# We can use the vis function to visualize Modules.\nfig, ax = plt.subplots(1, 1, figsize=(3,3))\nnet.vis(ax=ax)\n
<Axes: >\n
\u2026but we can also create a View
to visualize only parts of the net
:
# ... and Views\nfig, ax = plt.subplots(1,1, figsize=(3,3))\nnet.cell(0).vis(ax=ax, col=\"blue\") # View of the 0th cell of the network\nnet.cell(1).vis(ax=ax, col=\"red\") # View of the 1st cell of the network\n\nnet.cell(0).branch(0).vis(ax=ax, col=\"green\") # View of the 1st branch of the 0th cell of the network\nnet.cell(1).branch(1).comp(1).vis(ax=ax, col=\"black\", type=\"scatter\") # View of the 0th comp of the 1st branch of the 0th cell of the network\n
<Axes: >\n
"},{"location":"tutorial/00_jaxley_api/#how-to-create-views","title":"How to create View
s","text":"Above, we used net.cell(0)
to generate a View
of the 0-eth cell. Jaxley
supports many ways of performing such indexing:
# several types of indices are supported (lists, ranges, ...)\nnet.cell([0,1]).branch(\"all\").comp(0) # View of all 0th comps of all branches of cell 0 and 1\n\nbranch.loc(0.1) # Equivalent to `NEURON`s `loc`. Assumes branches are continous from 0-1.\n\nnet[0,0,0] # Modules/Views can also be lazily indexed\n\ncell0 = net.cell(0) # Views can be assigned to variables and only track the parts of the Module they belong to\ncell0.branch(1).comp(0) # Views can be continuely indexed\n
View with 0 different channels. Use `.nodes` for details.\n
cell0.nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v x y z global_cell_index global_branch_index global_comp_index controlled_by_param 0 0 0 0 10.0 1.0 5000.0 1.0 -70.0 5.000000 30.000000 0.0 0 0 0 0 1 0 0 1 10.0 1.0 5000.0 1.0 -70.0 15.000000 30.000000 0.0 0 0 1 0 2 0 0 2 10.0 1.0 5000.0 1.0 -70.0 25.000000 30.000000 0.0 0 0 2 0 3 0 0 3 10.0 1.0 5000.0 1.0 -70.0 35.000000 30.000000 0.0 0 0 3 0 4 0 1 0 10.0 1.0 5000.0 1.0 -70.0 44.850713 28.787322 0.0 0 1 4 0 5 0 1 1 10.0 1.0 5000.0 1.0 -70.0 54.552138 26.361966 0.0 0 1 5 0 6 0 1 2 10.0 1.0 5000.0 1.0 -70.0 64.253563 23.936609 0.0 0 1 6 0 7 0 1 3 10.0 1.0 5000.0 1.0 -70.0 73.954988 21.511253 0.0 0 1 7 0 8 0 2 0 10.0 1.0 5000.0 1.0 -70.0 44.850713 31.212678 0.0 0 2 8 0 9 0 2 1 10.0 1.0 5000.0 1.0 -70.0 54.552138 33.638034 0.0 0 2 9 0 10 0 2 2 10.0 1.0 5000.0 1.0 -70.0 64.253563 36.063391 0.0 0 2 10 0 11 0 2 3 10.0 1.0 5000.0 1.0 -70.0 73.954988 38.488747 0.0 0 2 11 0 net.shape\n
(2, 6, 24)\n
Note: In case you need even more flexibility in how you select parts of a Module, Jaxley provides a select
method, to give full control over the exact parts of the nodes
and edges
that are part of a View
. On examples of how this can be used, see the tutorial on advanced indexing.
You can also iterate over networks, cells, and branches:
# We set the radiuses to random values...\nradiuses = np.random.rand((24))\nnet.set(\"radius\", radiuses)\n\n# ...and then we set the length to 100.0 um if the radius is >0.5.\nfor cell in net:\n for branch in cell:\n for comp in branch:\n if comp.nodes.iloc[0][\"radius\"] > 0.5:\n comp.set(\"length\", 100.0)\n\n# Show the first five compartments:\nnet.nodes[[\"radius\", \"length\"]][:5]\n
radius length 0 0.763057 100.0 1 0.334882 10.0 2 0.805696 100.0 3 0.717921 100.0 4 0.079569 10.0 Finally, you can also use View
s in a context manager:
with net.cell(0).branch(0) as branch0:\n branch0.set(\"radius\", 2.0)\n branch0.set(\"length\", 2.5)\n\n# Show the first five compartments.\nnet.nodes[[\"radius\", \"length\"]][:5]\n
radius length 0 2.000000 2.5 1 2.000000 2.5 2 2.000000 2.5 3 2.000000 2.5 4 0.079569 10.0"},{"location":"tutorial/00_jaxley_api/#channels","title":"Channels","text":"The Module
s that we have created above will not do anything interesting, since by default Jaxley initializes them without any mechanisms in the membrane. To change this, we have to insert channels into the membrane. For this purpose Jaxley
implements Channel
s that can be inserted into any compartment using the insert
method of a Module
or a View
:
# insert a Leak channel into all compartments in the Module.\nnet.insert(Leak())\nnet.nodes.head() # Channel parameters are now also added to `nodes`.\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v global_cell_index global_branch_index global_comp_index controlled_by_param x y z Leak Leak_gLeak Leak_eLeak 0 0 0 0 2.5 2.000000 5000.0 1.0 -70.0 0 0 0 0 5.000000 30.000000 0.0 True 0.0001 -70.0 1 0 0 1 2.5 2.000000 5000.0 1.0 -70.0 0 0 1 0 15.000000 30.000000 0.0 True 0.0001 -70.0 2 0 0 2 2.5 2.000000 5000.0 1.0 -70.0 0 0 2 0 25.000000 30.000000 0.0 True 0.0001 -70.0 3 0 0 3 2.5 2.000000 5000.0 1.0 -70.0 0 0 3 0 35.000000 30.000000 0.0 True 0.0001 -70.0 4 0 1 0 10.0 0.079569 5000.0 1.0 -70.0 0 1 4 0 44.850713 28.787322 0.0 True 0.0001 -70.0 This is also were View
s come in handy, as it allows to easily target the insertion of channels to specific compartments.
# inserting several channels into parts of the network\nwith net.cell(0) as cell0:\n cell0.insert(Na())\n cell0.insert(K())\n\n# # The above is equivalent to:\n# net.cell(0).insert(Na())\n# net.cell(0).insert(K())\n\n# K and Na channels were only insert into cell 0\nnet.cell(\"all\").branch(0).comp(0).nodes[[\"global_cell_index\", \"Na\", \"K\", \"Leak\"]]\n
global_cell_index Na K Leak 0 0 True True True 12 1 False False True"},{"location":"tutorial/00_jaxley_api/#synapses","title":"Synapses","text":"To connect different cells together, Jaxley implements a connect
method, that can be used to couple 2 compartments together using a Synapse
. Synapses in Jaxley work only on the compartment level, that means to be able to connect two cells, you need to specify the exact compartments on a given cell to make the connections between. Below is an example of this:
# connecting two cells using a Synapse\npre_comp = cell0.branch(1).comp(0)\npost_comp = net.cell(1).branch(0).comp(0)\n\nconnect(pre_comp, post_comp, IonotropicSynapse())\n\nnet.edges\n
global_edge_index global_pre_comp_index global_post_comp_index type type_ind pre_locs post_locs IonotropicSynapse_gS IonotropicSynapse_e_syn IonotropicSynapse_k_minus IonotropicSynapse_s controlled_by_param 0 0 4 12 IonotropicSynapse 0 0.125 0.125 0.0001 0.0 0.025 0.2 0 As you can see above, now the edges
dataframe is also updated with the information of the newly added synapse.
Congrats! You should now have an intuitive understand of how to use Jaxley\u2019s API to construct, navigate and manipulate neuron models.
"},{"location":"tutorial/01_morph_neurons/","title":"Basics of Jaxley","text":"In this tutorial, you will learn how to:
Here is a code snippet which you will learn to understand in this tutorial:
import jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nimport matplotlib.pyplot as plt\n\n\n# Build the cell.\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=2)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1])\n\n# Insert channels.\ncell.insert(Leak())\ncell.branch(0).insert(Na())\ncell.branch(0).insert(K())\n\n# Change parameters.\ncell.set(\"axial_resistivity\", 200.0)\n\n# Visualize the morphology.\ncell.compute_xyz()\nfig, ax = plt.subplots(1, 1, figsize=(4, 4))\ncell.vis(ax=ax)\n\n# Stimulate.\ncurrent = jx.step_current(i_delay=1.0, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=10.0)\ncell.branch(0).loc(0.0).stimulate(current)\n\n# Record.\ncell.branch(0).loc(0.0).record(\"v\")\n\n# Simulate and plot.\nv = jx.integrate(cell, delta_t=0.025)\nplt.plot(v.T)\n
First, we import the relevant libraries:
from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax.numpy as jnp\nfrom jax import jit\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import fully_connect\n
We will now build our first cell in Jaxley
. You have two options to do this: you can either build a cell bottom-up by defining the morphology yourselve, or you can load cells from SWC files.
To define a cell from scratch you first have to define a single compartment and then assemble those compartments into a branch:
comp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=2)\n
Next, we can assemble branches into a cell. To do so, we have to define for each branch what its parent branch is. A -1
entry means that this branch does not have a parent.
parents = jnp.asarray([-1, 0, 0, 1, 1])\ncell = jx.Cell(branch, parents=parents)\n
To learn more about Compartment
s, Branch
es, and Cell
s, see this tutorial.
Alternatively, you could also load cells from SWC with
cell = jx.read_swc(fname, ncomp=4)
Details on handling SWC files can be found in this tutorial.
"},{"location":"tutorial/01_morph_neurons/#visualize-the-cells","title":"Visualize the cells","text":"Cells can be visualized as follows:
cell.compute_xyz() # Only needed for visualization.\n\nfig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = cell.vis(ax=ax, col=\"k\")\n
"},{"location":"tutorial/01_morph_neurons/#insert-mechanisms","title":"Insert mechanisms","text":"Currently, the cell does not contain any kind of ion channel (not even a leak
). We can fix this by inserting a leak channel into the entire cell, and by inserting sodium and potassium into the zero-eth branch.
cell.insert(Leak())\ncell.branch(0).insert(Na())\ncell.branch(0).insert(K())\n
Once the cell is created, we can inspect its .nodes
attribute which lists all properties of the cell:
cell.nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v global_cell_index global_branch_index ... Na Na_gNa eNa vt Na_m Na_h K K_gK eK K_n 0 0 0 0 10.0 1.0 5000.0 1.0 -70.0 0 0 ... True 0.05 50.0 -60.0 0.2 0.2 True 0.005 -90.0 0.2 1 0 0 1 10.0 1.0 5000.0 1.0 -70.0 0 0 ... True 0.05 50.0 -60.0 0.2 0.2 True 0.005 -90.0 0.2 2 0 1 0 10.0 1.0 5000.0 1.0 -70.0 0 1 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 3 0 1 1 10.0 1.0 5000.0 1.0 -70.0 0 1 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 4 0 2 0 10.0 1.0 5000.0 1.0 -70.0 0 2 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 5 0 2 1 10.0 1.0 5000.0 1.0 -70.0 0 2 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 6 0 3 0 10.0 1.0 5000.0 1.0 -70.0 0 3 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 7 0 3 1 10.0 1.0 5000.0 1.0 -70.0 0 3 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 8 0 4 0 10.0 1.0 5000.0 1.0 -70.0 0 4 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 9 0 4 1 10.0 1.0 5000.0 1.0 -70.0 0 4 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 10 rows \u00d7 25 columns
Note that Jaxley
uses the same units as the NEURON
simulator, which are listed here.
You can also inspect just parts of the cell
, for example its 1st branch:
cell.branch(1).nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v Leak Leak_gLeak ... Na_m Na_h K K_gK eK K_n global_cell_index global_branch_index global_comp_index controlled_by_param 2 0 0 0 10.0 1.0 5000.0 1.0 -70.0 True 0.0001 ... NaN NaN False NaN NaN NaN 0 1 2 1 3 0 0 1 10.0 1.0 5000.0 1.0 -70.0 True 0.0001 ... NaN NaN False NaN NaN NaN 0 1 3 1 2 rows \u00d7 25 columns
The easiest way to know which branch is the 1st branch (or, e.g., the zero-eth compartment of the 1st branch) is to plot it in a different color:
fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = cell.vis(ax=ax, col=\"k\")\n_ = cell.branch(1).vis(ax=ax, col=\"r\")\n_ = cell.branch(1).comp(1).vis(ax=ax, col=\"b\")\n
More background and features on indexing as cell.branch(0)
is in this tutorial.
You can change properties of the cell with the .set()
method:
cell.branch(1).set(\"axial_resistivity\", 200.0)\n
And we can again inspect the .nodes
to make sure that the axial resistivity indeed changed:
cell.branch(1).nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v Leak Leak_gLeak ... Na_m Na_h K K_gK eK K_n global_cell_index global_branch_index global_comp_index controlled_by_param 2 0 0 0 10.0 1.0 200.0 1.0 -70.0 True 0.0001 ... NaN NaN False NaN NaN NaN 0 1 2 1 3 0 0 1 10.0 1.0 200.0 1.0 -70.0 True 0.0001 ... NaN NaN False NaN NaN NaN 0 1 3 1 2 rows \u00d7 25 columns
In a similar way, you can modify channel properties or initial states (units are again here):
cell.branch(0).set(\"K_gK\", 0.01) # modify potassium conductance.\ncell.set(\"v\", -65.0) # modify initial voltage.\n
"},{"location":"tutorial/01_morph_neurons/#stimulate-the-cell","title":"Stimulate the cell","text":"We next stimulate one of the compartments with a step current. For this, we first define the step current (units are again here):
dt = 0.025\nt_max = 10.0\ntime_vec = np.arange(0, t_max+dt, dt)\ncurrent = jx.step_current(i_delay=1.0, i_dur=2.0, i_amp=0.08, delta_t=dt, t_max=t_max)\n\nfig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = plt.plot(time_vec, current)\n
We then stimulate one of the compartments of the cell with this step current:
cell.delete_stimuli()\ncell.branch(0).loc(0.0).stimulate(current)\n
Added 1 external_states. See `.externals` for details.\n
"},{"location":"tutorial/01_morph_neurons/#define-recordings","title":"Define recordings","text":"Next, you have to define where to record the voltage. In this case, we will record the voltage at two locations:
cell.delete_recordings()\ncell.branch(0).loc(0.0).record(\"v\")\ncell.branch(3).loc(1.0).record(\"v\")\n
Added 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\n
We can again visualize these locations to understand where we inserted recordings:
fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = cell.vis(ax=ax)\n_ = cell.branch(0).loc(0.0).vis(ax=ax, col=\"b\")\n_ = cell.branch(3).loc(1.0).vis(ax=ax, col=\"g\")\n
"},{"location":"tutorial/01_morph_neurons/#simulate-the-cell-response","title":"Simulate the cell response","text":"Having set up the cell, inserted stimuli and recordings, we are now ready to run a simulation with jx.integrate
:
voltages = jx.integrate(cell, delta_t=dt)\nprint(\"voltages.shape\", voltages.shape)\n
voltages.shape (2, 402)\n
The jx.integrate
function returns an array of shape (num_recordings, num_timepoints)
. In our case, we inserted 2
recordings and we simulated for 10ms at a 0.025 time step, which leads to 402 time steps.
We can now visualize the voltage response:
fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(voltages[0], c=\"b\")\n_ = ax.plot(voltages[1], c=\"orange\")\n
At the location of the first recording (in blue) the cell spiked, whereas at the second recording, it did not. This makes sense because we only inserted sodium and potassium channels into the first branch, but not in the entire cell.
Congrats! You have just run your first morphologically detailed neuron simulation in Jaxley
. We suggest to continue by learning how to build networks. If you are only interested in single cell simulations, you can directly jump to learning how to speed up simulations. If you want to simulate detailed morphologies from SWC files, checkout our tutorial on working with detailed morphologies.
In this tutorial, you will learn how to:
.edges
attribute to inspect and change synaptic parametersHere is a code snippet which you will learn to understand in this tutorial:
import jaxley as jx\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import connect\n\n\n# Define a network. `cell` is defined as in previous tutorial.\nnet = jx.Network([cell for _ in range(11)])\n\n# Define synapses.\nfully_connect(\n net.cell(range(10)),\n net.cell(10),\n IonotropicSynapse(),\n)\n\n# Change synaptic parameters.\nnet.select(edges=[0, 1]).set(\"IonotropicSynapse_gS\", 0.1) # nS\n\n# Visualize the network.\nnet.compute_xyz()\nfig, ax = plt.subplots(1, 1, figsize=(4, 4))\nnet.vis(ax=ax, detail=\"full\", layers=[10, 1]) # or `detail=\"point\"`.\n
In the previous tutorial, you learned how to build single cells with morphological detail, how to insert stimuli and recordings, and how to run a first simulation. In this tutorial, we will define networks of multiple cells and connect them with synapses. Let\u2019s get started:
from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax.numpy as jnp\nfrom jax import jit\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import fully_connect, connect\n
"},{"location":"tutorial/02_small_network/#define-the-network","title":"Define the network","text":"First, we define a cell as you saw in the previous tutorial.
comp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=4)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1, 2, 2])\n
We can assemble multiple cells into a network by using jx.Network
, which takes a list of jx.Cell
s. Here, we assemble 11 cells into a network:
num_cells = 11\nnet = jx.Network([cell for _ in range(num_cells)])\n
At this point, we can already visualize this network:
net.compute_xyz()\nnet.rotate(180)\nfig, ax = plt.subplots(1, 1, figsize=(3, 6))\n_ = net.vis(ax=ax, detail=\"full\", layers=[10, 1], layer_kwargs={\"within_layer_offset\": 150, \"between_layer_offset\": 200})\n
Note: you can use move_to
to have more control over the location of cells, e.g.: network.cell(i).move_to(x=0, y=200)
.
As you can see, the neurons are not connected yet. Let\u2019s fix this by connecting neurons with synapses. We will build a network consisting of two layers: 10 neurons in the input layer and 1 neuron in the output layer.
We can use Jaxley
\u2019s fully_connect
method to connect these layers:
pre = net.cell(range(10))\npost = net.cell(10)\nfully_connect(pre, post, IonotropicSynapse())\n
Let\u2019s visualize this again:
fig, ax = plt.subplots(1, 1, figsize=(3, 6))\n_ = net.vis(ax=ax, detail=\"full\", layers=[10, 1], layer_kwargs={\"within_layer_offset\": 150, \"between_layer_offset\": 200})\n
As you can see, the full_connect
method inserted one synapse (in blue) from every neuron in the first layer to the output neuron. The fully_connect
method builds this synapse from the zero-eth compartment and zero-eth branch of the presynaptic neuron onto a random branch of the postsynaptic neuron. If you want more control over the pre- and post-synaptic branches, you can use the connect
method:
pre = net.cell(0).branch(5).loc(1.0)\npost = net.cell(10).branch(0).loc(0.0)\nconnect(pre, post, IonotropicSynapse())\n
fig, ax = plt.subplots(1, 1, figsize=(3, 6))\n_ = net.vis(ax=ax, detail=\"full\", layers=[10, 1], layer_kwargs={\"within_layer_offset\": 150, \"between_layer_offset\": 200})\n
"},{"location":"tutorial/02_small_network/#inspecting-and-changing-synaptic-parameters","title":"Inspecting and changing synaptic parameters","text":"You can inspect synaptic parameters via the .edges
attribute:
net.edges\n
global_edge_index global_pre_comp_index global_post_comp_index type type_ind pre_locs post_locs IonotropicSynapse_gS IonotropicSynapse_e_syn IonotropicSynapse_k_minus IonotropicSynapse_s controlled_by_param 0 0 0 286 IonotropicSynapse 0 0.125 0.625 0.0001 0.0 0.025 0.2 0 1 1 28 298 IonotropicSynapse 0 0.125 0.625 0.0001 0.0 0.025 0.2 0 2 2 56 286 IonotropicSynapse 0 0.125 0.625 0.0001 0.0 0.025 0.2 0 3 3 84 295 IonotropicSynapse 0 0.125 0.875 0.0001 0.0 0.025 0.2 0 4 4 112 302 IonotropicSynapse 0 0.125 0.625 0.0001 0.0 0.025 0.2 0 5 5 140 288 IonotropicSynapse 0 0.125 0.125 0.0001 0.0 0.025 0.2 0 6 6 168 287 IonotropicSynapse 0 0.125 0.875 0.0001 0.0 0.025 0.2 0 7 7 196 305 IonotropicSynapse 0 0.125 0.375 0.0001 0.0 0.025 0.2 0 8 8 224 299 IonotropicSynapse 0 0.125 0.875 0.0001 0.0 0.025 0.2 0 9 9 252 284 IonotropicSynapse 0 0.125 0.125 0.0001 0.0 0.025 0.2 0 10 10 23 280 IonotropicSynapse 0 0.875 0.125 0.0001 0.0 0.025 0.2 0 To modify a parameter of all synapses you can again use .set()
:
net.set(\"IonotropicSynapse_gS\", 0.0003) # nS\n
To modify individual syanptic parameters, use the .select()
method. Below, we change the values of the first two synapses:
net.select(edges=[0, 1]).set(\"IonotropicSynapse_gS\", 0.0004) # nS\n
For more details on how to flexibly set synaptic parameters (e.g., by cell type, or by pre-synaptic cell index,\u2026), see this tutorial.
"},{"location":"tutorial/02_small_network/#stimulating-recording-and-simulating-the-network","title":"Stimulating, recording, and simulating the network","text":"We will now set up a simulation of the network. This works exactly as it does for single neurons:
# Stimulus.\ni_delay = 3.0 # ms\ni_amp = 0.05 # nA\ni_dur = 2.0 # ms\n\n# Duration and step size.\ndt = 0.025 # ms\nt_max = 50.0 # ms\n
time_vec = jnp.arange(0.0, t_max + dt, dt)\n
As a simple example, we insert sodium, potassium, and leak into every compartment of every cell of the network.
net.insert(Na())\nnet.insert(K())\nnet.insert(Leak())\n
We stimulate every neuron in the input layer and record the voltage from the output neuron:
current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max)\nnet.delete_stimuli()\nfor stim_ind in range(10):\n net.cell(stim_ind).branch(0).loc(0.0).stimulate(current)\n\nnet.delete_recordings()\nnet.cell(10).branch(0).loc(0.0).record()\n
Added 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 recordings. See `.recordings` for details.\n
Finally, we can again run the network simulation and plot the result:
s = jx.integrate(net, delta_t=dt)\n
fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(s.T)\n
That\u2019s it! You now know how to simulate networks of morphologically detailed neurons. We recommend that you now have a look at how you can speed up your simulation. To learn more about handling synaptic parameters, we recommend to check out this tutorial.
"},{"location":"tutorial/04_jit_and_vmap/","title":"Speeding up simulations","text":"In this tutorial, you will learn how to:
Jaxley
jit
to compile your simulations and make them faster vmap
to parallelize simulations on GPUs Here is a code snippet which you will learn to understand in this tutorial:
from jax import jit, vmap\n\n\ncell = ... # See tutorial on Basics of Jaxley.\n\ndef simulate(params):\n param_state = None\n param_state = cell.data_set(\"Na_gNa\", params[0], param_state)\n param_state = cell.data_set(\"K_gK\", params[1], param_state)\n return jx.integrate(cell, param_state=param_state, delta_t=0.025)\n\n# Define 100 sets of sodium and potassium conductances.\nall_params = jnp.asarray(np.random.rand(100, 2))\n\n# Fast for-loops with jit compilation.\njitted_simulate = jit(simulate)\nvoltages = [jitted_simulate(params) for params in all_params]\n\n# Using vmap for parallelization.\nvmapped_simulate = vmap(jitted_simulate, in_axes=(0,))\nvoltages = vmapped_simulate(all_params)\n
In the previous tutorials, you learned how to build single cells or networks and how to change their parameters. In this tutorial, you will learn how to speed up such simulations by many orders of magnitude. This can be achieved in to ways:
Let\u2019s get started!
"},{"location":"tutorial/04_jit_and_vmap/#using-gpu-or-cpu","title":"Using GPU or CPU","text":"In Jaxley
you can set whether you want to use gpu
or cpu
with the following lines at the beginning of your script:
from jax import config\nconfig.update(\"jax_platform_name\", \"cpu\")\n
JAX
(and Jaxley
) also allow to choose between float32
and float64
. Especially on GPUs, float32
will be faster, but we have experienced stability issues when simulating morphologically detailed neurons with float32
.
config.update(\"jax_enable_x64\", True) # Set to false to use `float32`.\n
Next, we will import relevant libraries:
import matplotlib.pyplot as plt\nimport numpy as np\nimport jax.numpy as jnp\nfrom jax import jit, vmap\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\n
"},{"location":"tutorial/04_jit_and_vmap/#building-the-cell-or-network","title":"Building the cell or network","text":"We first build a cell (or network) in the same way as we showed in the previous tutorials:
dt = 0.025\nt_max = 10.0\n\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=4)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1, 2, 2])\n\ncell.insert(Na())\ncell.insert(K())\ncell.insert(Leak())\n\ncell.delete_stimuli()\ncurrent = jx.step_current(i_delay=1.0, i_dur=1.0, i_amp=0.1, delta_t=dt, t_max=t_max)\ncell.branch(0).loc(0.0).stimulate(current)\n\ncell.delete_recordings()\ncell.branch(0).loc(0.0).record()\n
Added 1 external_states. See `.externals` for details.\nAdded 1 recordings. See `.recordings` for details.\n
"},{"location":"tutorial/04_jit_and_vmap/#parameter-sweeps","title":"Parameter sweeps","text":"Assume you want to run the same cell with many different values for the sodium and potassium conductance, for example for genetic algorithms or for parameter sweeps. To do this efficiently in Jaxley
, you have to use the data_set()
method (in combination with jit
and vmap
, as shown later):
def simulate(params):\n param_state = None\n param_state = cell.data_set(\"Na_gNa\", params[0], param_state)\n param_state = cell.data_set(\"K_gK\", params[1], param_state)\n return jx.integrate(cell, param_state=param_state, delta_t=dt)\n
The .data_set()
method takes three arguments:
1) the name of the parameter you want to set. Jaxley
allows to set the following parameters: \u201cradius\u201d, \u201clength\u201d, \u201caxial_resistivity\u201d, as well as all parameters of channels and synapses. 2) the value of the parameter. 3) a param_state
which is initialized as None
and is modified by .data_set()
. This has to be passed to jx.integrate()
.
Having done this, the simplest (but least efficient) way to perform the parameter sweep is to run a for-loop over many parameter sets:
# Define 5 sets of sodium and potassium conductances.\nall_params = jnp.asarray(np.random.rand(5, 2))\n\nvoltages = jnp.asarray([simulate(params) for params in all_params])\nprint(\"voltages.shape\", voltages.shape)\n
voltages.shape (5, 1, 402)\n
The resulting voltages have shape (num_simulations, num_recordings, num_timesteps)
.
In addition to running sweeps across multiple parameters, you can also run sweeeps across multiple stimuli (e.g. step current stimuli of different amplitudes. You can achieve this with the data_stimulate()
method:
def simulate(i_amp):\n current = jx.step_current(1.0, 1.0, i_amp, 0.025, 10.0)\n\n data_stimuli = None\n data_stimuli = cell.branch(0).comp(0).data_stimulate(current, data_stimuli)\n return jx.integrate(cell, data_stimuli=data_stimuli)\n
"},{"location":"tutorial/04_jit_and_vmap/#speeding-up-for-loops-via-jit-compilation","title":"Speeding up for loops via jit
compilation","text":"We can speed up such parameter sweeps (or stimulus sweeps) with jit
compilation. jit
compilation will compile the simulation when it is run for the first time, such that every other simulation will be must faster. This can be achieved by defining a new function which uses JAX
\u2019s jit()
:
jitted_simulate = jit(simulate)\n
# First run, will be slow.\nvoltages = jitted_simulate(all_params[0])\n
# More runs, will be much faster.\nvoltages = jnp.asarray([jitted_simulate(params) for params in all_params])\nprint(\"voltages.shape\", voltages.shape)\n
voltages.shape (5, 1, 402)\n
jit
compilation can be up to 10k times faster, especially for small simulations with few compartments. For very large models, the gain obtained with jit
will be much smaller (jit
may even provide no speed up at all).
vmap
","text":"Another way to speed up parameter sweeps is with GPU parallelization. Parallelization in Jaxley
can be achieved by using vmap
of JAX
. To do this, we first create a new function that handles multiple parameter sets directly:
# Using vmap for parallelization.\nvmapped_simulate = vmap(jitted_simulate)\n
We can then run this method on all parameter sets (all_params.shape == (100, 2)
), and Jaxley
will automatically parallelize across them. Of course, you will only get a speed-up if you have a GPU available and you specified gpu
as device in the beginning of this tutorial.
voltages = vmapped_simulate(all_params)\n
GPU parallelization with vmap
can give a large speed-up, which can easily be 2-3 orders of magnitude.
jit
and vmap
","text":"Finally, you can also combine using jit
and vmap
. For example, you can run multiple batches of many parallel simulations. Each batch can be parallelized with vmap
and simulating each batch can be compiled with jit
:
jitted_vmapped_simulate = jit(vmap(simulate))\n
for batch in range(10):\n all_params = jnp.asarray(np.random.rand(5, 2))\n voltages_batch = jitted_vmapped_simulate(all_params)\n
That\u2019s all you have to know about jit
and vmap
! If you have worked through this and the previous tutorials, you should be ready to set up your first network simulations.
If you want to learn more, we recommend you to read the tutorial on building channel and synapse models.
Alternatively, you can also directly jump ahead to the tutorial on training biophysical networks which will teach you how you can optimize parameters of biophysical models with gradient descent.
Finally, if you want to learn more about JAX, check out their tutorial on jit or their tutorial on vmap.
"},{"location":"tutorial/05_channel_and_synapse_models/","title":"Building ion channel models","text":"In this tutorial, you will learn how to:
Jaxley
This tutorial assumes that you have already learned how to build basic simulations.
from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax\nimport jax.numpy as jnp\nfrom jax import jit, value_and_grad\n\nimport jaxley as jx\n
First, we define a cell as you saw in the previous tutorial:
comp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=4)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1, 2, 2])\n
You have also already learned how to insert preconfigured channels into Jaxley
models:
cell.insert(Na())\ncell.insert(K())\ncell.insert(Leak())\n
In this tutorial, we will show you how to build your own channel and synapse models.
"},{"location":"tutorial/05_channel_and_synapse_models/#your-own-channel","title":"Your own channel","text":"Below is how you can define your own channel. We will go into detail about individual parts of the code in the next couple of cells.
import jax.numpy as jnp\nfrom jaxley.channels import Channel\nfrom jaxley.solver_gate import solve_gate_exponential\n\n\ndef exp_update_alpha(x, y):\n return x / (jnp.exp(x / y) - 1.0)\n\nclass Potassium(Channel):\n \"\"\"Potassium channel.\"\"\"\n\n def __init__(self, name = None):\n self.current_is_in_mA_per_cm2 = True\n super().__init__(name)\n self.channel_params = {\"gK_new\": 1e-4}\n self.channel_states = {\"n_new\": 0.0}\n self.current_name = \"i_K\"\n\n def update_states(self, states, dt, v, params):\n \"\"\"Update state.\"\"\"\n ns = states[\"n_new\"]\n alpha = 0.01 * exp_update_alpha(-(v + 55), 10)\n beta = 0.125 * jnp.exp(-(v + 65) / 80)\n new_n = solve_gate_exponential(ns, dt, alpha, beta)\n return {\"n_new\": new_n}\n\n def compute_current(self, states, v, params):\n \"\"\"Return current.\"\"\"\n ns = states[\"n_new\"]\n kd_conds = params[\"gK_new\"] * ns**4 # S/cm^2\n\n e_kd = -77.0 \n return kd_conds * (v - e_kd)\n\n def init_state(self, states, v, params, delta_t):\n alpha = 0.01 * exp_update_alpha(-(v + 55), 10)\n beta = 0.125 * jnp.exp(-(v + 65) / 80)\n return {\"n_new\": alpha / (alpha + beta)}\n
Let\u2019s look at each part of this in detail.
The below is simply a helper function for the solver of the gate variables:
def exp_update_alpha(x, y):\n return x / (jnp.exp(x / y) - 1.0)\n
Next, we define our channel as a class. It should inherit from the Channel
class and define channel_params
, channel_states
, and current_name
. You also need to set self.current_is_in_mA_per_cm2=True
as the first line on your __init__()
method. This is to acknowledge that your current is returned in mA/cm2
(not in uA/cm2
, as would have been required in Jaxley versions 0.4.0 or older).
class Potassium(Channel):\n \"\"\"Potassium channel.\"\"\"\n\n def __init__(self, name=None):\n self.current_is_in_mA_per_cm2 = True\n super().__init__(name)\n self.channel_params = {\"gK_new\": 1e-4}\n self.channel_states = {\"n_new\": 0.0}\n self.current_name = \"i_K\"\n
Next, we have the update_states()
method, which updates the gating variables:
def update_states(self, states, dt, v, params):\n
Every channel you define must have an update_states()
method which takes exactly these five arguments (self, states, dt, v, params). The inputs states
to the update_states
method is a dictionary which contains all states that are updated (including states of other channels). v
is a jnp.ndarray
which contains the voltage of a single compartment (shape ()
). Let\u2019s get the state of the potassium channel which we are building here:
ns = states[\"n_new\"]\n
Next, we update the state of the channel. In this example, we do this with exponential Euler, but you can implement any solver yourself:
alpha = 0.01 * exp_update_alpha(-(v + 55), 10)\nbeta = 0.125 * jnp.exp(-(v + 65) / 80)\nnew_n = solve_gate_exponential(ns, dt, alpha, beta)\nreturn {\"n_new\": new_n}\n
A channel also needs a compute_current()
method which returns the current throught the channel:
def compute_current(self, states, v, params):\n ns = states[\"n_new\"]\n\n # Multiply with 1000 to convert Siemens to milli Siemens.\n kd_conds = params[\"gK_new\"] * ns**4 # S/cm^2\n\n e_kd = -77.0 \n current = kd_conds * (v - e_kd)\n return current\n
Finally, the init_state()
method can be implemented optionally. It can be used to automatically compute the initial state based on the voltage when cell.init_states()
is run.
Alright, done! We can now insert this channel into any jx.Module
such as our cell:
cell.insert(Potassium())\n
cell.delete_stimuli()\ncurrent = jx.step_current(1.0, 1.0, 0.1, 0.025, 10.0)\ncell.branch(0).comp(0).stimulate(current)\n\ncell.delete_recordings()\ncell.branch(0).comp(0).record()\n
Added 1 external_states. See `.externals` for details.\nAdded 1 recordings. See `.recordings` for details.\n
s = jx.integrate(cell)\n
fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(s.T[:-1])\n_ = ax.set_ylim([-80, 50])\n_ = ax.set_xlabel(\"Time (ms)\")\n_ = ax.set_ylabel(\"Voltage (mV)\")\n
"},{"location":"tutorial/05_channel_and_synapse_models/#your-own-synapse","title":"Your own synapse","text":"The parts below assume that you have already learned how to build network simulations in Jaxley
.
Note that again, a synapse needs to have the two functions update_states
and compute_current
with all input arguments shown below.
The below is an example of how to define your own synapse model in Jaxley
:
import jax.numpy as jnp\nfrom jaxley.synapses.synapse import Synapse\n\n\nclass TestSynapse(Synapse):\n \"\"\"\n Compute syanptic current and update syanpse state.\n \"\"\"\n def __init__(self, name = None):\n super().__init__(name)\n self.synapse_params = {\"gChol\": 0.001, \"eChol\": 0.0}\n self.synapse_states = {\"s_chol\": 0.1}\n\n def update_states(self, states, delta_t, pre_voltage, post_voltage, params):\n \"\"\"Return updated synapse state and current.\"\"\"\n s_inf = 1.0 / (1.0 + jnp.exp((-35.0 - pre_voltage) / 10.0))\n exp_term = jnp.exp(-delta_t)\n new_s = states[\"s_chol\"] * exp_term + s_inf * (1.0 - exp_term)\n return {\"s_chol\": new_s}\n\n def compute_current(self, states, pre_voltage, post_voltage, params):\n g_syn = params[\"gChol\"] * states[\"s_chol\"]\n return g_syn * (post_voltage - params[\"eChol\"])\n
As you can see above, synapses follow closely how channels are defined. The main difference is that the compute_current
method takes two voltages: the pre-synaptic voltage (a jnp.ndarray
of shape ()
) and the post-synaptic voltage (a jnp.ndarray
of shape ()
).
net = jx.Network([cell for _ in range(3)])\n
from jaxley.connect import connect\n\npre = net.cell(0).branch(0).loc(0.0)\npost = net.cell(1).branch(0).loc(0.0)\nconnect(pre, post, TestSynapse())\n
net.cell(0).branch(0).loc(0.0).stimulate(jx.step_current(1.0, 2.0, 0.1, 0.025, 10.0))\nfor i in range(3):\n net.cell(i).branch(0).loc(0.0).record()\n
Added 1 external_states. See `.externals` for details.\nAdded 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\n
s = jx.integrate(net)\n
fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(s.T[:-1])\n_ = ax.set_ylim([-80, 50])\n_ = ax.set_xlabel(\"Time (ms)\")\n_ = ax.set_ylabel(\"Voltage (mV)\")\n
That\u2019s it! You are now ready to build your own custom simulations and equip them with channel and synapse models!
This tutorial does not have an immediate follow-up tutorial. If you have not done so already, you can check out our tutorial on training biophysical networks which will teach you how you can optimize parameters of biophysical models with gradient descent.
"},{"location":"tutorial/06_groups/","title":"Defining groups","text":"In this tutorial, you will learn how to:
Jaxley
Here is a code snippet which you will learn to understand in this tutorial:
from jax import jit, vmap\n\n\nnet = ... # See tutorial on Basics of Jaxley.\n\nnet.cell(0).add_to_group(\"fast_spiking\")\nnet.cell(1).add_to_group(\"slow_spiking\")\n\ndef simulate(params):\n param_state = None\n param_state = net.fast_spiking.data_set(\"HH_gNa\", params[0], param_state)\n param_state = net.slow_spiking.data_set(\"HH_gNa\", params[1], param_state)\n return jx.integrate(net, param_state=param_state)\n\n# Define sodium for fast and slow spiking neurons.\nparams = jnp.asarray([1.0, 0.1])\n\n# Run simulation.\nvoltages = simulate(params)\n
In many cases, you might want to group several compartments (or branches, or cells) and assign a unique parameter or mechanism to this group. For example, you might want to define a couple of branches as basal and then assign a Hodgkin-Huxley mechanism only to those branches. Or you might define a couple of cells as fast spiking and assign them a high value for the sodium conductance. We describe how you can do this in this tutorial.
from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport time\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax\nimport jax.numpy as jnp\nfrom jax import jit, value_and_grad\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import fully_connect\n
First, we define a network as you saw in the previous tutorial:
comp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=2)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1])\nnetwork = jx.Network([cell for _ in range(3)])\n\npre = network.cell([0, 1])\npost = network.cell([2])\nfully_connect(pre, post, IonotropicSynapse())\n\nnetwork.insert(Na())\nnetwork.insert(K())\nnetwork.insert(Leak())\n
"},{"location":"tutorial/06_groups/#group-apical-dendrites","title":"Group: apical dendrites","text":"Assume that, in each of the five neurons in this network, the second and forth branch are apical dendrites. We can define this as:
for cell_ind in range(3):\n network.cell(cell_ind).branch(1).add_to_group(\"apical\")\n network.cell(cell_ind).branch(3).add_to_group(\"apical\")\n
After this, we can access network.apical
as we previously accesses anything else:
network.apical.set(\"radius\", 0.3)\n
network.apical.view\n
View with 3 different channels. Use `.nodes` for details.\n
"},{"location":"tutorial/06_groups/#group-fast-spiking","title":"Group: fast spiking","text":"Similarly, you could define a group of fast-spiking cells. Assume that the first and second cell are fast-spiking:
network.cell(0).add_to_group(\"fast_spiking\")\nnetwork.cell(1).add_to_group(\"fast_spiking\")\n
network.fast_spiking.set(\"Na_gNa\", 0.4)\n
network.fast_spiking.view\n
View with 3 different channels. Use `.nodes` for details.\n
"},{"location":"tutorial/06_groups/#groups-from-swc-files","title":"Groups from SWC files","text":"If you are reading .swc
morphologigies, you can automatically assign groups with
jx.read_swc(file_name, nseg=n, assign_groups=True).\n
After that, you can directly use cell.soma
, cell.apical
, cell.basal
, or cell.axon
."},{"location":"tutorial/06_groups/#how-groups-are-interpreted-by-make_trainable","title":"How groups are interpreted by .make_trainable()
","text":"If you make a parameter of a group
trainable, then it will be treated as a single shared parameter for a given property:
network.fast_spiking.make_trainable(\"Na_gNa\")\n
Number of newly added trainable parameters: 1. Total number of trainable parameters: 1\n
As such, get_parameters()
returns only a single trainable parameter, which will be the sodium conductance for every compartment of every fast-spiking neuron:
network.get_parameters()\n
[{'Na_gNa': Array([0.4], dtype=float64)}]\n
If, instead, you would want a separate parameter for every fast-spiking cell, you should not use the group, but instead do the following (remember that fast-spiking neurons had indices [0,1]):
network.cell([0,1]).make_trainable(\"axial_resistivity\")\n
Number of newly added trainable parameters: 2. Total number of trainable parameters: 3\n
network.get_parameters()\n
[{'Na_gNa': Array([0.4], dtype=float64)},\n {'axial_resistivity': Array([5000., 5000.], dtype=float64)}]\n
This generated two parameters for the axial resistivitiy, each corresponding to one cell.
"},{"location":"tutorial/06_groups/#summary","title":"Summary","text":"Groups allow you to organize your simulation in a more intuitive way, and they allow to perform parameter sharing with make_trainable()
.
In this tutorial, you will learn how to train biophysical models in Jaxley
. This includes the following:
Here is a code snippet which you will learn to understand in this tutorial:
from jax import jit, vmap, value_and_grad\nimport jaxley as jx\nimport jaxley.optimize.transforms as jt\n\nnet = ... # See tutorial on the basics of `Jaxley`.\n\n# Define which parameters to train.\nnet.cell(\"all\").make_trainable(\"HH_gNa\")\nnet.IonotropicSynapse.make_trainable(\"IonotropicSynapse_gS\")\nparameters = net.get_parameters()\n\n# Define parameter transform and apply it to the parameters.\ntransform = jx.ParamTransform([\n {\"IonotropicSynapse_gS\": jt.SigmoidTransform(0.0, 1.0)},\n {\"HH_gNa\":jt.SigmoidTransform(0.0, 1, 0)}\n])\n\nopt_params = transform.inverse(parameters)\n\n# Define simulation and batch it across stimuli.\ndef simulate(params, datapoint):\n current = jx.datapoint_to_step_currents(i_delay=1.0, i_dur=1.0, i_amps=datapoint, dt=0.025, t_max=5.0)\n data_stimuli = net.cell(0).branch(0).comp(0).data_stimulate(current, None)\n return jx.integrate(net, params=params, data_stimuli=data_stimuli, checkpoint_inds=[20, 20], delta_t=0.025)\n\nbatch_simulate = vmap(simulate, in_axes=(None, 0))\n\n# Define loss function and its gradient.\ndef loss_fn(opt_params, datapoints, label):\n params = transform.forward(opt_params)\n voltages = batch_simulate(params, datapoints)\n return jnp.abs(jnp.mean(voltages) - label)\n\ngrad_fn = jit(value_and_grad(loss_fn, argnums=0))\n\n# Define data and dataloader.\ndata = jnp.asarray(np.random.randn(100, 3))\ndataloader = Dataset.from_tensor_slices((inputs, labels))\ndataloader = dataloader.shuffle(dataloader.cardinality()).batch(4)\n\n# Define the optimizer.\noptimizer = optax.Adam(lr=0.01)\nopt_state = optimizer.init_state(opt_params)\n\nfor epoch in range(10):\n for batch in dataloader:\n stimuli = batch[0].numpy()\n labels = batch[1].numpy()\n loss, gradient = grad_fn(opt_params, stimuli, labels)\n\n # Optimizer step.\n updates, opt_state = optimizer.update(gradient, opt_state)\n opt_params = optax.apply_updates(opt_params, updates)\n
from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax\nimport jax.numpy as jnp\nfrom jax import jit, vmap, value_and_grad\n\nimport jaxley as jx\nfrom jaxley.channels import Leak\nfrom jaxley.synapses import TanhRateSynapse\nfrom jaxley.connect import fully_connect\n
First, we define a network as you saw in the previous tutorial:
_ = np.random.seed(0) # For synaptic locations.\n\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=2)\ncell = jx.Cell(branch, parents=[-1, 0, 0])\nnet = jx.Network([cell for _ in range(3)])\n\npre = net.cell([0, 1])\npost = net.cell([2])\nfully_connect(pre, post, TanhRateSynapse())\n\n# Change some default values of the tanh synapse.\nnet.TanhRateSynapse.set(\"TanhRateSynapse_x_offset\", -60.0)\nnet.TanhRateSynapse.set(\"TanhRateSynapse_gS\", 1e-3)\nnet.TanhRateSynapse.set(\"TanhRateSynapse_slope\", 0.1)\n\nnet.insert(Leak())\n
This network consists of three neurons arranged in two layers:
net.compute_xyz()\nnet.rotate(180)\nfig, ax = plt.subplots(1, 1, figsize=(3, 2))\n_ = net.vis(ax=ax, detail=\"full\", layers=[2, 1], layer_kwargs={\"within_layer_offset\": 100.0, \"between_layer_offset\": 100.0}) \n
We consider the last neuron as the output neuron and record the voltage from there:
net.delete_recordings()\nnet.cell(0).branch(0).loc(0.0).record()\nnet.cell(1).branch(0).loc(0.0).record()\nnet.cell(2).branch(0).loc(0.0).record()\n
Added 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\n
"},{"location":"tutorial/07_gradient_descent/#defining-a-dataset","title":"Defining a dataset","text":"We will train this biophysical network on a classification task. The inputs will be values and the label is binary:
inputs = jnp.asarray(np.random.rand(100, 2))\nlabels = jnp.asarray((inputs[:, 0] + inputs[:, 1]) > 1.0)\n
fig, ax = plt.subplots(1, 1, figsize=(3, 2))\n_ = ax.scatter(inputs[labels, 0], inputs[labels, 1])\n_ = ax.scatter(inputs[~labels, 0], inputs[~labels, 1])\n
labels = labels.astype(float)\n
"},{"location":"tutorial/07_gradient_descent/#defining-trainable-parameters","title":"Defining trainable parameters","text":"net.delete_trainables()\n
This follows the same API as .set()
seen in the previous tutorial. If you want to use a single parameter for all radius
es in the entire network, do:
net.make_trainable(\"radius\")\n
Number of newly added trainable parameters: 1. Total number of trainable parameters: 1\n
We can also define parameters for individual compartments. To do this, use the \"all\"
key. The following defines a separate parameter the sodium conductance for every compartment in the entire network:
net.cell(\"all\").branch(\"all\").loc(\"all\").make_trainable(\"Leak_gLeak\")\n
Number of newly added trainable parameters: 18. Total number of trainable parameters: 19\n
"},{"location":"tutorial/07_gradient_descent/#making-synaptic-parameters-trainable","title":"Making synaptic parameters trainable","text":"Synaptic parameters can be made trainable in the exact same way. To use a single parameter for all syanptic conductances in the entire network, do
net.TanhRateSynapse.make_trainable(\"TanhRateSynapse_gS\")\n
Here, we use a different syanptic conductance for all syanpses. This can be done as follows:
net.TanhRateSynapse.edge(\"all\").make_trainable(\"TanhRateSynapse_gS\")\n
Number of newly added trainable parameters: 2. Total number of trainable parameters: 21\n
"},{"location":"tutorial/07_gradient_descent/#running-the-simulation","title":"Running the simulation","text":"Once all parameters are defined, you have to use .get_parameters()
to obtain all trainable parameters. This is also the time to check how many trainable parameters your network has:
params = net.get_parameters()\n
You can now run the simulation with the trainable parameters by passing them to the jx.integrate
function.
s = jx.integrate(net, params=params, t_max=10.0)\n
"},{"location":"tutorial/07_gradient_descent/#stimulating-the-network","title":"Stimulating the network","text":"The network above does not yet get any stimuli. We will use the 2D inputs from the dataset to stimulate the two input neurons. The amplitude of the step current corresponds to the input value. Below is the simulator that defines this:
def simulate(params, inputs):\n currents = jx.datapoint_to_step_currents(i_delay=1.0, i_dur=1.0, i_amp=inputs / 10, delta_t=0.025, t_max=10.0)\n\n data_stimuli = None\n data_stimuli = net.cell(0).branch(2).loc(1.0).data_stimulate(currents[0], data_stimuli=data_stimuli)\n data_stimuli = net.cell(1).branch(2).loc(1.0).data_stimulate(currents[1], data_stimuli=data_stimuli)\n\n return jx.integrate(net, params=params, data_stimuli=data_stimuli, delta_t=0.025)\n\nbatched_simulate = vmap(simulate, in_axes=(None, 0))\n
We can also inspect some traces:
traces = batched_simulate(params, inputs[:4])\n
fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(traces[:, 2, :].T)\n
"},{"location":"tutorial/07_gradient_descent/#defining-a-loss-function","title":"Defining a loss function","text":"Let us define a loss function to be optimized:
def loss(params, inputs, labels):\n traces = batched_simulate(params, inputs) # Shape `(batchsize, num_recordings, timepoints)`.\n prediction = jnp.mean(traces[:, 2], axis=1) # Use the average over time of the output neuron (2) as prediction.\n prediction = (prediction + 72.0) / 5 # Such that the prediction is roughly in [0, 1].\n losses = jnp.abs(prediction - labels) # Mean absolute error loss.\n return jnp.mean(losses) # Average across the batch.\n
And we can use JAX
\u2019s inbuilt functions to take the gradient through the entire ODE:
jitted_grad = jit(value_and_grad(loss, argnums=0))\n
value, gradient = jitted_grad(params, inputs[:4], labels[:4])\n
"},{"location":"tutorial/07_gradient_descent/#defining-parameter-transformations","title":"Defining parameter transformations","text":"Before training, however, we will enforce for all parameters to be within a prespecified range (such that, e.g., conductances can not become negative)
import jaxley.optimize.transforms as jt\n
# Define a function to create appropriate transforms for each parameter\ndef create_transform(name):\n if name == \"axial_resistivity\":\n # Must be positive; apply Softplus and scale to match initialization\n return jt.ChainTransform([jt.SoftplusTransform(0), jt.AffineTransform(5000, 0)])\n elif name == \"length\":\n # Apply Softplus and affine transform for the 'length' parameter\n return jt.ChainTransform([jt.SoftplusTransform(0), jt.AffineTransform(10, 0)])\n else:\n # Default to a Softplus transform for other parameters\n return jt.SoftplusTransform(0)\n\n# Apply the transforms to the parameters\ntransforms = [{k: create_transform(k) for k in param} for param in params]\ntf = jt.ParamTransform(transforms)\n
transform = jx.ParamTransform([{\"radius\": jt.SigmoidTransform(0.1, 5.0)},\n {\"Leak_gLeak\":jt.SigmoidTransform(1e-5, 1e-3)},\n {\"TanhRateSynapse_gS\" : jt.SigmoidTransform(1e-5, 1e-2)}])\n
With these modify the loss function acocrdingly:
def loss(opt_params, inputs, labels):\n transform.forward(opt_params)\n\n traces = batched_simulate(params, inputs) # Shape `(batchsize, num_recordings, timepoints)`.\n prediction = jnp.mean(traces[:, 2], axis=1) # Use the average over time of the output neuron (2) as prediction.\n prediction = (prediction + 72.0) # Such that the prediction is around 0.\n losses = jnp.abs(prediction - labels) # Mean absolute error loss.\n return jnp.mean(losses) # Average across the batch.\n
"},{"location":"tutorial/07_gradient_descent/#using-checkpointing","title":"Using checkpointing","text":"Checkpointing allows to vastly reduce the memory requirements of training biophysical models (see also JAX\u2019s full tutorial on checkpointing).
t_max = 5.0\ndt = 0.025\n\nlevels = 2\ntime_points = t_max // dt + 2\ncheckpoints = [int(np.ceil(time_points**(1/levels))) for _ in range(levels)]\n
To enable checkpointing, we have to modify the simulate
function appropriately and use
jx.integrate(..., checkpoint_inds=checkpoints)\n
as done below: def simulate(params, inputs):\n currents = jx.datapoint_to_step_currents(i_delay=1.0, i_dur=1.0, i_amp=inputs / 10.0, delta_t=dt, t_max=t_max)\n\n data_stimuli = None\n data_stimuli = net.cell(0).branch(2).loc(1.0).data_stimulate(currents[0], data_stimuli=data_stimuli)\n data_stimuli = net.cell(1).branch(2).loc(1.0).data_stimulate(currents[1], data_stimuli=data_stimuli)\n\n return jx.integrate(net, params=params, data_stimuli=data_stimuli, checkpoint_lengths=checkpoints)\n\nbatched_simulate = vmap(simulate, in_axes=(None, 0))\n\n\ndef predict(params, inputs):\n traces = simulate(params, inputs) # Shape `(batchsize, num_recordings, timepoints)`.\n prediction = jnp.mean(traces[2]) # Use the average over time of the output neuron (2) as prediction.\n return prediction + 72.0 # Such that the prediction is around 0.\n\nbatched_predict = vmap(predict, in_axes=(None, 0))\n\n\ndef loss(opt_params, inputs, labels):\n params = transform.forward(opt_params)\n\n predictions = batched_predict(params, inputs)\n losses = jnp.abs(predictions - labels) # Mean absolute error loss.\n return jnp.mean(losses) # Average across the batch.\n\njitted_grad = jit(value_and_grad(loss, argnums=0))\n
"},{"location":"tutorial/07_gradient_descent/#training","title":"Training","text":"We will use the ADAM optimizer from the optax library to optimize the free parameters (you have to install the package with pip install optax
first):
import optax\n
opt_params = transform.inverse(params)\noptimizer = optax.adam(learning_rate=0.01)\nopt_state = optimizer.init(opt_params)\n
"},{"location":"tutorial/07_gradient_descent/#writing-a-dataloader","title":"Writing a dataloader","text":"Below, we just write our own (very simple) dataloader. Alternatively, you could use the dataloader from any deep learning library such as pytorch or tensorflow:
class Dataset:\n def __init__(self, inputs: np.ndarray, labels: np.ndarray):\n \"\"\"Simple Dataloader.\n\n Args:\n inputs: Array of shape (num_samples, num_dim)\n labels: Array of shape (num_samples,)\n \"\"\"\n assert len(inputs) == len(labels), \"Inputs and labels must have same length\"\n self.inputs = inputs\n self.labels = labels\n self.num_samples = len(inputs)\n self._rng_state = None\n self.batch_size = 1\n\n def shuffle(self, seed=None):\n \"\"\"Shuffle the dataset in-place\"\"\"\n self._rng_state = np.random.get_state()[1][0] if seed is None else seed\n np.random.seed(self._rng_state)\n indices = np.random.permutation(self.num_samples)\n self.inputs = self.inputs[indices]\n self.labels = self.labels[indices]\n return self\n\n def batch(self, batch_size):\n \"\"\"Create batches of the data\"\"\"\n self.batch_size = batch_size\n return self\n\n def __iter__(self):\n self.shuffle(seed=self._rng_state)\n for start in range(0, self.num_samples, self.batch_size):\n end = min(start + self.batch_size, self.num_samples)\n yield self.inputs[start:end], self.labels[start:end]\n self._rng_state += 1\n
"},{"location":"tutorial/07_gradient_descent/#training-loop","title":"Training loop","text":"batch_size = 4\ndataloader = Dataset(inputs, labels)\ndataloader = dataloader.shuffle(seed=0).batch(batch_size)\n\nfor epoch in range(10):\n epoch_loss = 0.0\n\n for batch_ind, batch in enumerate(dataloader):\n current_batch, label_batch = batch\n loss_val, gradient = jitted_grad(opt_params, current_batch, label_batch)\n updates, opt_state = optimizer.update(gradient, opt_state)\n opt_params = optax.apply_updates(opt_params, updates)\n epoch_loss += loss_val\n\n print(f\"epoch {epoch}, loss {epoch_loss}\")\n\nfinal_params = transform.forward(opt_params)\n
epoch 0, loss 25.033223182772293\nepoch 1, loss 21.00894915349165\nepoch 2, loss 15.092242959956026\nepoch 3, loss 9.061544660383163\nepoch 4, loss 6.925509860325612\nepoch 5, loss 6.273630037897756\nepoch 6, loss 6.1757316054693145\nepoch 7, loss 6.135132525725265\nepoch 8, loss 6.145608619185389\nepoch 9, loss 6.135660902068834\n
ntest = 32\npredictions = batched_predict(final_params, inputs[:ntest])\n
fig, ax = plt.subplots(1, 1, figsize=(3, 2))\n_ = ax.scatter(labels[:ntest], predictions)\n_ = ax.set_xlabel(\"Label\")\n_ = ax.set_ylabel(\"Prediction\")\n
Indeed, the loss goes down and the network successfully classifies the patterns.
"},{"location":"tutorial/07_gradient_descent/#summary","title":"Summary","text":"Puh, this was a pretty dense tutorial with a lot of material. You should have learned how to:
This was the last \u201cbasic\u201d tutorial of the Jaxley
toolbox. If you want to learn more, check out our Advanced Tutorials. If anything is still unclear please create a discussion. If you find any bugs, please open an issue. Happy coding!
In this tutorial, you will learn how to:
Jaxley
Here is a code snippet which you will learn to understand in this tutorial:
import jaxley as jx\n\ncell = jx.read_swc(\"my_cell.swc\", ncomp=4)\ncell.branch(2).set_ncomp(2) # Modify the number of compartments of a branch.\n
To work with more complicated morphologies, Jaxley
supports importing morphological reconstructions via .swc
files. .swc
is currently the only supported format. Other formats like .asc
need to be converted to .swc
first, for example using the BlueBrain\u2019s morph-tool. For more information on the exact specifications of .swc
see here.
import jaxley as jx\nfrom jaxley.synapses import IonotropicSynapse\nimport matplotlib.pyplot as plt\n
To work with .swc
files, Jaxley
implements a custom .swc
reader. The reader traces the morphology and identifies all uninterrupted sections. These are then partitioned into branches, each of which will be approximated by a number of equally many compartments that can be simulated fully in parallel.
To demonstrate this, let\u2019s import an example morphology of a Layer 5 pyramidal cell and visualize it.
# import swc file into jx.Cell object\nfname = \"data/morph.swc\"\ncell = jx.read_swc(fname, ncomp=8) # Use four compartments per branch.\n\n# print shape (num_branches, num_comps)\nprint(cell.shape)\n\ncell.show()\n
(157, 1256)\n
local_comp_index global_comp_index local_branch_index global_branch_index local_cell_index global_cell_index 0 0 0 0 0 0 0 1 1 1 0 0 0 0 2 2 2 0 0 0 0 3 3 3 0 0 0 0 4 4 4 0 0 0 0 ... ... ... ... ... ... ... 1251 3 1251 156 156 0 0 1252 4 1252 156 156 0 0 1253 5 1253 156 156 0 0 1254 6 1254 156 156 0 0 1255 7 1255 156 156 0 0 1256 rows \u00d7 6 columns
As we can see, this yields a morphology that is approximated by 1256 compartments. Depending on the amount of detail that you need, you can also change the number of compartments in each branch:
cell = jx.read_swc(fname, ncomp=2)\n\n# print shape (num_branches, num_comps)\nprint(cell.shape)\n\ncell.show()\n
(157, 314)\n
local_comp_index global_comp_index local_branch_index global_branch_index local_cell_index global_cell_index 0 0 0 0 0 0 0 1 1 1 0 0 0 0 2 0 2 1 1 0 0 3 1 3 1 1 0 0 4 0 4 2 2 0 0 ... ... ... ... ... ... ... 309 1 309 154 154 0 0 310 0 310 155 155 0 0 311 1 311 155 155 0 0 312 0 312 156 156 0 0 313 1 313 156 156 0 0 314 rows \u00d7 6 columns
The above assigns the same number of compartments to every branch. To use a different number of compartments in individual branches, you can use .set_ncomp()
:
cell.branch(1).set_ncomp(4)\n
As you can see below, branch 0
has two compartments (because this is what was passed to jx.read_swc(..., ncomp=2)
), but branch 1
has four compartments:
cell.branch([0, 1]).nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v global_cell_index global_branch_index global_comp_index controlled_by_param 0 0 0 0 0.050000 8.119000 5000.0 1.0 -70.0 0 0 0 0 1 0 0 1 0.050000 8.119000 5000.0 1.0 -70.0 0 0 1 0 2 0 1 0 3.120779 7.806172 5000.0 1.0 -70.0 0 1 2 1 3 0 1 1 3.120779 7.111231 5000.0 1.0 -70.0 0 1 3 1 4 0 1 2 3.120779 5.652394 5000.0 1.0 -70.0 0 1 4 1 5 0 1 3 3.120779 3.869247 5000.0 1.0 -70.0 0 1 5 1 Once imported the compartmentalized morphology can be viewed using vis
.
# visualize the cell\ncell.vis()\nplt.axis(\"off\")\nplt.title(\"L5PC\")\nplt.show()\n
vis
can be called on any jx.Module
and every View
of the module. This means we can also for example use vis
to highlight each branch. This can be done by iterating over each branch index and calling cell.branch(i).vis()
. Within the loop.
fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n# define colorwheel with 10 colors\ncolors = plt.cm.tab10.colors\nfor i, branch in enumerate(cell.branches):\n branch.vis(ax=ax, col=colors[i % 10])\nplt.axis(\"off\")\nplt.title(\"Branches\")\nplt.show()\n
While we only use two compartments to approximate each branch in this example, we can see the morphology is still plotted in great detail. This is because we always plot the full .swc
reconstruction irrespective of the number of compartments used. The morphology lives seperately in the cell.xyzr
attribute in a per branch fashion.
In addition to plotting the full morphology of the cell using points vis(type=\"scatter\")
or lines vis(type=\"line\")
, Jaxley
also supports plotting a detailed morphological vis(type=\"morph\")
or approximate compartmental reconstruction vis(type=\"comp\")
that correctly considers the thickness of the neurite. Note that \"comp\"
plots the lengths of each compartment which is equal to the length of the traced neurite. While neurites can be zigzaggy, the compartments that approximate them are straight lines. This can lead to miss-aligment of the compartment ends. For details see the documentation of vis
.
The morphologies can either be projected onto 2D or also rendered in 3D.
# visualize the cell\nfig, ax = plt.subplots(1, 4, figsize=(10, 3), layout=\"constrained\", sharex=True, sharey=True)\ncell.vis(ax=ax[0], type=\"morph\", dims=[0,1])\ncell.vis(ax=ax[1], type=\"comp\", dims=[0,1])\ncell.vis(ax=ax[2], type=\"scatter\", dims=[0,1], morph_plot_kwargs={\"s\": 1})\ncell.vis(ax=ax[3], type=\"line\", dims=[0,1])\nfig.suptitle(\"Comparison of plot types\")\nplt.show()\n
# set to interactive mode\n# %matplotlib notebook\n
# plot in 3D\nfig = plt.figure()\nax = fig.add_subplot(111, projection='3d')\ncell.vis(ax=ax, type=\"line\", dims=[2,0,1])\nax.view_init(elev=20, azim=5)\nplt.show()\n
Since Jaxley
supports grouping different branches or compartments together, we can also use the id
labels provided by the .swc
file to assign group labels to the jx.Cell
object.
print(list(cell.groups.keys()))\n\nfig, ax = plt.subplots(1, 1, figsize=(5, 5))\ncolors = plt.cm.tab10.colors\ncell.basal.vis(ax=ax, col=colors[2])\ncell.soma.vis(ax=ax, col=colors[1])\ncell.apical.vis(ax=ax, col=colors[0])\nplt.axis(\"off\")\nplt.title(\"Groups\")\nplt.show()\n
['soma', 'basal', 'apical', 'custom']\n
To build a network of morphologically detailed cells, we can now connect several reconstructed cells together and also visualize the network. However, since all cells are going to have the same center, Jaxley
will naively plot all of them on top of each other. To seperate out the cells, we therefore have to move them to a new location first.
net = jx.Network([cell]*5)\njx.connect(net[0,0,0], net[2,0,0], IonotropicSynapse())\njx.connect(net[0,0,0], net[3,0,0], IonotropicSynapse())\njx.connect(net[0,0,0], net[4,0,0], IonotropicSynapse())\n\njx.connect(net[1,0,0], net[2,0,0], IonotropicSynapse())\njx.connect(net[1,0,0], net[3,0,0], IonotropicSynapse())\njx.connect(net[1,0,0], net[4,0,0], IonotropicSynapse())\n\nnet.rotate(-90)\n\nnet.cell(0).move(0, 300)\nnet.cell(1).move(0, 500)\n\nnet.cell(2).move(900, 200)\nnet.cell(3).move(900, 400)\nnet.cell(4).move(900, 600)\n\nnet.vis()\nplt.axis(\"off\")\nplt.show()\n
Congrats! You have now learned how to vizualize and build networks out of very complex morphologies. To simulate this network, you can follow the steps in the tutorial on how to build a network.
"},{"location":"tutorial/09_advanced_indexing/","title":"Customizing synaptic parameters","text":"In this tutorial, you will learn how to:
select()
method to fully customize network simulations with Jaxley
. copy_node_property_to_edges()
method to flexibly modify synapses. Here is a code snippet which you will learn to understand in this tutorial:
net = ... # See tutorial on Basics of Jaxley.\n\n# Set synaptic conductance of the synapse with index 0 and 1.\nnet.select(edges=[0, 1]).set(\"Ionotropic_gS\", 0.1)\n\n# Set synaptic conductance of all synapses that have cells 3 or 4 as presynaptic neuron.\nnet.copy_node_property_to_edges(\"global_cell_index\")\ndf = net.edges\ndf = df.query(\"pre_global_cell_index in [3, 4]\")\nnet.select(edges=df.index).set(\"Ionotropic_gS\", 0.2)\n\n# Set synaptic conductance of all synapses that\n# 1) have cells 2 or 3 as presynaptic neuron and\n# 2) has cell 5 as postsynaptic neuron\ndf = net.edges\ndf = df.query(\"pre_global_cell_index in [2, 3]\")\ndf = df.query(\"post_global_cell_index == 5\")\nnet.select(edges=df.index).set(\"Ionotropic_gS\", 0.3)\n
In a previous tutorial you learned how to set parameters of a jx.Network
. In that tutorial, we briefly mentioned the select()
method which allowed to set individual synapses to particular values. In this tutorial, we will go into detail in how you can fully customize your Jaxley
simulation.
Let\u2019s go!
import jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.connect import fully_connect\nfrom jaxley.synapses import IonotropicSynapse\n
"},{"location":"tutorial/09_advanced_indexing/#preface-building-the-network","title":"Preface: Building the network","text":"We first build a network consisting of six neurons, in the same way as we showed in the previous tutorials:
dt = 0.025\nt_max = 10.0\n\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, nseg=2)\ncell = jx.Cell(branch, parents=[-1, 0])\nnet = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n
"},{"location":"tutorial/09_advanced_indexing/#setting-individual-synapse-parameters","title":"Setting individual synapse parameters","text":"As always, you can use the .edges
table to inspect synaptic parameters of the network:
net.edges\n
global_edge_index pre_global_comp_index post_global_comp_index type type_ind pre_locs post_locs IonotropicSynapse_gS IonotropicSynapse_e_syn IonotropicSynapse_k_minus IonotropicSynapse_s controlled_by_param 0 0 0 13 IonotropicSynapse 0 0.25 0.75 0.0001 0.0 0.025 0.2 0 1 1 0 19 IonotropicSynapse 0 0.25 0.75 0.0001 0.0 0.025 0.2 0 2 2 0 20 IonotropicSynapse 0 0.25 0.25 0.0001 0.0 0.025 0.2 0 3 3 4 12 IonotropicSynapse 0 0.25 0.25 0.0001 0.0 0.025 0.2 0 4 4 4 16 IonotropicSynapse 0 0.25 0.25 0.0001 0.0 0.025 0.2 0 5 5 4 21 IonotropicSynapse 0 0.25 0.75 0.0001 0.0 0.025 0.2 0 6 6 8 13 IonotropicSynapse 0 0.25 0.75 0.0001 0.0 0.025 0.2 0 7 7 8 17 IonotropicSynapse 0 0.25 0.75 0.0001 0.0 0.025 0.2 0 8 8 8 21 IonotropicSynapse 0 0.25 0.75 0.0001 0.0 0.025 0.2 0 This table has nine rows, each corresponding to one synapse. This makes sense because we fully connected three neurons (0, 1, 2) to three other neurons (3, 4, 5), giving a total of 3x3=9
synapses.
You can modify parameters of individual synapses as follows:
net.select(edges=[3, 4, 5]).set(\"IonotropicSynapse_gS\", 0.2)\n
Above, we are modifying the synapses with indices [3, 4, 5]
(i.e., the indices of the net.edges
DataFrame). The resulting values are indeed changed:
net.edges.IonotropicSynapse_gS\n
0 0.0001\n1 0.0001\n2 0.0001\n3 0.2000\n4 0.2000\n5 0.2000\n6 0.0001\n7 0.0001\n8 0.0001\nName: IonotropicSynapse_gS, dtype: float64\n
"},{"location":"tutorial/09_advanced_indexing/#example-1-setting-synaptic-parameters-which-connect-particular-neurons","title":"Example 1: Setting synaptic parameters which connect particular neurons","text":"This is great, but setting synaptic parameters just by their index can be exhausting, in particular in very large networks. Instead, we would want to, for example, set the maximal conductance of all synapses that connect from cell 0 or 1 to any other neuron.
In Jaxley
, such customization can be achieved by filtering the .edges
dataframe accordingly, as shown below:
net = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n\nnet.copy_node_property_to_edges(\"global_cell_index\")\ndf = net.edges\ndf = df.query(\"pre_global_cell_index in [0, 1]\")\nnet.select(edges=df.index).set(\"IonotropicSynapse_gS\", 0.23)\n
net.edges.IonotropicSynapse_gS\n
0 0.2300\n1 0.2300\n2 0.2300\n3 0.2300\n4 0.2300\n5 0.2300\n6 0.0001\n7 0.0001\n8 0.0001\nName: IonotropicSynapse_gS, dtype: float64\n
Indeed, the first six synapses now have the value 0.23
! Let\u2019s look at the individual lines to understand how this worked:
We want to set parameter by cell index. However, by default, the pre- or post-synaptic cell-indices are not listed in net.edges
. We can add the cell index to the .edges
dataframe by calling .copy_node_property_to_edges()
:
net.copy_node_property_to_edges(\"global_cell_index\")\n
After this, the pre- and post-synaptic cell indices are listed in net.edges
as pre_global_cell_index
and post_global_cell_index
.
Next, we take .edges
, which is a pandas DataFrame:
df = net.edges\n
We then modify this DataFrame to only contain those rows where the global cell index is in 0 or 1:
df = df.query(\"pre_global_cell_index in [0, 1]\")\n
For the above step, you use any column of the DataFrame to filter it (you can see all columns with df.columns
). Note that, while we used .query()
here, you can really filter the pandas DataFrame however you want. For example, the query
above is identical to df = df[df[\"pre_global_cell_index\"].isin([0, 1])]
.
Finally, we use the .select()
method, which returns a subset of the Network
at the specified indices. This subset of the network can be modified with .set()
:
net.select(edges=df.index).set(\"IonotropicSynapse_gS\", 0.23)\n
"},{"location":"tutorial/09_advanced_indexing/#example-2-setting-parameters-given-pre-and-post-synaptic-cell-indices","title":"Example 2: Setting parameters given pre- and post-synaptic cell indices","text":"Say you want to select all synapses that have cells 1 or 2 as presynaptic neuron and cell 4 or 5 as postsynaptic neuron.
net = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n
Just like before, we can simply use .query()
as already shown above. However, this time, call .query()
to twice to filter by pre- and post-synaptic cell indices:
net.copy_node_property_to_edges(\"global_cell_index\")\n\ndf = net.edges\ndf = df.query(\"pre_global_cell_index in [1, 2]\")\ndf = df.query(\"post_global_cell_index in [4, 5]\")\nnet.select(edges=df.index).set(\"IonotropicSynapse_gS\", 0.3)\n
net.edges.IonotropicSynapse_gS\n
0 0.0001\n1 0.0001\n2 0.0001\n3 0.0001\n4 0.3000\n5 0.3000\n6 0.0001\n7 0.3000\n8 0.3000\nName: IonotropicSynapse_gS, dtype: float64\n
"},{"location":"tutorial/09_advanced_indexing/#example-3-applying-this-strategy-to-cell-level-parameters","title":"Example 3: Applying this strategy to cell level parameters","text":"You had previously seen that you can modify parameters with, e.g., net.cell(0).set(...)
. However, if you need more flexibility than this, you can also use the above strategy to modify cell-level parameters:
net = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n\ndf = net.nodes\ndf = df.query(\"global_cell_index in [0, 1]\")\nnet.select(nodes=df.index).set(\"radius\", 0.1)\n
"},{"location":"tutorial/09_advanced_indexing/#example-4-flexibly-setting-parameters-based-on-their-groups","title":"Example 4: Flexibly setting parameters based on their groups
","text":"If you are using groups, as shown in this tutorial, then you can also use this for querying synapses. To demonstrate this, let\u2019s create a group of excitatory neurons (e.g., cells 0, 3, 5):
# Redefine network.\nnet = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n\nnet.cell([0, 3, 5]).add_to_group(\"exc\")\n
Now, say we want all synapses that start from these excitatory neurons. You can do this as follows:
# First, we have to identify which cells are in the `exc` group.\nindices_of_excitatory_cells = net.exc.nodes[\"global_cell_index\"].unique().tolist() # [0, 3, 5]\n\n# Then we can proceed as before:\nnet.copy_node_property_to_edges(\"global_cell_index\")\ndf = net.edges\ndf = df.query(f\"pre_global_cell_index in {indices_of_excitatory_cells}\")\nnet.select(edges=df.index).set(\"IonotropicSynapse_gS\", 0.4)\n
"},{"location":"tutorial/09_advanced_indexing/#example-5-setting-synaptic-parameters-based-on-properties-of-the-presynaptic-cell","title":"Example 5: Setting synaptic parameters based on properties of the presynaptic cell","text":"Let\u2019s discuss one more example: Imagine we only want to modify those synapses whose presynaptic compartment has a sodium channel. Let\u2019s first add a sodium channel to some of the cells:
net = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n\nnet.cell(0).branch(0).comp(0).insert(Na())\nnet.cell(2).branch(1).comp(1).insert(Na())\n
Now, let us query which cells have the desired synapses:
df = net.nodes\ndf = df.query(\"Na\")\nindices_of_sodium_compartments = df[\"global_comp_index\"].unique().tolist()\n
indices_of_sodium_compartments
lists all compartments which contained sodium:
print(indices_of_sodium_compartments)\n
[0, 11]\n
Then, we can proceed as always and filter for the global pre-synaptic compartment index:
df = net.edges\ndf = df.query(f\"pre_global_comp_index in {indices_of_sodium_compartments}\")\nnet.select(edges=df.index).set(\"IonotropicSynapse_gS\", 0.6)\n
net.edges.IonotropicSynapse_gS\n
0 0.6000\n1 0.6000\n2 0.6000\n3 0.0001\n4 0.0001\n5 0.0001\n6 0.0001\n7 0.0001\n8 0.0001\nName: IonotropicSynapse_gS, dtype: float64\n
Indeed, only synapses coming from the first neuron were modified (as its presynaptic compartment contained sodium), in contrast to synapses from neuron 2 (whose presynaptic compartment did not).
"},{"location":"tutorial/09_advanced_indexing/#summary","title":"Summary","text":"In this tutorial, you learned how to fully customize your Jaxley
simulation. This works by querying rows from the .edges
DataFrame.
In this tutorial, you will learn how to:
Here is a code snippet which you will learn to understand in this tutorial:
net = ... # See tutorial on Basics of Jaxley.\n\n# The same parameter for all synapses\nnet.make_trainable(\"Ionotropic_gS\")\n\n# An individual parameter for every synapse.\nnet.select(edges=\"all\").make_trainable(\"Ionotropic_gS\")\n\n# Share synaptic conductances emerging from the same neurons.\nnet.copy_node_property_to_edges(\"cell_index\")\nsub_net = net.select(edges=[0, 1, 2])\nsub_net.edges[\"controlled_by_param\"] = sub_net.edges[\"pre_global_cell_index\"]\nsub_net.make_trainable(\"Ionotropic_gS\")\n
In a previous tutorial about training networks, we briefly touched on parameter sharing. In this tutorial, we will show you how you can flexibly share parameters within a network.
import jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.connect import fully_connect\nfrom jaxley.synapses import IonotropicSynapse\n
"},{"location":"tutorial/10_advanced_parameter_sharing/#preface-building-the-network","title":"Preface: Building the network","text":"We first build a network consisting of six neurons, in the same way as we showed in the previous tutorials:
dt = 0.025\nt_max = 10.0\n\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=2)\ncell = jx.Cell(branch, parents=[-1, 0])\nnet = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n
"},{"location":"tutorial/10_advanced_parameter_sharing/#sharing-parameters-by-modifying-controlled_by_param","title":"Sharing parameters by modifying controlled_by_param
","text":"net.copy_node_property_to_edges(\"global_cell_index\")\n\ndf = net.edges\ndf = df.query(\"pre_global_cell_index in [0, 1, 2]\")\nsubnetwork = net.select(edges=df.index)\n\ndf = subnetwork.edges\ndf[\"controlled_by_param\"] = df[\"pre_global_cell_index\"]\nsubnetwork.make_trainable(\"IonotropicSynapse_gS\")\n
Number of newly added trainable parameters: 3. Total number of trainable parameters: 3\n
Let\u2019s look at this line by line. First, we exactly follow the previous tutorial in selecting the synapses which we are interested in training (i.e., the ones whose presynaptic neuron has index 0, 1, 2):
df = net.edges\ndf = df.query(\"pre_global_cell_index in [0, 1, 2]\")\nsubnetwork = net.select(edges=df.index)\n
As second step, we enable parameter sharing. This is done by setting the controlled_by_param
. Synapses that have the same value in controlled_by_param
will be shared. Let\u2019s inspect controlled_by_param
before we modify it:
subnetwork.edges[[\"pre_global_cell_index\", \"controlled_by_param\"]]\n
pre_global_cell_index controlled_by_param 0 0 0 1 0 1 2 0 2 3 1 3 4 1 4 5 1 5 6 2 6 7 2 7 8 2 8 Every synapse has a different value. Because of this, no synaptic parameters will be shared. To enable parameter sharing we override the controlled_by_param
column with the presynaptic cell index:
df = subnetwork.edges\ndf[\"controlled_by_param\"] = df[\"pre_global_cell_index\"]\n
df[[\"pre_global_cell_index\", \"controlled_by_param\"]]\n
pre_global_cell_index controlled_by_param 0 0 0 1 0 0 2 0 0 3 1 1 4 1 1 5 1 1 6 2 2 7 2 2 8 2 2 Now, all we have to do is to make these synaptic parameters trainable with the make_trainable()
method:
subnetwork.make_trainable(\"IonotropicSynapse_gS\")\n
Number of newly added trainable parameters: 3. Total number of trainable parameters: 6\n
It correctly says that we added three parameters (because we have three cells, and we share individual synaptic parameters). We now have 6 trainable parameters in total (because we already added 3 trainable parameters above).
"},{"location":"tutorial/10_advanced_parameter_sharing/#a-more-involved-example-sharing-by-pre-and-post-synaptic-cell-type","title":"A more involved example: sharing by pre- and post-synaptic cell type","text":"As an example, consider the following: We have a fully connected network of six cells. Each cell falls into one of three cell types:
from typing import Union, List\n
net = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell(\"all\"), net.cell(\"all\"), IonotropicSynapse())\n\nnet.cell([0, 1]).add_to_group(\"exc\")\nnet.cell([2, 3]).add_to_group(\"inh\")\nnet.cell([4, 5]).add_to_group(\"unknown\")\n
We want to make all synapses that start from excitatory or inhibitory neurons trainable. In addition, we want to use the same parameter for synapses if they have the same pre- and post-synaptic cell type.
To achieve this, we will first want a column in net.nodes
which indicates the cell type.
for group, inds in net.groups.items():\n net.nodes.loc[inds, \"cell_type\"] = group\n
net.nodes[\"cell_type\"]\n
0 exc\n1 exc\n2 exc\n3 exc\n4 exc\n5 exc\n6 exc\n7 exc\n8 inh\n9 inh\n10 inh\n11 inh\n12 inh\n13 inh\n14 inh\n15 inh\n16 unknown\n17 unknown\n18 unknown\n19 unknown\n20 unknown\n21 unknown\n22 unknown\n23 unknown\nName: cell_type, dtype: object\n
The cell_type
is now part of the net.nodes
. However, we would like to do parameter sharing of synapses based on the pre- and post-synaptic node values. To do so, we import the cell_type
column into net.edges
. To do this, we use the .copy_node_property_to_edges()
which the name of the property you are copying from nodes:
net.copy_node_property_to_edges(\"cell_type\")\n
After this, you have columns in the .edges
which indicate the pre- and post-synaptic cell type:
net.edges[[\"pre_cell_type\", \"post_cell_type\"]]\n
pre_cell_type post_cell_type 0 exc exc 1 exc exc 2 exc inh 3 exc inh 4 exc unknown 5 exc unknown 6 exc exc 7 exc exc 8 exc inh 9 exc inh 10 exc unknown 11 exc unknown 12 inh exc 13 inh exc 14 inh inh 15 inh inh 16 inh unknown 17 inh unknown 18 inh exc 19 inh exc 20 inh inh 21 inh inh 22 inh unknown 23 inh unknown 24 unknown exc 25 unknown exc 26 unknown inh 27 unknown inh 28 unknown unknown 29 unknown unknown 30 unknown exc 31 unknown exc 32 unknown inh 33 unknown inh 34 unknown unknown 35 unknown unknown Next, we specify which parts of the network we actually want to change (in this case, all synapses which have excitatory or inhibitory presynaptic neurons):
df = net.edges\ndf = df.query(f\"pre_cell_type in ['exc', 'inh']\")\nprint(f\"There are {len(df)} synapses to be changed.\")\n\nsubnetwork = net.select(edges=df.index)\n
There are 24 synapses to be changed.\n
As the last step, we again have to specify parameter sharing by setting controlled_by_param
. In this case, we want to share parameters that have the same pre- and post-synaptic neuron. We achieve this by grouping the synpases by their pre- and post-synaptic cell type (see pd.DataFrame.groupby for details):
# Step 6: use groupby to specify parameter sharing and make the parameters trainable.\nsubnetwork.edges[\"controlled_by_param\"] = subnetwork.edges.groupby([\"pre_cell_type\", \"post_cell_type\"]).ngroup()\nsubnetwork.make_trainable(\"IonotropicSynapse_gS\")\n
Number of newly added trainable parameters: 6. Total number of trainable parameters: 6\n
This created six trainable parameters, which makes sense as we have two types of pre-synaptic neurons (excitatory and inhibitory) and each has three options for the postsynaptic neuron (pre, post, unknown).
"},{"location":"tutorial/10_advanced_parameter_sharing/#summary","title":"Summary","text":"In this tutorial, you learned how you can flexibly share synaptic parameters. This works by first using select()
to identify which synapses to make trainable, and by then modifying controlled_by_param
to customize parameter sharing.
The official documentation for Jaxley has moved to jaxley.readthedocs.io. The website you are currently on will be taken down in the future.
Jaxley
is a differentiable simulator for biophysical neuron models in JAX. Its key features are:
jit
-compilation, making it as fast as other packages while being fully written in python Jaxley
allows to simulate biophysical neuron models on CPU, GPU, or TPU:
import matplotlib.pyplot as plt\nfrom jax import config\n\nimport jaxley as jx\nfrom jaxley.channels import HH\n\nconfig.update(\"jax_platform_name\", \"cpu\") # Or \"gpu\" / \"tpu\".\n\ncell = jx.Cell() # Define cell.\ncell.insert(HH()) # Insert channels.\n\ncurrent = jx.step_current(i_delay=1.0, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=10.0)\ncell.stimulate(current) # Stimulate with step current.\ncell.record(\"v\") # Record voltage.\n\nv = jx.integrate(cell) # Run simulation.\nplt.plot(v.T) # Plot voltage trace.\n
If you want to learn more, we have tutorials on how to:
Jaxley
is available on pypi
:
pip install jaxley\n
This will install Jaxley
with CPU support. If you want GPU support, follow the instructions on the JAX
github repository to install JAX
with GPU support (in addition to installing Jaxley
). For example, for NVIDIA GPUs, run pip install -U \"jax[cuda12]\"\n
"},{"location":"#feedback-and-contributions","title":"Feedback and Contributions","text":"We welcome any feedback on how Jaxley
is working for your neuron models and are happy to receive bug reports, pull requests and other feedback (see contribute). We wish to maintain a positive community, please read our Code of Conduct.
Apache License Version 2.0 (Apache-2.0)
"},{"location":"#citation","title":"Citation","text":"If you use Jaxley
, consider citing the corresponding paper:
@article{deistler2024differentiable,\n doi = {10.1101/2024.08.21.608979},\n year = {2024},\n publisher = {Cold Spring Harbor Laboratory},\n author = {Deistler, Michael and Kadhim, Kyra L. and Pals, Matthijs and Beck, Jonas and Huang, Ziwei and Gloeckler, Manuel and Lappalainen, Janne K. and Schr{\\\"o}der, Cornelius and Berens, Philipp and Gon{\\c c}alves, Pedro J. and Macke, Jakob H.},\n title = {Differentiable simulation enables large-scale training of detailed biophysical models of neural dynamics},\n journal = {bioRxiv}\n}\n
"},{"location":"code_of_conduct/","title":"Contributor Covenant Code of Conduct","text":""},{"location":"code_of_conduct/#our-pledge","title":"Our Pledge","text":"We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, caste, color, religion, or sexual identity and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community.
"},{"location":"code_of_conduct/#our-standards","title":"Our Standards","text":"Examples of behavior that contributes to a positive environment for our community include:
Examples of unacceptable behavior include:
Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful.
Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate.
"},{"location":"code_of_conduct/#scope","title":"Scope","text":"This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event.
"},{"location":"code_of_conduct/#enforcement","title":"Enforcement","text":"Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting jaxley
developer Michael Deistler via email (michael.deistler@uni-tuebingen.de). All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the reporter of any incident.
"},{"location":"code_of_conduct/#enforcement-guidelines","title":"Enforcement Guidelines","text":"Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct:
"},{"location":"code_of_conduct/#1-correction","title":"1. Correction","text":"Community Impact: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community.
Consequence: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested.
"},{"location":"code_of_conduct/#2-warning","title":"2. Warning","text":"Community Impact: A violation through a single incident or series of actions.
Consequence: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban.
"},{"location":"code_of_conduct/#3-temporary-ban","title":"3. Temporary Ban","text":"Community Impact: A serious violation of community standards, including sustained inappropriate behavior.
Consequence: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban.
"},{"location":"code_of_conduct/#4-permanent-ban","title":"4. Permanent Ban","text":"Community Impact: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals.
Consequence: A permanent ban from any sort of public interaction within the community.
"},{"location":"code_of_conduct/#attribution","title":"Attribution","text":"This Code of Conduct is adapted from the Contributor Covenant, version 2.1, available at https://www.contributor-covenant.org/version/2/1/code_of_conduct.html.
Community Impact Guidelines were inspired by Mozilla\u2019s code of conduct enforcement ladder.
For answers to common questions about this code of conduct, see the FAQ at https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations.
"},{"location":"contribute/","title":"Guide","text":""},{"location":"contribute/#user-experiences-bugs-and-feature-requests","title":"User experiences, bugs, and feature requests","text":"To report bugs and suggest features (including better documentation), please head over to issues on GitHub.
"},{"location":"contribute/#code-contributions","title":"Code contributions","text":"In general, we use pull requests to make changes to Jaxley
. So, if you are planning to make a contribution, please fork, create a feature branch and then make a PR from your feature branch to the upstream Jaxley
(details).
Clone the repo and install via setup.py
using pip install -e \".[dev]\"
(the dev flag installs development and testing dependencies).
For docstrings and comments, we use Google Style.
Code needs to pass through the following tools, which are installed alongside Jaxley
:
black: Automatic code formatting for Python. You can run black manually from the console using black .
in the top directory of the repository, which will format all files.
isort: Used to consistently order imports. You can run isort manually from the console using isort
in the top directory.
black
and isort
are checked as part of our CI actions. If these checks fail please make sure you have installed the latest versions for each of them and run them locally.
Most of the documentation is written in markdown (basic markdown guide).
You can directly fix mistakes and suggest clearer formulations in markdown files simply by initiating a PR on through GitHub. Click on documentation file and look for the little pencil at top right.
"},{"location":"credits/","title":"Credits","text":"Jaxley
is a collaborative project between the groups of Jakob Macke (Uni T\u00fcbingen), Pedro Gon\u00e7alves (KU Leuven / NERF), and Philipp Berens (Uni T\u00fcbingen).
Jaxley
is licensed under the Apache License Version 2.0 (Apache-2.0) and
Copyright (C) 2024 Michael Deistler, Jakob H. Macke, Pedro J. Goncalves, Philipp Berens.
"},{"location":"credits/#important-dependencies-and-prior-art","title":"Important dependencies and prior art","text":"This work was supported by the German Research Foundation (DFG) through Germany\u2019s Excellence Strategy (EXC 2064 \u2013 Project number 390727645) and the CRC 1233 \u201cRobust Vision\u201d, the German Federal Ministry of Education and Research (Tu\u0308bingen AI Center, FKZ: 01IS18039A), the \u2018Certification and Foundations of Safe Machine Learning Systems in Healthcare\u2019 project funded by the Carl Zeiss Foundation, and the European Union (ERC, \u201cDeepCoMechTome\u201d, ref. 101089288, \u201cNextMechMod\u201d, ref. 101039115).
"},{"location":"faq/","title":"Frequently asked questions","text":"Jaxley
? Jaxley
use? See also the discussion page and the issue tracker on the Jaxley
GitHub repository for recent questions and problems.
Jaxley
is available on PyPI
:
pip install jaxley\n
This will install Jaxley
with CPU support. If you want GPU support, follow the instructions on the JAX
github repository to install JAX
with GPU support (in addition to installing Jaxley
). For example, for NVIDIA GPUs, run pip install -U \"jax[cuda12]\"\n
"},{"location":"install/#install-from-source","title":"Install from source","text":"You can also install Jaxley
from source:
git clone https://github.com/jaxleyverse/jaxley.git\ncd jaxley\npip install -e .\n
Note that pip>=21.3
is required to install the editable version with pyproject.toml
see pip docs.
Jaxley
use?","text":"Jaxley
uses the same units as the NEURON
simulator, which are listed here.
All module
s (i.e., compartments, branches, cells, and networks) in Jaxley
can be saved and loaded with pickle:
import jaxley as jx\nimport pickle\n\n# ... define network, cell, etc.\nnetwork = jx.Network([cell1, cell2])\n\n# Save.\nwith open(\"path/to/file.pkl\", \"wb\") as handle:\n pickle.dump(network, handle)\n\n# Load.\nwith open(\"path/to/file.pkl\", \"rb\") as handle:\n network = pickle.load(handle)\n
"},{"location":"faq/question_03/","title":"What kinds of models can be implemented in Jaxley
?","text":"Jaxley
focuses on biophysical, Hodgkin-Huxley-type models. You can think of Jaxley
like the NEURON
simulator written in JAX
.
Jaxley
allows to simulate the following types of models, as well as networks thereof:
For all of these models, Jaxley
is flexible and accurate. For example, it can flexibly add new channel models, use different kinds of synapses (conductance-based, tanh, \u2026), and it can insert different kinds of channels in different branches (or compartments) within single cells. Like NEURON
, Jaxley
implements a backward-Euler solver for stable numerical solution of multi-compartment neurons.
However, Jaxley
does not implement the following types of models:
connect(pre, post, synapse_type)
","text":"Connect two compartments with a chemical synapse.
The pre- and postsynaptic compartments must be different compartments of the same network.
Parameters:
Name Type Description Defaultpre
View
View of the presynaptic compartment.
requiredpost
View
View of the postsynaptic compartment.
requiredsynapse_type
Synapse
The synapse to append
required Source code injaxley/connect.py
def connect(\n pre: \"View\",\n post: \"View\",\n synapse_type: \"Synapse\",\n):\n \"\"\"Connect two compartments with a chemical synapse.\n\n The pre- and postsynaptic compartments must be different compartments of the\n same network.\n\n Args:\n pre: View of the presynaptic compartment.\n post: View of the postsynaptic compartment.\n synapse_type: The synapse to append\n \"\"\"\n assert is_same_network(\n pre, post\n ), \"Pre and post compartments must be part of the same network.\"\n\n pre.base._append_multiple_synapses(pre.nodes, post.nodes, synapse_type)\n
"},{"location":"reference/connect/#jaxley.connect.connectivity_matrix_connect","title":"connectivity_matrix_connect(pre_cell_view, post_cell_view, synapse_type, connectivity_matrix)
","text":"Appends multiple connections which build a custom connected network.
Connects pre- and postsynaptic cells according to a custom connectivity matrix. Entries > 0 in the matrix indicate a connection between the corresponding cells. Connections are from branch 0 location 0 to a randomly chosen branch and loc.
Parameters:
Name Type Description Defaultpre_cell_view
View
View of the presynaptic cell.
requiredpost_cell_view
View
View of the postsynaptic cell.
requiredsynapse_type
Synapse
The synapse to append.
requiredconnectivity_matrix
ndarray[bool]
A boolean matrix indicating the connections between cells.
required Source code injaxley/connect.py
def connectivity_matrix_connect(\n pre_cell_view: \"View\",\n post_cell_view: \"View\",\n synapse_type: \"Synapse\",\n connectivity_matrix: np.ndarray[bool],\n):\n \"\"\"Appends multiple connections which build a custom connected network.\n\n Connects pre- and postsynaptic cells according to a custom connectivity matrix.\n Entries > 0 in the matrix indicate a connection between the corresponding cells.\n Connections are from branch 0 location 0 to a randomly chosen branch and loc.\n\n Args:\n pre_cell_view: View of the presynaptic cell.\n post_cell_view: View of the postsynaptic cell.\n synapse_type: The synapse to append.\n connectivity_matrix: A boolean matrix indicating the connections between cells.\n \"\"\"\n # Get pre- and postsynaptic cell indices.\n pre_cell_inds = pre_cell_view._cells_in_view\n post_cell_inds = post_cell_view._cells_in_view\n # setting scope ensure that this works indep of current scope\n pre_nodes = pre_cell_view.scope(\"local\").branch(0).comp(0).nodes\n pre_nodes[\"index\"] = pre_nodes.index\n pre_cell_nodes = pre_nodes.set_index(\"global_cell_index\")\n\n assert connectivity_matrix.shape == (\n len(pre_cell_inds),\n len(post_cell_inds),\n ), \"Connectivity matrix must have shape (num_pre, num_post).\"\n assert connectivity_matrix.dtype == bool, \"Connectivity matrix must be boolean.\"\n\n # get connection pairs from connectivity matrix\n from_idx, to_idx = np.where(connectivity_matrix)\n pre_cell_inds = pre_cell_inds[from_idx]\n post_cell_inds = post_cell_inds[to_idx]\n\n # Sample random postsynaptic compartments (global comp indices).\n global_post_indices = np.hstack(\n [\n sample_comp(post_cell_view.scope(\"global\").cell(cell_idx))\n for cell_idx in post_cell_inds\n ]\n )\n post_rows = post_cell_view.nodes.loc[global_post_indices]\n\n # Pre-synapse is at the zero-eth branch and zero-eth compartment.\n global_pre_indices = pre_cell_nodes.loc[pre_cell_inds, \"index\"].to_numpy()\n pre_rows = pre_cell_view.select(nodes=global_pre_indices).nodes\n\n pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)\n
"},{"location":"reference/connect/#jaxley.connect.fully_connect","title":"fully_connect(pre_cell_view, post_cell_view, synapse_type)
","text":"Appends multiple connections which build a fully connected layer.
Connections are from branch 0 location 0 to a randomly chosen branch and loc.
Parameters:
Name Type Description Defaultpre_cell_view
View
View of the presynaptic cell.
requiredpost_cell_view
View
View of the postsynaptic cell.
requiredsynapse_type
Synapse
The synapse to append.
required Source code injaxley/connect.py
def fully_connect(\n pre_cell_view: \"View\",\n post_cell_view: \"View\",\n synapse_type: \"Synapse\",\n):\n \"\"\"Appends multiple connections which build a fully connected layer.\n\n Connections are from branch 0 location 0 to a randomly chosen branch and loc.\n\n Args:\n pre_cell_view: View of the presynaptic cell.\n post_cell_view: View of the postsynaptic cell.\n synapse_type: The synapse to append.\n \"\"\"\n # Get pre- and postsynaptic cell indices.\n num_pre = len(pre_cell_view._cells_in_view)\n num_post = len(post_cell_view._cells_in_view)\n\n # Infer indices of (random) postsynaptic compartments.\n global_post_indices = (\n post_cell_view.nodes.groupby(\"global_cell_index\")\n .sample(num_pre, replace=True)\n .index.to_numpy()\n )\n global_post_indices = global_post_indices.reshape((-1, num_pre), order=\"F\").ravel()\n post_rows = post_cell_view.nodes.loc[global_post_indices]\n\n # Pre-synapse is at the zero-eth branch and zero-eth compartment.\n pre_rows = pre_cell_view.scope(\"local\").branch(0).comp(0).nodes.copy()\n # Repeat rows `num_post` times. See SO 50788508.\n pre_rows = pre_rows.loc[pre_rows.index.repeat(num_post)].reset_index(drop=True)\n\n pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)\n
"},{"location":"reference/connect/#jaxley.connect.is_same_network","title":"is_same_network(pre, post)
","text":"Check if views are from the same network.
Source code injaxley/connect.py
def is_same_network(pre: \"View\", post: \"View\") -> bool:\n \"\"\"Check if views are from the same network.\"\"\"\n is_in_net = \"network\" in pre.base.__class__.__name__.lower()\n is_in_same_net = pre.base is post.base\n return is_in_net and is_in_same_net\n
"},{"location":"reference/connect/#jaxley.connect.sample_comp","title":"sample_comp(cell_view, num=1, replace=True)
","text":"Sample a compartment from a cell.
Returns View with shape (num, num_cols).
Source code injaxley/connect.py
def sample_comp(cell_view: \"View\", num: int = 1, replace=True) -> \"CompartmentView\":\n \"\"\"Sample a compartment from a cell.\n\n Returns View with shape (num, num_cols).\"\"\"\n return np.random.choice(cell_view._comps_in_view, num, replace=replace)\n
"},{"location":"reference/connect/#jaxley.connect.sparse_connect","title":"sparse_connect(pre_cell_view, post_cell_view, synapse_type, p)
","text":"Appends multiple connections which build a sparse, randomly connected layer.
Connections are from branch 0 location 0 to a randomly chosen branch and loc.
Parameters:
Name Type Description Defaultpre_cell_view
View
View of the presynaptic cell.
requiredpost_cell_view
View
View of the postsynaptic cell.
requiredsynapse_type
Synapse
The synapse to append.
requiredp
float
Probability of connection.
required Source code injaxley/connect.py
def sparse_connect(\n pre_cell_view: \"View\",\n post_cell_view: \"View\",\n synapse_type: \"Synapse\",\n p: float,\n):\n \"\"\"Appends multiple connections which build a sparse, randomly connected layer.\n\n Connections are from branch 0 location 0 to a randomly chosen branch and loc.\n\n Args:\n pre_cell_view: View of the presynaptic cell.\n post_cell_view: View of the postsynaptic cell.\n synapse_type: The synapse to append.\n p: Probability of connection.\n \"\"\"\n # Get pre- and postsynaptic cell indices.\n pre_cell_inds = pre_cell_view._cells_in_view\n post_cell_inds = post_cell_view._cells_in_view\n num_pre = len(pre_cell_inds)\n num_post = len(post_cell_inds)\n\n num_connections = np.random.binomial(num_pre * num_post, p)\n pre_syn_neurons = np.random.choice(pre_cell_inds, size=num_connections)\n post_syn_neurons = np.random.choice(post_cell_inds, size=num_connections)\n\n # Sort the synapses only for convenience of inspecting `.edges`.\n sorting = np.argsort(pre_syn_neurons)\n pre_syn_neurons = pre_syn_neurons[sorting]\n post_syn_neurons = post_syn_neurons[sorting]\n\n # Post-synapse is a randomly chosen branch and compartment.\n global_post_indices = [\n sample_comp(post_cell_view.scope(\"global\").cell(cell_idx))\n for cell_idx in post_syn_neurons\n ]\n global_post_indices = (\n np.hstack(global_post_indices) if len(global_post_indices) > 1 else []\n )\n post_rows = post_cell_view.base.nodes.loc[global_post_indices]\n\n # Pre-synapse is at the zero-eth branch and zero-eth compartment.\n global_pre_indices = pre_cell_view.base._cumsum_ncomp_per_cell[pre_syn_neurons]\n pre_rows = pre_cell_view.base.nodes.loc[global_pre_indices]\n\n if len(pre_rows) > 0:\n pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)\n
"},{"location":"reference/integration/","title":"Simulation","text":""},{"location":"reference/integration/#jaxley.integrate.add_clamps","title":"add_clamps(externals, external_inds, data_clamps=None)
","text":"Adds clamps to the external inputs.
Parameters:
Name Type Description Defaultexternals
Dict
Current external inputs.
requiredexternal_inds
Dict
Current external indices.
requireddata_clamps
Optional[Tuple[str, ndarray, DataFrame]]
Additional data clamps. Defaults to None.
None
Returns:
Type DescriptionTuple[Dict, Dict]
Tuple[Dict, Dict]: Updated external inputs and indices.
Source code injaxley/integrate.py
def add_clamps(\n externals: Dict,\n external_inds: Dict,\n data_clamps: Optional[Tuple[str, jnp.ndarray, pd.DataFrame]] = None,\n) -> Tuple[Dict, Dict]:\n \"\"\"Adds clamps to the external inputs.\n\n Args:\n externals (Dict): Current external inputs.\n external_inds (Dict): Current external indices.\n data_clamps (Optional[Tuple[str, jnp.ndarray, pd.DataFrame]], optional): Additional data clamps. Defaults to None.\n\n Returns:\n Tuple[Dict, Dict]: Updated external inputs and indices.\n \"\"\"\n # If a clamp is inserted, add it to the external inputs.\n if data_clamps is not None:\n state_name, clamps, inds = data_clamps\n if state_name in externals.keys():\n externals[state_name] = jnp.concatenate([externals[state_name], clamps])\n external_inds[state_name] = jnp.concatenate(\n [external_inds[state_name], inds.index.to_numpy()]\n )\n else:\n externals[state_name] = clamps\n external_inds[state_name] = inds.index.to_numpy()\n\n return externals, external_inds\n
"},{"location":"reference/integration/#jaxley.integrate.add_stimuli","title":"add_stimuli(externals, external_inds, data_stimuli=None)
","text":"Extends the external inputs with the stimuli.
Parameters:
Name Type Description Defaultexternals
Dict
Current external inputs.
requiredexternal_inds
Dict
Current external indices.
requireddata_stimuli
Optional[Tuple[ndarray, DataFrame]]
Additional data stimuli. Defaults to None.
None
Returns:
Type DescriptionTuple[Dict, Dict]
Tuple[Dict, Dict]: Updated external inputs and indices.
Source code injaxley/integrate.py
def add_stimuli(\n externals: Dict,\n external_inds: Dict,\n data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n) -> Tuple[Dict, Dict]:\n \"\"\"Extends the external inputs with the stimuli.\n\n Args:\n externals (Dict): Current external inputs.\n external_inds (Dict): Current external indices.\n data_stimuli (Optional[Tuple[jnp.ndarray, pd.DataFrame]], optional): Additional data stimuli. Defaults to None.\n\n Returns:\n Tuple[Dict, Dict]: Updated external inputs and indices.\n \"\"\"\n # If stimulus is inserted, add it to the external inputs.\n if \"i\" in externals.keys() or data_stimuli is not None:\n if \"i\" in externals.keys():\n if data_stimuli is not None:\n externals[\"i\"] = jnp.concatenate([externals[\"i\"], data_stimuli[1]])\n external_inds[\"i\"] = jnp.concatenate(\n [external_inds[\"i\"], data_stimuli[2].index.to_numpy()]\n )\n else:\n externals[\"i\"] = data_stimuli[1]\n external_inds[\"i\"] = data_stimuli[2].index.to_numpy()\n\n return externals, external_inds\n
"},{"location":"reference/integration/#jaxley.integrate.build_init_and_step_fn","title":"build_init_and_step_fn(module, voltage_solver='jaxley.stone', solver='bwd_euler')
","text":"This function returns the init_fn
and step_fn
which initialize the parameters and states of the neuron model and then step through the model
Parameters:
Name Type Description Defaultmodule
Module
A Module
object that e.g. a cell.
voltage_solver
str
Voltage solver used in step. Defaults to \u201cjaxley.stone\u201d.
'jaxley.stone'
solver
str
ODE solver. Defaults to \u201cbwd_euler\u201d.
'bwd_euler'
Returns:
Type DescriptionTuple[Callable, Callable]
init_fn, step_fn: Functions that initialize the state and parameters, and perform a single integration step, respectively.
Source code injaxley/integrate.py
def build_init_and_step_fn(\n module: Module,\n voltage_solver: str = \"jaxley.stone\",\n solver: str = \"bwd_euler\",\n) -> Tuple[Callable, Callable]:\n \"\"\"This function returns the `init_fn` and `step_fn` which initialize the\n parameters and states of the neuron model and then step through the model\n\n Args:\n module (Module): A `Module` object that e.g. a cell.\n voltage_solver (str, optional): Voltage solver used in step. Defaults to \"jaxley.stone\".\n solver (str, optional): ODE solver. Defaults to \"bwd_euler\".\n\n Returns:\n init_fn, step_fn: Functions that initialize the state and parameters, and perform\n a single integration step, respectively.\n \"\"\"\n # Initialize the external inputs and their indices.\n external_inds = module.external_inds.copy()\n\n def init_fn(\n params: List[Dict[str, jnp.ndarray]],\n all_states: Optional[Dict] = None,\n param_state: Optional[List[Dict]] = None,\n delta_t: float = 0.025,\n ) -> Tuple[Dict, Dict]:\n \"\"\"Initializes the parameters and states of the neuron model.\n\n Args:\n params (List[Dict[str, jnp.ndarray]]): List of trainable parameters.\n all_states (Optional[Dict], optional): State if alread initialized. Defaults to None.\n param_state (Optional[List[Dict]], optional): Parameters returned by `data_set`.. Defaults to None.\n delta_t (float, optional): Step size. Defaults to 0.025.\n\n Returns:\n Tuple[Dict, Dict]: All states and parameters.\n \"\"\"\n # Make the `trainable_params` of the same shape as the `param_state`, such that\n # they can be processed together by `get_all_parameters`.\n pstate = params_to_pstate(params, module.indices_set_by_trainables)\n if param_state is not None:\n pstate += param_state\n\n all_params = module.get_all_parameters(pstate, voltage_solver=voltage_solver)\n all_states = (\n module.get_all_states(pstate, all_params, delta_t)\n if all_states is None\n else all_states\n )\n return all_states, all_params\n\n def step_fn(\n all_states: Dict,\n all_params: Dict,\n externals: Dict,\n external_inds: Dict = external_inds,\n delta_t: float = 0.025,\n ) -> Dict:\n \"\"\"Performs a single integration step with step size delta_t.\n\n Args:\n all_states (Dict): Current state of the neuron model.\n all_params (Dict): Current parameters of the neuron model.\n externals (Dict): External inputs.\n external_inds (Dict, optional): External indices. Defaults to `module.external_inds`.\n delta_t (float, optional): Time step. Defaults to 0.025.\n\n Returns:\n Dict: Updated states.\n \"\"\"\n state = all_states\n state = module.step(\n state,\n delta_t,\n external_inds,\n externals,\n params=all_params,\n solver=solver,\n voltage_solver=voltage_solver,\n )\n return state\n\n return init_fn, step_fn\n
"},{"location":"reference/integration/#jaxley.integrate.integrate","title":"integrate(module, params=[], *, param_state=None, data_stimuli=None, data_clamps=None, t_max=None, delta_t=0.025, solver='bwd_euler', voltage_solver='jaxley.stone', checkpoint_lengths=None, all_states=None, return_states=False)
","text":"Solves ODE and simulates neuron model.
Parameters:
Name Type Description Defaultparams
List[Dict[str, ndarray]]
Trainable parameters returned by get_parameters()
.
[]
param_state
Optional[List[Dict]]
Parameters returned by data_set
.
None
data_stimuli
Optional[Tuple[ndarray, DataFrame]]
Outputs of .data_stimulate()
, only needed if stimuli change across function calls.
None
data_clamps
Optional[Tuple[str, ndarray, DataFrame]]
Outputs of .data_clamp()
, only needed if clamps change across function calls.
None
t_max
Optional[float]
Duration of the simulation in milliseconds. If t_max
is greater than the length of the stimulus input, the stimulus will be padded at the end with zeros. If t_max
is smaller, then the stimulus with be truncated.
None
delta_t
float
Time step of the solver in milliseconds.
0.025
solver
str
Which ODE solver to use. Either of [\u201cfwd_euler\u201d, \u201cbwd_euler\u201d, \u201ccrank_nicolson\u201d].
'bwd_euler'
tridiag_solver
Algorithm to solve tridiagonal systems. The different options only affect bwd_euler
and crank_nicolson
solvers. Either of [\u201cstone\u201d, \u201cthomas\u201d], where stone
is much faster on GPU for long branches with many compartments and thomas
is slightly faster on CPU (thomas
is used in NEURON).
checkpoint_lengths
Optional[List[int]]
Number of timesteps at every level of checkpointing. The prod(checkpoint_lengths)
must be larger or equal to the desired number of simulated timesteps. Warning: the simulation is run for prod(checkpoint_lengths)
timesteps, and the result is posthoc truncated to the desired simulation length. Therefore, a poor choice of checkpoint_lengths
can lead to longer simulation time. If None
, no checkpointing is applied.
None
all_states
Optional[Dict]
An optional initial state that was returned by a previous jx.integrate(..., return_states=True)
run. Overrides potentially trainable initial states.
None
return_states
bool
If True, it returns all states such that the current state of the Module
can be set with set_states
.
False
Source code in jaxley/integrate.py
def integrate(\n module: Module,\n params: List[Dict[str, jnp.ndarray]] = [],\n *,\n param_state: Optional[List[Dict]] = None,\n data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n data_clamps: Optional[Tuple[str, jnp.ndarray, pd.DataFrame]] = None,\n t_max: Optional[float] = None,\n delta_t: float = 0.025,\n solver: str = \"bwd_euler\",\n voltage_solver: str = \"jaxley.stone\",\n checkpoint_lengths: Optional[List[int]] = None,\n all_states: Optional[Dict] = None,\n return_states: bool = False,\n) -> jnp.ndarray:\n \"\"\"\n Solves ODE and simulates neuron model.\n\n Args:\n params: Trainable parameters returned by `get_parameters()`.\n param_state: Parameters returned by `data_set`.\n data_stimuli: Outputs of `.data_stimulate()`, only needed if stimuli change\n across function calls.\n data_clamps: Outputs of `.data_clamp()`, only needed if clamps change across\n function calls.\n t_max: Duration of the simulation in milliseconds. If `t_max` is greater than\n the length of the stimulus input, the stimulus will be padded at the end\n with zeros. If `t_max` is smaller, then the stimulus with be truncated.\n delta_t: Time step of the solver in milliseconds.\n solver: Which ODE solver to use. Either of [\"fwd_euler\", \"bwd_euler\",\n \"crank_nicolson\"].\n tridiag_solver: Algorithm to solve tridiagonal systems. The different options\n only affect `bwd_euler` and `crank_nicolson` solvers. Either of [\"stone\",\n \"thomas\"], where `stone` is much faster on GPU for long branches\n with many compartments and `thomas` is slightly faster on CPU (`thomas` is\n used in NEURON).\n checkpoint_lengths: Number of timesteps at every level of checkpointing. The\n `prod(checkpoint_lengths)` must be larger or equal to the desired number of\n simulated timesteps. Warning: the simulation is run for\n `prod(checkpoint_lengths)` timesteps, and the result is posthoc truncated\n to the desired simulation length. Therefore, a poor choice of\n `checkpoint_lengths` can lead to longer simulation time. If `None`, no\n checkpointing is applied.\n all_states: An optional initial state that was returned by a previous\n `jx.integrate(..., return_states=True)` run. Overrides potentially\n trainable initial states.\n return_states: If True, it returns all states such that the current state of\n the `Module` can be set with `set_states`.\n \"\"\"\n\n assert module.initialized, \"Module is not initialized, run `._initialize()`.\"\n module.to_jax() # Creates `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`.\n\n # Initialize the external inputs and their indices.\n externals = module.externals.copy()\n external_inds = module.external_inds.copy()\n\n # If stimulus is inserted, add it to the external inputs.\n externals, external_inds = add_stimuli(externals, external_inds, data_stimuli)\n\n # If a clamp is inserted, add it to the external inputs.\n externals, external_inds = add_clamps(externals, external_inds, data_clamps)\n\n if not externals.keys():\n # No stimulus was inserted and no clamp was set.\n assert (\n t_max is not None\n ), \"If no stimulus or clamp are inserted you have to specify the simulation duration at `jx.integrate(..., t_max=)`.\"\n\n for key in externals.keys():\n externals[key] = externals[key].T # Shape `(time, num_stimuli)`.\n\n if module.recordings.empty:\n raise ValueError(\"No recordings are set. Please set them.\")\n rec_inds = module.recordings.rec_index.to_numpy()\n rec_states = module.recordings.state.to_numpy()\n\n # Shorten or pad stimulus depending on `t_max`.\n if t_max is not None:\n t_max_steps = int(t_max // delta_t + 1)\n\n # Pad or truncate the stimulus.\n for key in externals.keys():\n if t_max_steps > externals[key].shape[0]:\n if key == \"i\":\n pad = jnp.zeros(\n (t_max_steps - externals[\"i\"].shape[0], externals[\"i\"].shape[1])\n )\n externals[\"i\"] = jnp.concatenate((externals[\"i\"], pad))\n else:\n raise NotImplementedError(\n \"clamp must be at least as long as simulation.\"\n )\n else:\n externals[key] = externals[key][:t_max_steps, :]\n\n init_fn, step_fn = build_init_and_step_fn(\n module, voltage_solver=voltage_solver, solver=solver\n )\n all_states, all_params = init_fn(params, all_states, param_state, delta_t)\n\n def _body_fun(state, externals):\n state = step_fn(state, all_params, externals, external_inds, delta_t)\n recs = jnp.asarray(\n [\n state[rec_state][rec_ind]\n for rec_state, rec_ind in zip(rec_states, rec_inds)\n ]\n )\n return state, recs\n\n # If necessary, pad the stimulus with zeros in order to simulate sufficiently long.\n # The total simulation length will be `prod(checkpoint_lengths)`. At the end, we\n # return only the first `nsteps_to_return` elements (plus the initial state).\n if externals:\n example_key = list(externals.keys())[0]\n nsteps_to_return = len(externals[example_key])\n else:\n nsteps_to_return = t_max_steps\n\n if checkpoint_lengths is None:\n checkpoint_lengths = [nsteps_to_return]\n length = nsteps_to_return\n else:\n length = prod(checkpoint_lengths)\n size_difference = length - nsteps_to_return\n assert (\n nsteps_to_return <= length\n ), \"The desired simulation duration is longer than `prod(nested_length)`.\"\n if externals:\n dummy_external = jnp.zeros(\n (size_difference, externals[example_key].shape[1])\n )\n for key in externals.keys():\n externals[key] = jnp.concatenate([externals[key], dummy_external])\n\n # Record the initial state.\n init_recs = jnp.asarray(\n [\n all_states[rec_state][rec_ind]\n for rec_state, rec_ind in zip(rec_states, rec_inds)\n ]\n )\n init_recording = jnp.expand_dims(init_recs, axis=0)\n\n # Run simulation.\n all_states, recordings = nested_checkpoint_scan(\n _body_fun,\n all_states,\n externals,\n length=length,\n nested_lengths=checkpoint_lengths,\n )\n recs = jnp.concatenate([init_recording, recordings[:nsteps_to_return]], axis=0).T\n return (recs, all_states) if return_states else recs\n
"},{"location":"reference/integration/#jaxley.solver_gate.exponential_euler","title":"exponential_euler(x, dt, x_inf, x_tau)
","text":"An exact solver for the linear dynamical system dx = -(x - x_inf) / x_tau
.
jaxley/solver_gate.py
def exponential_euler(\n x: jnp.ndarray,\n dt: float,\n x_inf: jnp.ndarray,\n x_tau: jnp.ndarray,\n):\n \"\"\"An exact solver for the linear dynamical system `dx = -(x - x_inf) / x_tau`.\"\"\"\n exp_term = save_exp(-dt / x_tau)\n return x * exp_term + x_inf * (1.0 - exp_term)\n
"},{"location":"reference/integration/#jaxley.solver_gate.save_exp","title":"save_exp(x, max_value=20.0)
","text":"Clip the input to a maximum value and return its exponential.
Source code injaxley/solver_gate.py
def save_exp(x, max_value: float = 20.0):\n \"\"\"Clip the input to a maximum value and return its exponential.\"\"\"\n x = jnp.clip(x, a_max=max_value)\n return jnp.exp(x)\n
"},{"location":"reference/integration/#jaxley.solver_gate.solve_inf_gate_exponential","title":"solve_inf_gate_exponential(x, dt, s_inf, tau_s)
","text":"solves dx/dt = (s_inf - x) / tau_s via exponential Euler
Parameters:
Name Type Description Defaultx
ndarray
gate variable
requireddt
float
time_delta
requireds_inf
ndarray
description
requiredtau_s
ndarray
description
requiredReturns:
Name Type Description_type_
updated gate
Source code injaxley/solver_gate.py
def solve_inf_gate_exponential(\n x: jnp.ndarray,\n dt: float,\n s_inf: jnp.ndarray,\n tau_s: jnp.ndarray,\n):\n \"\"\"solves dx/dt = (s_inf - x) / tau_s\n via exponential Euler\n\n Args:\n x (jnp.ndarray): gate variable\n dt (float): time_delta\n s_inf (jnp.ndarray): _description_\n tau_s (jnp.ndarray): _description_\n\n Returns:\n _type_: updated gate\n \"\"\"\n slope = -1.0 / tau_s\n exp_term = save_exp(slope * dt)\n return x * exp_term + s_inf * (1.0 - exp_term)\n
"},{"location":"reference/integration/#jaxley.solver_voltage.step_voltage_explicit","title":"step_voltage_explicit(voltages, voltage_terms, constant_terms, axial_conductances, internal_node_inds, sinks, sources, types, ncomp_per_branch, par_inds, child_inds, nbranches, solver, delta_t, idx, debug_states)
","text":"Solve one timestep of branched nerve equations with explicit (forward) Euler.
Source code injaxley/solver_voltage.py
def step_voltage_explicit(\n voltages: jnp.ndarray,\n voltage_terms: jnp.ndarray,\n constant_terms: jnp.ndarray,\n axial_conductances: jnp.ndarray,\n internal_node_inds: jnp.ndarray,\n sinks: jnp.ndarray,\n sources: jnp.ndarray,\n types: jnp.ndarray,\n ncomp_per_branch: jnp.ndarray,\n par_inds: jnp.ndarray,\n child_inds: jnp.ndarray,\n nbranches: int,\n solver: str,\n delta_t: float,\n idx: JaxleySolveIndexer,\n debug_states,\n) -> jnp.ndarray:\n \"\"\"Solve one timestep of branched nerve equations with explicit (forward) Euler.\"\"\"\n voltages = jnp.reshape(voltages, (nbranches, -1))\n voltage_terms = jnp.reshape(voltage_terms, (nbranches, -1))\n constant_terms = jnp.reshape(constant_terms, (nbranches, -1))\n\n update = _voltage_vectorfield(\n voltages,\n voltage_terms,\n constant_terms,\n types,\n sources,\n sinks,\n axial_conductances,\n par_inds,\n child_inds,\n nbranches,\n solver,\n delta_t,\n idx,\n debug_states,\n )\n new_voltates = voltages + delta_t * update\n return new_voltates.ravel(order=\"C\")\n
"},{"location":"reference/integration/#jaxley.solver_voltage.step_voltage_implicit_with_jaxley_spsolve","title":"step_voltage_implicit_with_jaxley_spsolve(voltages, voltage_terms, constant_terms, axial_conductances, internal_node_inds, sinks, sources, types, ncomp_per_branch, par_inds, child_inds, nbranches, solver, delta_t, idx, debug_states)
","text":"Solve one timestep of branched nerve equations with implicit (backward) Euler.
Source code injaxley/solver_voltage.py
def step_voltage_implicit_with_jaxley_spsolve(\n voltages: jnp.ndarray,\n voltage_terms: jnp.ndarray,\n constant_terms: jnp.ndarray,\n axial_conductances: jnp.ndarray,\n internal_node_inds: jnp.ndarray,\n sinks: jnp.ndarray,\n sources: jnp.ndarray,\n types: jnp.ndarray,\n ncomp_per_branch: jnp.ndarray,\n par_inds: jnp.ndarray,\n child_inds: jnp.ndarray,\n nbranches: int,\n solver: str,\n delta_t: float,\n idx: JaxleySolveIndexer,\n debug_states,\n):\n \"\"\"Solve one timestep of branched nerve equations with implicit (backward) Euler.\"\"\"\n # Build diagonals.\n c2c = np.isin(types, [0, 1, 2])\n total_ncomp = idx.cumsum_ncomp[-1]\n diags = jnp.ones(total_ncomp)\n\n # if-case needed because `.at` does not allow empty inputs, but the input is\n # empty for compartments.\n if len(sinks[c2c]) > 0:\n diags = diags.at[idx.mask(sinks[c2c])].add(delta_t * axial_conductances[c2c])\n\n diags = diags.at[idx.mask(internal_node_inds)].add(delta_t * voltage_terms)\n\n # Build solves.\n solves = jnp.zeros(total_ncomp)\n solves = solves.at[idx.mask(internal_node_inds)].add(\n voltages + delta_t * constant_terms\n )\n\n # Build upper and lower within the branch.\n c2c = types == 0 # c2c = compartment-to-compartment.\n\n # Build uppers.\n uppers = jnp.zeros(total_ncomp)\n upper_inds = sources[c2c] > sinks[c2c]\n sinks_upper = sinks[c2c][upper_inds]\n if len(sinks_upper) > 0:\n uppers = uppers.at[idx.mask(sinks_upper)].add(\n -delta_t * axial_conductances[c2c][upper_inds]\n )\n\n # Build lowers.\n lowers = jnp.zeros(total_ncomp)\n lower_inds = sources[c2c] < sinks[c2c]\n sinks_lower = sinks[c2c][lower_inds]\n if len(sinks_lower) > 0:\n lowers = lowers.at[idx.mask(sinks_lower)].add(\n -delta_t * axial_conductances[c2c][lower_inds]\n )\n\n # Build branchpoint conductances.\n branchpoint_conds_parents = axial_conductances[types == 1]\n branchpoint_conds_children = axial_conductances[types == 2]\n branchpoint_weights_parents = axial_conductances[types == 3]\n branchpoint_weights_children = axial_conductances[types == 4]\n all_branchpoint_vals = jnp.concatenate(\n [branchpoint_weights_parents, branchpoint_weights_children]\n )\n # Find unique group identifiers\n num_branchpoints = len(branchpoint_conds_parents)\n branchpoint_diags = -group_and_sum(\n all_branchpoint_vals, idx.branchpoint_group_inds, num_branchpoints\n )\n branchpoint_solves = jnp.zeros((num_branchpoints,))\n\n branchpoint_conds_children = -delta_t * branchpoint_conds_children\n branchpoint_conds_parents = -delta_t * branchpoint_conds_parents\n\n # Here, I move all child and parent indices towards a branchpoint into a larger\n # vector. This is wasteful, but it makes indexing much easier. JIT compiling\n # makes the speed difference negligible.\n # Children.\n bp_conds_children = jnp.zeros(nbranches)\n bp_weights_children = jnp.zeros(nbranches)\n # Parents.\n bp_conds_parents = jnp.zeros(nbranches)\n bp_weights_parents = jnp.zeros(nbranches)\n\n # `.at[inds]` requires that `inds` is not empty, so we need an if-case here.\n # `len(inds) == 0` is the case for branches and compartments.\n if num_branchpoints > 0:\n bp_conds_children = bp_conds_children.at[child_inds].set(\n branchpoint_conds_children\n )\n bp_weights_children = bp_weights_children.at[child_inds].set(\n branchpoint_weights_children\n )\n bp_conds_parents = bp_conds_parents.at[par_inds].set(branchpoint_conds_parents)\n bp_weights_parents = bp_weights_parents.at[par_inds].set(\n branchpoint_weights_parents\n )\n\n # Triangulate the linear system of equations.\n (\n diags,\n lowers,\n solves,\n uppers,\n branchpoint_diags,\n branchpoint_solves,\n bp_weights_children,\n bp_conds_parents,\n ) = _triang_branched(\n lowers,\n diags,\n uppers,\n solves,\n bp_conds_children,\n bp_conds_parents,\n bp_weights_children,\n bp_weights_parents,\n branchpoint_diags,\n branchpoint_solves,\n solver,\n ncomp_per_branch,\n idx,\n debug_states,\n )\n\n # Backsubstitute the linear system of equations.\n (\n solves,\n lowers,\n diags,\n bp_weights_parents,\n branchpoint_solves,\n bp_conds_children,\n ) = _backsub_branched(\n lowers,\n diags,\n uppers,\n solves,\n bp_conds_children,\n bp_conds_parents,\n bp_weights_children,\n bp_weights_parents,\n branchpoint_diags,\n branchpoint_solves,\n solver,\n ncomp_per_branch,\n idx,\n debug_states,\n )\n return solves.ravel(order=\"C\")[idx.mask(internal_node_inds)]\n
"},{"location":"reference/mechanisms/","title":"Channels","text":""},{"location":"reference/mechanisms/#channel","title":"Channel","text":"Channel base class. All channels inherit from this class.
As in NEURON, a Channel
is considered a distributed process, which means that its conductances are to be specified in S/cm2
and its currents are to be specified in uA/cm2
.
jaxley/channels/channel.py
class Channel:\n \"\"\"Channel base class. All channels inherit from this class.\n\n As in NEURON, a `Channel` is considered a distributed process, which means that its\n conductances are to be specified in `S/cm2` and its currents are to be specified in\n `uA/cm2`.\"\"\"\n\n _name = None\n channel_params = None\n channel_states = None\n current_name = None\n\n def __init__(self, name: Optional[str] = None):\n contact = (\n \"If you have any questions, please reach out via email to \"\n \"michael.deistler@uni-tuebingen.de or create an issue on Github: \"\n \"https://github.com/jaxleyverse/jaxley/issues. Thank you!\"\n )\n if (\n not hasattr(self, \"current_is_in_mA_per_cm2\")\n or not self.current_is_in_mA_per_cm2\n ):\n raise ValueError(\n \"The channel you are using is deprecated. \"\n \"In Jaxley version 0.5.0, we changed the unit of the current returned \"\n \"by `compute_current` of channels from `uA/cm^2` to `mA/cm^2`. Please \"\n \"update your channel model (by dividing the resulting current by 1000) \"\n \"and set `self.current_is_in_mA_per_cm2=True` as the first line \"\n f\"in the `__init__()` method of your channel. {contact}\"\n )\n\n self._name = name if name else self.__class__.__name__\n\n @property\n def name(self) -> Optional[str]:\n \"\"\"The name of the channel (by default, this is the class name).\"\"\"\n return self._name\n\n def change_name(self, new_name: str):\n \"\"\"Change the channel name.\n\n Args:\n new_name: The new name of the channel.\n\n Returns:\n Renamed channel, such that this function is chainable.\n \"\"\"\n old_prefix = self._name + \"_\"\n new_prefix = new_name + \"_\"\n\n self._name = new_name\n self.channel_params = {\n (\n new_prefix + key[len(old_prefix) :]\n if key.startswith(old_prefix)\n else key\n ): value\n for key, value in self.channel_params.items()\n }\n\n self.channel_states = {\n (\n new_prefix + key[len(old_prefix) :]\n if key.startswith(old_prefix)\n else key\n ): value\n for key, value in self.channel_states.items()\n }\n return self\n\n def update_states(\n self, states, dt, v, params\n ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:\n \"\"\"Return the updated states.\"\"\"\n raise NotImplementedError\n\n def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n ):\n \"\"\"Given channel states and voltage, return the current through the channel.\n\n Args:\n states: All states of the compartment.\n v: Voltage of the compartment in mV.\n params: Parameters of the channel (conductances in `S/cm2`).\n\n Returns:\n Current in `uA/cm2`.\n \"\"\"\n raise NotImplementedError\n\n def init_state(\n self,\n states: Dict[str, jnp.ndarray],\n v: jnp.ndarray,\n params: Dict[str, jnp.ndarray],\n delta_t: float,\n ):\n \"\"\"Initialize states of channel.\"\"\"\n return {}\n
"},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.name","title":"name: Optional[str]
property
","text":"The name of the channel (by default, this is the class name).
"},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.change_name","title":"change_name(new_name)
","text":"Change the channel name.
Parameters:
Name Type Description Defaultnew_name
str
The new name of the channel.
requiredReturns:
Type DescriptionRenamed channel, such that this function is chainable.
Source code injaxley/channels/channel.py
def change_name(self, new_name: str):\n \"\"\"Change the channel name.\n\n Args:\n new_name: The new name of the channel.\n\n Returns:\n Renamed channel, such that this function is chainable.\n \"\"\"\n old_prefix = self._name + \"_\"\n new_prefix = new_name + \"_\"\n\n self._name = new_name\n self.channel_params = {\n (\n new_prefix + key[len(old_prefix) :]\n if key.startswith(old_prefix)\n else key\n ): value\n for key, value in self.channel_params.items()\n }\n\n self.channel_states = {\n (\n new_prefix + key[len(old_prefix) :]\n if key.startswith(old_prefix)\n else key\n ): value\n for key, value in self.channel_states.items()\n }\n return self\n
"},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.compute_current","title":"compute_current(states, v, params)
","text":"Given channel states and voltage, return the current through the channel.
Parameters:
Name Type Description Defaultstates
Dict[str, ndarray]
All states of the compartment.
requiredv
Voltage of the compartment in mV.
requiredparams
Dict[str, ndarray]
Parameters of the channel (conductances in S/cm2
).
Returns:
Type DescriptionCurrent in uA/cm2
.
jaxley/channels/channel.py
def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n \"\"\"Given channel states and voltage, return the current through the channel.\n\n Args:\n states: All states of the compartment.\n v: Voltage of the compartment in mV.\n params: Parameters of the channel (conductances in `S/cm2`).\n\n Returns:\n Current in `uA/cm2`.\n \"\"\"\n raise NotImplementedError\n
"},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.init_state","title":"init_state(states, v, params, delta_t)
","text":"Initialize states of channel.
Source code injaxley/channels/channel.py
def init_state(\n self,\n states: Dict[str, jnp.ndarray],\n v: jnp.ndarray,\n params: Dict[str, jnp.ndarray],\n delta_t: float,\n):\n \"\"\"Initialize states of channel.\"\"\"\n return {}\n
"},{"location":"reference/mechanisms/#jaxley.channels.channel.Channel.update_states","title":"update_states(states, dt, v, params)
","text":"Return the updated states.
Source code injaxley/channels/channel.py
def update_states(\n self, states, dt, v, params\n) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]:\n \"\"\"Return the updated states.\"\"\"\n raise NotImplementedError\n
"},{"location":"reference/mechanisms/#hh","title":"HH","text":" Bases: Channel
Hodgkin-Huxley channel.
Source code injaxley/channels/hh.py
class HH(Channel):\n \"\"\"Hodgkin-Huxley channel.\"\"\"\n\n def __init__(self, name: Optional[str] = None):\n self.current_is_in_mA_per_cm2 = True\n\n super().__init__(name)\n prefix = self._name\n self.channel_params = {\n f\"{prefix}_gNa\": 0.12,\n f\"{prefix}_gK\": 0.036,\n f\"{prefix}_gLeak\": 0.0003,\n f\"{prefix}_eNa\": 50.0,\n f\"{prefix}_eK\": -77.0,\n f\"{prefix}_eLeak\": -54.3,\n }\n self.channel_states = {\n f\"{prefix}_m\": 0.2,\n f\"{prefix}_h\": 0.2,\n f\"{prefix}_n\": 0.2,\n }\n self.current_name = f\"i_HH\"\n\n def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n ):\n \"\"\"Return updated HH channel state.\"\"\"\n prefix = self._name\n m, h, n = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"], states[f\"{prefix}_n\"]\n new_m = solve_gate_exponential(m, dt, *self.m_gate(v))\n new_h = solve_gate_exponential(h, dt, *self.h_gate(v))\n new_n = solve_gate_exponential(n, dt, *self.n_gate(v))\n return {f\"{prefix}_m\": new_m, f\"{prefix}_h\": new_h, f\"{prefix}_n\": new_n}\n\n def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n ):\n \"\"\"Return current through HH channels.\"\"\"\n prefix = self._name\n m, h, n = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"], states[f\"{prefix}_n\"]\n\n gNa = params[f\"{prefix}_gNa\"] * (m**3) * h # S/cm^2\n gK = params[f\"{prefix}_gK\"] * n**4 # S/cm^2\n gLeak = params[f\"{prefix}_gLeak\"] # S/cm^2\n\n return (\n gNa * (v - params[f\"{prefix}_eNa\"])\n + gK * (v - params[f\"{prefix}_eK\"])\n + gLeak * (v - params[f\"{prefix}_eLeak\"])\n )\n\n def init_state(self, states, v, params, delta_t):\n \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n prefix = self._name\n alpha_m, beta_m = self.m_gate(v)\n alpha_h, beta_h = self.h_gate(v)\n alpha_n, beta_n = self.n_gate(v)\n return {\n f\"{prefix}_m\": alpha_m / (alpha_m + beta_m),\n f\"{prefix}_h\": alpha_h / (alpha_h + beta_h),\n f\"{prefix}_n\": alpha_n / (alpha_n + beta_n),\n }\n\n @staticmethod\n def m_gate(v):\n alpha = 0.1 * _vtrap(-(v + 40), 10)\n beta = 4.0 * save_exp(-(v + 65) / 18)\n return alpha, beta\n\n @staticmethod\n def h_gate(v):\n alpha = 0.07 * save_exp(-(v + 65) / 20)\n beta = 1.0 / (save_exp(-(v + 35) / 10) + 1)\n return alpha, beta\n\n @staticmethod\n def n_gate(v):\n alpha = 0.01 * _vtrap(-(v + 55), 10)\n beta = 0.125 * save_exp(-(v + 65) / 80)\n return alpha, beta\n
"},{"location":"reference/mechanisms/#jaxley.channels.hh.HH.compute_current","title":"compute_current(states, v, params)
","text":"Return current through HH channels.
Source code injaxley/channels/hh.py
def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n \"\"\"Return current through HH channels.\"\"\"\n prefix = self._name\n m, h, n = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"], states[f\"{prefix}_n\"]\n\n gNa = params[f\"{prefix}_gNa\"] * (m**3) * h # S/cm^2\n gK = params[f\"{prefix}_gK\"] * n**4 # S/cm^2\n gLeak = params[f\"{prefix}_gLeak\"] # S/cm^2\n\n return (\n gNa * (v - params[f\"{prefix}_eNa\"])\n + gK * (v - params[f\"{prefix}_eK\"])\n + gLeak * (v - params[f\"{prefix}_eLeak\"])\n )\n
"},{"location":"reference/mechanisms/#jaxley.channels.hh.HH.init_state","title":"init_state(states, v, params, delta_t)
","text":"Initialize the state such at fixed point of gate dynamics.
Source code injaxley/channels/hh.py
def init_state(self, states, v, params, delta_t):\n \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n prefix = self._name\n alpha_m, beta_m = self.m_gate(v)\n alpha_h, beta_h = self.h_gate(v)\n alpha_n, beta_n = self.n_gate(v)\n return {\n f\"{prefix}_m\": alpha_m / (alpha_m + beta_m),\n f\"{prefix}_h\": alpha_h / (alpha_h + beta_h),\n f\"{prefix}_n\": alpha_n / (alpha_n + beta_n),\n }\n
"},{"location":"reference/mechanisms/#jaxley.channels.hh.HH.update_states","title":"update_states(states, dt, v, params)
","text":"Return updated HH channel state.
Source code injaxley/channels/hh.py
def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n):\n \"\"\"Return updated HH channel state.\"\"\"\n prefix = self._name\n m, h, n = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"], states[f\"{prefix}_n\"]\n new_m = solve_gate_exponential(m, dt, *self.m_gate(v))\n new_h = solve_gate_exponential(h, dt, *self.h_gate(v))\n new_n = solve_gate_exponential(n, dt, *self.n_gate(v))\n return {f\"{prefix}_m\": new_m, f\"{prefix}_h\": new_h, f\"{prefix}_n\": new_n}\n
"},{"location":"reference/mechanisms/#pospischil","title":"Pospischil","text":" Bases: Channel
Leak current
Source code injaxley/channels/pospischil.py
class Leak(Channel):\n \"\"\"Leak current\"\"\"\n\n def __init__(self, name: Optional[str] = None):\n self.current_is_in_mA_per_cm2 = True\n\n super().__init__(name)\n prefix = self._name\n self.channel_params = {\n f\"{prefix}_gLeak\": 1e-4,\n f\"{prefix}_eLeak\": -70.0,\n }\n self.channel_states = {}\n self.current_name = f\"i_{prefix}\"\n\n def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n ):\n \"\"\"No state to update.\"\"\"\n return {}\n\n def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n ):\n \"\"\"Return current.\"\"\"\n prefix = self._name\n gLeak = params[f\"{prefix}_gLeak\"] # S/cm^2\n return gLeak * (v - params[f\"{prefix}_eLeak\"])\n\n def init_state(self, states, v, params, delta_t):\n return {}\n
Bases: Channel
Sodium channel
Source code injaxley/channels/pospischil.py
class Na(Channel):\n \"\"\"Sodium channel\"\"\"\n\n def __init__(self, name: Optional[str] = None):\n self.current_is_in_mA_per_cm2 = True\n\n super().__init__(name)\n prefix = self._name\n self.channel_params = {\n f\"{prefix}_gNa\": 50e-3,\n \"eNa\": 50.0,\n \"vt\": -60.0, # Global parameter, not prefixed with `Na`.\n }\n self.channel_states = {f\"{prefix}_m\": 0.2, f\"{prefix}_h\": 0.2}\n self.current_name = f\"i_Na\"\n\n def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n ):\n \"\"\"Update state.\"\"\"\n prefix = self._name\n m, h = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"]\n new_m = solve_gate_exponential(m, dt, *self.m_gate(v, params[\"vt\"]))\n new_h = solve_gate_exponential(h, dt, *self.h_gate(v, params[\"vt\"]))\n return {f\"{prefix}_m\": new_m, f\"{prefix}_h\": new_h}\n\n def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n ):\n \"\"\"Return current.\"\"\"\n prefix = self._name\n m, h = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"]\n\n gNa = params[f\"{prefix}_gNa\"] * (m**3) * h # S/cm^2\n\n current = gNa * (v - params[\"eNa\"])\n return current\n\n def init_state(self, states, v, params, delta_t):\n \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n prefix = self._name\n alpha_m, beta_m = self.m_gate(v, params[\"vt\"])\n alpha_h, beta_h = self.h_gate(v, params[\"vt\"])\n return {\n f\"{prefix}_m\": alpha_m / (alpha_m + beta_m),\n f\"{prefix}_h\": alpha_h / (alpha_h + beta_h),\n }\n\n @staticmethod\n def m_gate(v, vt):\n v_alpha = v - vt - 13.0\n alpha = 0.32 * efun(-0.25 * v_alpha) / 0.25\n\n v_beta = v - vt - 40.0\n beta = 0.28 * efun(0.2 * v_beta) / 0.2\n return alpha, beta\n\n @staticmethod\n def h_gate(v, vt):\n v_alpha = v - vt - 17.0\n alpha = 0.128 * save_exp(-v_alpha / 18.0)\n\n v_beta = v - vt - 40.0\n beta = 4.0 / (save_exp(-v_beta / 5.0) + 1.0)\n return alpha, beta\n
Bases: Channel
Potassium channel
Source code injaxley/channels/pospischil.py
class K(Channel):\n \"\"\"Potassium channel\"\"\"\n\n def __init__(self, name: Optional[str] = None):\n self.current_is_in_mA_per_cm2 = True\n\n super().__init__(name)\n prefix = self._name\n self.channel_params = {\n f\"{prefix}_gK\": 5e-3,\n \"eK\": -90.0,\n \"vt\": -60.0, # Global parameter, not prefixed with `Na`.\n }\n self.channel_states = {f\"{prefix}_n\": 0.2}\n self.current_name = f\"i_K\"\n\n def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n ):\n \"\"\"Update state.\"\"\"\n prefix = self._name\n n = states[f\"{prefix}_n\"]\n new_n = solve_gate_exponential(n, dt, *self.n_gate(v, params[\"vt\"]))\n return {f\"{prefix}_n\": new_n}\n\n def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n ):\n \"\"\"Return current.\"\"\"\n prefix = self._name\n n = states[f\"{prefix}_n\"]\n\n gK = params[f\"{prefix}_gK\"] * (n**4) # S/cm^2\n\n return gK * (v - params[\"eK\"])\n\n def init_state(self, states, v, params, delta_t):\n \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n prefix = self._name\n alpha_n, beta_n = self.n_gate(v, params[\"vt\"])\n return {f\"{prefix}_n\": alpha_n / (alpha_n + beta_n)}\n\n @staticmethod\n def n_gate(v, vt):\n v_alpha = v - vt - 15.0\n alpha = 0.032 * efun(-0.2 * v_alpha) / 0.2\n\n v_beta = v - vt - 10.0\n beta = 0.5 * save_exp(-v_beta / 40.0)\n return alpha, beta\n
Bases: Channel
Slow M Potassium channel
Source code injaxley/channels/pospischil.py
class Km(Channel):\n \"\"\"Slow M Potassium channel\"\"\"\n\n def __init__(self, name: Optional[str] = None):\n self.current_is_in_mA_per_cm2 = True\n\n super().__init__(name)\n prefix = self._name\n self.channel_params = {\n f\"{prefix}_gKm\": 0.004e-3,\n f\"{prefix}_taumax\": 4000.0,\n f\"eK\": -90.0,\n }\n self.channel_states = {f\"{prefix}_p\": 0.2}\n self.current_name = f\"i_K\"\n\n def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n ):\n \"\"\"Update state.\"\"\"\n prefix = self._name\n p = states[f\"{prefix}_p\"]\n new_p = solve_inf_gate_exponential(\n p, dt, *self.p_gate(v, params[f\"{prefix}_taumax\"])\n )\n return {f\"{prefix}_p\": new_p}\n\n def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n ):\n \"\"\"Return current.\"\"\"\n prefix = self._name\n p = states[f\"{prefix}_p\"]\n\n gKm = params[f\"{prefix}_gKm\"] * p # S/cm^2\n return gKm * (v - params[\"eK\"])\n\n def init_state(self, states, v, params, delta_t):\n \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n prefix = self._name\n alpha_p, beta_p = self.p_gate(v, params[f\"{prefix}_taumax\"])\n return {f\"{prefix}_p\": alpha_p / (alpha_p + beta_p)}\n\n @staticmethod\n def p_gate(v, taumax):\n v_p = v + 35.0\n p_inf = 1.0 / (1.0 + save_exp(-0.1 * v_p))\n\n tau_p = taumax / (3.3 * save_exp(0.05 * v_p) + save_exp(-0.05 * v_p))\n\n return p_inf, tau_p\n
Bases: Channel
L-type Calcium channel
Source code injaxley/channels/pospischil.py
class CaL(Channel):\n \"\"\"L-type Calcium channel\"\"\"\n\n def __init__(self, name: Optional[str] = None):\n self.current_is_in_mA_per_cm2 = True\n\n super().__init__(name)\n prefix = self._name\n self.channel_params = {\n f\"{prefix}_gCaL\": 0.1e-3,\n \"eCa\": 120.0,\n }\n self.channel_states = {f\"{prefix}_q\": 0.2, f\"{prefix}_r\": 0.2}\n self.current_name = f\"i_Ca\"\n\n def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n ):\n \"\"\"Update state.\"\"\"\n prefix = self._name\n q, r = states[f\"{prefix}_q\"], states[f\"{prefix}_r\"]\n new_q = solve_gate_exponential(q, dt, *self.q_gate(v))\n new_r = solve_gate_exponential(r, dt, *self.r_gate(v))\n return {f\"{prefix}_q\": new_q, f\"{prefix}_r\": new_r}\n\n def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n ):\n \"\"\"Return current.\"\"\"\n prefix = self._name\n q, r = states[f\"{prefix}_q\"], states[f\"{prefix}_r\"]\n gCaL = params[f\"{prefix}_gCaL\"] * (q**2) * r # S/cm^2\n\n return gCaL * (v - params[\"eCa\"])\n\n def init_state(self, states, v, params, delta_t):\n \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n prefix = self._name\n alpha_q, beta_q = self.q_gate(v)\n alpha_r, beta_r = self.r_gate(v)\n return {\n f\"{prefix}_q\": alpha_q / (alpha_q + beta_q),\n f\"{prefix}_r\": alpha_r / (alpha_r + beta_r),\n }\n\n @staticmethod\n def q_gate(v):\n v_alpha = -v - 27.0\n alpha = 0.055 * efun(v_alpha / 3.8) * 3.8\n\n v_beta = -v - 75.0\n beta = 0.94 * save_exp(v_beta / 17.0)\n return alpha, beta\n\n @staticmethod\n def r_gate(v):\n v_alpha = -v - 13.0\n alpha = 0.000457 * save_exp(v_alpha / 50)\n\n v_beta = -v - 15.0\n beta = 0.0065 / (save_exp(v_beta / 28.0) + 1)\n return alpha, beta\n
Bases: Channel
T-type Calcium channel
Source code injaxley/channels/pospischil.py
class CaT(Channel):\n \"\"\"T-type Calcium channel\"\"\"\n\n def __init__(self, name: Optional[str] = None):\n self.current_is_in_mA_per_cm2 = True\n\n super().__init__(name)\n prefix = self._name\n self.channel_params = {\n f\"{prefix}_gCaT\": 0.4e-4,\n f\"{prefix}_vx\": 2.0,\n \"eCa\": 120.0, # Global parameter, not prefixed with `CaT`.\n }\n self.channel_states = {f\"{prefix}_u\": 0.2}\n self.current_name = f\"i_Ca\"\n\n def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n ):\n \"\"\"Update state.\"\"\"\n prefix = self._name\n u = states[f\"{prefix}_u\"]\n new_u = solve_inf_gate_exponential(\n u, dt, *self.u_gate(v, params[f\"{prefix}_vx\"])\n )\n return {f\"{prefix}_u\": new_u}\n\n def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n ):\n \"\"\"Return current.\"\"\"\n prefix = self._name\n u = states[f\"{prefix}_u\"]\n s_inf = 1.0 / (1.0 + save_exp(-(v + params[f\"{prefix}_vx\"] + 57.0) / 6.2))\n\n gCaT = params[f\"{prefix}_gCaT\"] * (s_inf**2) * u # S/cm^2\n\n return gCaT * (v - params[\"eCa\"])\n\n def init_state(self, states, v, params, delta_t):\n \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n prefix = self._name\n alpha_u, beta_u = self.u_gate(v, params[f\"{prefix}_vx\"])\n return {f\"{prefix}_u\": alpha_u / (alpha_u + beta_u)}\n\n @staticmethod\n def u_gate(v, vx):\n v_u1 = v + vx + 81.0\n u_inf = 1.0 / (1.0 + save_exp(v_u1 / 4))\n\n tau_u = (30.8 + (211.4 + save_exp((v + vx + 113.2) / 5.0))) / (\n 3.7 * (1 + save_exp((v + vx + 84.0) / 3.2))\n )\n\n return u_inf, tau_u\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Leak.compute_current","title":"compute_current(states, v, params)
","text":"Return current.
Source code injaxley/channels/pospischil.py
def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n \"\"\"Return current.\"\"\"\n prefix = self._name\n gLeak = params[f\"{prefix}_gLeak\"] # S/cm^2\n return gLeak * (v - params[f\"{prefix}_eLeak\"])\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Leak.update_states","title":"update_states(states, dt, v, params)
","text":"No state to update.
Source code injaxley/channels/pospischil.py
def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n):\n \"\"\"No state to update.\"\"\"\n return {}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Na.compute_current","title":"compute_current(states, v, params)
","text":"Return current.
Source code injaxley/channels/pospischil.py
def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n \"\"\"Return current.\"\"\"\n prefix = self._name\n m, h = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"]\n\n gNa = params[f\"{prefix}_gNa\"] * (m**3) * h # S/cm^2\n\n current = gNa * (v - params[\"eNa\"])\n return current\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Na.init_state","title":"init_state(states, v, params, delta_t)
","text":"Initialize the state such at fixed point of gate dynamics.
Source code injaxley/channels/pospischil.py
def init_state(self, states, v, params, delta_t):\n \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n prefix = self._name\n alpha_m, beta_m = self.m_gate(v, params[\"vt\"])\n alpha_h, beta_h = self.h_gate(v, params[\"vt\"])\n return {\n f\"{prefix}_m\": alpha_m / (alpha_m + beta_m),\n f\"{prefix}_h\": alpha_h / (alpha_h + beta_h),\n }\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Na.update_states","title":"update_states(states, dt, v, params)
","text":"Update state.
Source code injaxley/channels/pospischil.py
def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n):\n \"\"\"Update state.\"\"\"\n prefix = self._name\n m, h = states[f\"{prefix}_m\"], states[f\"{prefix}_h\"]\n new_m = solve_gate_exponential(m, dt, *self.m_gate(v, params[\"vt\"]))\n new_h = solve_gate_exponential(h, dt, *self.h_gate(v, params[\"vt\"]))\n return {f\"{prefix}_m\": new_m, f\"{prefix}_h\": new_h}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.K.compute_current","title":"compute_current(states, v, params)
","text":"Return current.
Source code injaxley/channels/pospischil.py
def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n \"\"\"Return current.\"\"\"\n prefix = self._name\n n = states[f\"{prefix}_n\"]\n\n gK = params[f\"{prefix}_gK\"] * (n**4) # S/cm^2\n\n return gK * (v - params[\"eK\"])\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.K.init_state","title":"init_state(states, v, params, delta_t)
","text":"Initialize the state such at fixed point of gate dynamics.
Source code injaxley/channels/pospischil.py
def init_state(self, states, v, params, delta_t):\n \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n prefix = self._name\n alpha_n, beta_n = self.n_gate(v, params[\"vt\"])\n return {f\"{prefix}_n\": alpha_n / (alpha_n + beta_n)}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.K.update_states","title":"update_states(states, dt, v, params)
","text":"Update state.
Source code injaxley/channels/pospischil.py
def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n):\n \"\"\"Update state.\"\"\"\n prefix = self._name\n n = states[f\"{prefix}_n\"]\n new_n = solve_gate_exponential(n, dt, *self.n_gate(v, params[\"vt\"]))\n return {f\"{prefix}_n\": new_n}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Km.compute_current","title":"compute_current(states, v, params)
","text":"Return current.
Source code injaxley/channels/pospischil.py
def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n \"\"\"Return current.\"\"\"\n prefix = self._name\n p = states[f\"{prefix}_p\"]\n\n gKm = params[f\"{prefix}_gKm\"] * p # S/cm^2\n return gKm * (v - params[\"eK\"])\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Km.init_state","title":"init_state(states, v, params, delta_t)
","text":"Initialize the state such at fixed point of gate dynamics.
Source code injaxley/channels/pospischil.py
def init_state(self, states, v, params, delta_t):\n \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n prefix = self._name\n alpha_p, beta_p = self.p_gate(v, params[f\"{prefix}_taumax\"])\n return {f\"{prefix}_p\": alpha_p / (alpha_p + beta_p)}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.Km.update_states","title":"update_states(states, dt, v, params)
","text":"Update state.
Source code injaxley/channels/pospischil.py
def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n):\n \"\"\"Update state.\"\"\"\n prefix = self._name\n p = states[f\"{prefix}_p\"]\n new_p = solve_inf_gate_exponential(\n p, dt, *self.p_gate(v, params[f\"{prefix}_taumax\"])\n )\n return {f\"{prefix}_p\": new_p}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaL.compute_current","title":"compute_current(states, v, params)
","text":"Return current.
Source code injaxley/channels/pospischil.py
def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n \"\"\"Return current.\"\"\"\n prefix = self._name\n q, r = states[f\"{prefix}_q\"], states[f\"{prefix}_r\"]\n gCaL = params[f\"{prefix}_gCaL\"] * (q**2) * r # S/cm^2\n\n return gCaL * (v - params[\"eCa\"])\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaL.init_state","title":"init_state(states, v, params, delta_t)
","text":"Initialize the state such at fixed point of gate dynamics.
Source code injaxley/channels/pospischil.py
def init_state(self, states, v, params, delta_t):\n \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n prefix = self._name\n alpha_q, beta_q = self.q_gate(v)\n alpha_r, beta_r = self.r_gate(v)\n return {\n f\"{prefix}_q\": alpha_q / (alpha_q + beta_q),\n f\"{prefix}_r\": alpha_r / (alpha_r + beta_r),\n }\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaL.update_states","title":"update_states(states, dt, v, params)
","text":"Update state.
Source code injaxley/channels/pospischil.py
def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n):\n \"\"\"Update state.\"\"\"\n prefix = self._name\n q, r = states[f\"{prefix}_q\"], states[f\"{prefix}_r\"]\n new_q = solve_gate_exponential(q, dt, *self.q_gate(v))\n new_r = solve_gate_exponential(r, dt, *self.r_gate(v))\n return {f\"{prefix}_q\": new_q, f\"{prefix}_r\": new_r}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaT.compute_current","title":"compute_current(states, v, params)
","text":"Return current.
Source code injaxley/channels/pospischil.py
def compute_current(\n self, states: Dict[str, jnp.ndarray], v, params: Dict[str, jnp.ndarray]\n):\n \"\"\"Return current.\"\"\"\n prefix = self._name\n u = states[f\"{prefix}_u\"]\n s_inf = 1.0 / (1.0 + save_exp(-(v + params[f\"{prefix}_vx\"] + 57.0) / 6.2))\n\n gCaT = params[f\"{prefix}_gCaT\"] * (s_inf**2) * u # S/cm^2\n\n return gCaT * (v - params[\"eCa\"])\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaT.init_state","title":"init_state(states, v, params, delta_t)
","text":"Initialize the state such at fixed point of gate dynamics.
Source code injaxley/channels/pospischil.py
def init_state(self, states, v, params, delta_t):\n \"\"\"Initialize the state such at fixed point of gate dynamics.\"\"\"\n prefix = self._name\n alpha_u, beta_u = self.u_gate(v, params[f\"{prefix}_vx\"])\n return {f\"{prefix}_u\": alpha_u / (alpha_u + beta_u)}\n
"},{"location":"reference/mechanisms/#jaxley.channels.pospischil.CaT.update_states","title":"update_states(states, dt, v, params)
","text":"Update state.
Source code injaxley/channels/pospischil.py
def update_states(\n self,\n states: Dict[str, jnp.ndarray],\n dt,\n v,\n params: Dict[str, jnp.ndarray],\n):\n \"\"\"Update state.\"\"\"\n prefix = self._name\n u = states[f\"{prefix}_u\"]\n new_u = solve_inf_gate_exponential(\n u, dt, *self.u_gate(v, params[f\"{prefix}_vx\"])\n )\n return {f\"{prefix}_u\": new_u}\n
"},{"location":"reference/mechanisms/#synapses","title":"Synapses","text":""},{"location":"reference/mechanisms/#synapse","title":"Synapse","text":"Base class for a synapse.
As in NEURON, a Synapse
is considered a point process, which means that its conductances are to be specified in uS
and its currents are to be specified in nA
.
jaxley/synapses/synapse.py
class Synapse:\n \"\"\"Base class for a synapse.\n\n As in NEURON, a `Synapse` is considered a point process, which means that its\n conductances are to be specified in `uS` and its currents are to be specified in\n `nA`.\n \"\"\"\n\n _name = None\n synapse_params = None\n synapse_states = None\n\n def __init__(self, name: Optional[str] = None):\n self._name = name if name else self.__class__.__name__\n\n @property\n def name(self) -> Optional[str]:\n return self._name\n\n def change_name(self, new_name: str):\n \"\"\"Change the synapse name.\n\n Args:\n new_name: The new name of the channel.\n\n Returns:\n Renamed channel, such that this function is chainable.\n \"\"\"\n old_prefix = self._name + \"_\"\n new_prefix = new_name + \"_\"\n\n self._name = new_name\n self.synapse_params = {\n (\n new_prefix + key[len(old_prefix) :]\n if key.startswith(old_prefix)\n else key\n ): value\n for key, value in self.synapse_params.items()\n }\n\n self.synapse_states = {\n (\n new_prefix + key[len(old_prefix) :]\n if key.startswith(old_prefix)\n else key\n ): value\n for key, value in self.synapse_states.items()\n }\n return self\n\n def update_states(\n states: Dict[str, jnp.ndarray],\n delta_t: float,\n pre_voltage: jnp.ndarray,\n post_voltage: jnp.ndarray,\n params: Dict[str, jnp.ndarray],\n ) -> Dict[str, jnp.ndarray]:\n \"\"\"ODE update step.\n\n Args:\n states: States of the synapse.\n delta_t: Time step in `ms`.\n pre_voltage: Voltage of the presynaptic compartment, shape `()`.\n post_voltage: Voltage of the postsynaptic compartment, shape `()`.\n params: Parameters of the synapse. Conductances in `uS`.\n\n Returns:\n Updated states.\"\"\"\n raise NotImplementedError\n\n def compute_current(\n states: Dict[str, jnp.ndarray],\n pre_voltage: jnp.ndarray,\n post_voltage: jnp.ndarray,\n params: Dict[str, jnp.ndarray],\n ) -> jnp.ndarray:\n \"\"\"Return current through one synapse in `nA`.\n\n Internally, we use `jax.vmap` to vectorize this function across many synapses.\n\n Args:\n states: States of the synapse.\n pre_voltage: Voltage of the presynaptic compartment, shape `()`.\n post_voltage: Voltage of the postsynaptic compartment, shape `()`.\n params: Parameters of the synapse. Conductances in `uS`.\n\n Returns:\n Current through the synapse in `nA`, shape `()`.\n \"\"\"\n raise NotImplementedError\n
"},{"location":"reference/mechanisms/#jaxley.synapses.synapse.Synapse.change_name","title":"change_name(new_name)
","text":"Change the synapse name.
Parameters:
Name Type Description Defaultnew_name
str
The new name of the channel.
requiredReturns:
Type DescriptionRenamed channel, such that this function is chainable.
Source code injaxley/synapses/synapse.py
def change_name(self, new_name: str):\n \"\"\"Change the synapse name.\n\n Args:\n new_name: The new name of the channel.\n\n Returns:\n Renamed channel, such that this function is chainable.\n \"\"\"\n old_prefix = self._name + \"_\"\n new_prefix = new_name + \"_\"\n\n self._name = new_name\n self.synapse_params = {\n (\n new_prefix + key[len(old_prefix) :]\n if key.startswith(old_prefix)\n else key\n ): value\n for key, value in self.synapse_params.items()\n }\n\n self.synapse_states = {\n (\n new_prefix + key[len(old_prefix) :]\n if key.startswith(old_prefix)\n else key\n ): value\n for key, value in self.synapse_states.items()\n }\n return self\n
"},{"location":"reference/mechanisms/#jaxley.synapses.synapse.Synapse.compute_current","title":"compute_current(states, pre_voltage, post_voltage, params)
","text":"Return current through one synapse in nA
.
Internally, we use jax.vmap
to vectorize this function across many synapses.
Parameters:
Name Type Description Defaultstates
Dict[str, ndarray]
States of the synapse.
requiredpre_voltage
ndarray
Voltage of the presynaptic compartment, shape ()
.
post_voltage
ndarray
Voltage of the postsynaptic compartment, shape ()
.
params
Dict[str, ndarray]
Parameters of the synapse. Conductances in uS
.
Returns:
Type Descriptionndarray
Current through the synapse in nA
, shape ()
.
jaxley/synapses/synapse.py
def compute_current(\n states: Dict[str, jnp.ndarray],\n pre_voltage: jnp.ndarray,\n post_voltage: jnp.ndarray,\n params: Dict[str, jnp.ndarray],\n) -> jnp.ndarray:\n \"\"\"Return current through one synapse in `nA`.\n\n Internally, we use `jax.vmap` to vectorize this function across many synapses.\n\n Args:\n states: States of the synapse.\n pre_voltage: Voltage of the presynaptic compartment, shape `()`.\n post_voltage: Voltage of the postsynaptic compartment, shape `()`.\n params: Parameters of the synapse. Conductances in `uS`.\n\n Returns:\n Current through the synapse in `nA`, shape `()`.\n \"\"\"\n raise NotImplementedError\n
"},{"location":"reference/mechanisms/#jaxley.synapses.synapse.Synapse.update_states","title":"update_states(states, delta_t, pre_voltage, post_voltage, params)
","text":"ODE update step.
Parameters:
Name Type Description Defaultstates
Dict[str, ndarray]
States of the synapse.
requireddelta_t
float
Time step in ms
.
pre_voltage
ndarray
Voltage of the presynaptic compartment, shape ()
.
post_voltage
ndarray
Voltage of the postsynaptic compartment, shape ()
.
params
Dict[str, ndarray]
Parameters of the synapse. Conductances in uS
.
Returns:
Type DescriptionDict[str, ndarray]
Updated states.
Source code injaxley/synapses/synapse.py
def update_states(\n states: Dict[str, jnp.ndarray],\n delta_t: float,\n pre_voltage: jnp.ndarray,\n post_voltage: jnp.ndarray,\n params: Dict[str, jnp.ndarray],\n) -> Dict[str, jnp.ndarray]:\n \"\"\"ODE update step.\n\n Args:\n states: States of the synapse.\n delta_t: Time step in `ms`.\n pre_voltage: Voltage of the presynaptic compartment, shape `()`.\n post_voltage: Voltage of the postsynaptic compartment, shape `()`.\n params: Parameters of the synapse. Conductances in `uS`.\n\n Returns:\n Updated states.\"\"\"\n raise NotImplementedError\n
"},{"location":"reference/mechanisms/#ionotropic-synapse","title":"Ionotropic Synapse","text":" Bases: Synapse
Compute synaptic current and update synapse state for a generic ionotropic synapse.
The synapse state \u201cs\u201d is the probability that a postsynaptic receptor channel is open, and this depends on the amount of neurotransmitter released, which is in turn dependent on the presynaptic voltage.
The synaptic parameters areL. F. Abbott and E. Marder, \u201cModeling Small Networks,\u201d in Methods in Neuronal Modeling, C. Koch and I. Sergev, Eds. Cambridge: MIT Press, 1998.
Source code injaxley/synapses/ionotropic.py
class IonotropicSynapse(Synapse):\n \"\"\"\n Compute synaptic current and update synapse state for a generic ionotropic synapse.\n\n The synapse state \"s\" is the probability that a postsynaptic receptor channel is\n open, and this depends on the amount of neurotransmitter released, which is in turn\n dependent on the presynaptic voltage.\n\n The synaptic parameters are:\n - gS: the maximal conductance across the postsynaptic membrane (uS)\n - e_syn: the reversal potential across the postsynaptic membrane (mV)\n - k_minus: the rate constant of neurotransmitter unbinding from the postsynaptic\n receptor (s^-1)\n\n Details of this implementation can be found in the following book chapter:\n L. F. Abbott and E. Marder, \"Modeling Small Networks,\" in Methods in Neuronal\n Modeling, C. Koch and I. Sergev, Eds. Cambridge: MIT Press, 1998.\n\n \"\"\"\n\n def __init__(self, name: Optional[str] = None):\n super().__init__(name)\n prefix = self._name\n self.synapse_params = {\n f\"{prefix}_gS\": 1e-4,\n f\"{prefix}_e_syn\": 0.0,\n f\"{prefix}_k_minus\": 0.025,\n }\n self.synapse_states = {f\"{prefix}_s\": 0.2}\n\n def update_states(\n self,\n states: Dict,\n delta_t: float,\n pre_voltage: float,\n post_voltage: float,\n params: Dict,\n ) -> Dict:\n \"\"\"Return updated synapse state and current.\"\"\"\n prefix = self._name\n v_th = -35.0 # mV\n delta = 10.0 # mV\n\n s_inf = 1.0 / (1.0 + save_exp((v_th - pre_voltage) / delta))\n tau_s = (1.0 - s_inf) / params[f\"{prefix}_k_minus\"]\n\n slope = -1.0 / tau_s\n exp_term = save_exp(slope * delta_t)\n new_s = states[f\"{prefix}_s\"] * exp_term + s_inf * (1.0 - exp_term)\n return {f\"{prefix}_s\": new_s}\n\n def compute_current(\n self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict\n ) -> float:\n prefix = self._name\n g_syn = params[f\"{prefix}_gS\"] * states[f\"{prefix}_s\"]\n return g_syn * (post_voltage - params[f\"{prefix}_e_syn\"])\n
"},{"location":"reference/mechanisms/#jaxley.synapses.ionotropic.IonotropicSynapse.update_states","title":"update_states(states, delta_t, pre_voltage, post_voltage, params)
","text":"Return updated synapse state and current.
Source code injaxley/synapses/ionotropic.py
def update_states(\n self,\n states: Dict,\n delta_t: float,\n pre_voltage: float,\n post_voltage: float,\n params: Dict,\n) -> Dict:\n \"\"\"Return updated synapse state and current.\"\"\"\n prefix = self._name\n v_th = -35.0 # mV\n delta = 10.0 # mV\n\n s_inf = 1.0 / (1.0 + save_exp((v_th - pre_voltage) / delta))\n tau_s = (1.0 - s_inf) / params[f\"{prefix}_k_minus\"]\n\n slope = -1.0 / tau_s\n exp_term = save_exp(slope * delta_t)\n new_s = states[f\"{prefix}_s\"] * exp_term + s_inf * (1.0 - exp_term)\n return {f\"{prefix}_s\": new_s}\n
"},{"location":"reference/mechanisms/#tanh-rate-synapse","title":"TanH Rate Synapse","text":" Bases: Synapse
Compute synaptic current for tanh synapse (no state).
Source code injaxley/synapses/tanh_rate.py
class TanhRateSynapse(Synapse):\n \"\"\"\n Compute synaptic current for tanh synapse (no state).\n \"\"\"\n\n def __init__(self, name: Optional[str] = None):\n super().__init__(name)\n prefix = self._name\n self.synapse_params = {\n f\"{prefix}_gS\": 1e-4,\n f\"{prefix}_x_offset\": -70.0,\n f\"{prefix}_slope\": 1.0,\n }\n self.synapse_states = {}\n\n def update_states(\n self,\n states: Dict,\n delta_t: float,\n pre_voltage: float,\n post_voltage: float,\n params: Dict,\n ) -> Dict:\n \"\"\"Return updated synapse state and current.\"\"\"\n return {}\n\n def compute_current(\n self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict\n ) -> float:\n \"\"\"Return updated synapse state and current.\"\"\"\n prefix = self._name\n current = (\n -1\n * params[f\"{prefix}_gS\"]\n * jnp.tanh(\n (pre_voltage - params[f\"{prefix}_x_offset\"]) * params[f\"{prefix}_slope\"]\n )\n )\n return current\n
"},{"location":"reference/mechanisms/#jaxley.synapses.tanh_rate.TanhRateSynapse.compute_current","title":"compute_current(states, pre_voltage, post_voltage, params)
","text":"Return updated synapse state and current.
Source code injaxley/synapses/tanh_rate.py
def compute_current(\n self, states: Dict, pre_voltage: float, post_voltage: float, params: Dict\n) -> float:\n \"\"\"Return updated synapse state and current.\"\"\"\n prefix = self._name\n current = (\n -1\n * params[f\"{prefix}_gS\"]\n * jnp.tanh(\n (pre_voltage - params[f\"{prefix}_x_offset\"]) * params[f\"{prefix}_slope\"]\n )\n )\n return current\n
"},{"location":"reference/mechanisms/#jaxley.synapses.tanh_rate.TanhRateSynapse.update_states","title":"update_states(states, delta_t, pre_voltage, post_voltage, params)
","text":"Return updated synapse state and current.
Source code injaxley/synapses/tanh_rate.py
def update_states(\n self,\n states: Dict,\n delta_t: float,\n pre_voltage: float,\n post_voltage: float,\n params: Dict,\n) -> Dict:\n \"\"\"Return updated synapse state and current.\"\"\"\n return {}\n
"},{"location":"reference/modules/","title":"Modules","text":""},{"location":"reference/modules/#module","title":"Module","text":" Bases: ABC
Module base class.
Modules are everything that can be passed to jx.integrate
, i.e. compartments, branches, cells, and networks.
This base class defines the scaffold for all jaxley modules (compartments, branches, cells, networks).
Modules can be traversed and modified using the at
, cell
, branch
, comp
, edge
, and loc
methods. The scope
method can be used to toggle between global and local indices. Traversal of Modules will return a View
of itself, that has a modified set of attributes, which only consider the part of the Module that is in view.
For developers: The above has consequences for how to operate on Module
and which changes take affect where. The following guidelines should be followed (copied from View
):
self.base
. In order to enssure that these changes only affects whatever is currently in view self._nodes_in_view
, or self._edges_in_view
among others have to be used. Operating on nodes currently in view can for example be done with self.base.node.loc[self._nodes_in_view]
.xyzr
, needs to modified when View is instantiated. I.e. xyzr
of cell.branch(0)
, should be [self.base.xyzr[0]]
This could be achieved via: [self.base.xyzr[b] for b in self._branches_in_view]
.For developers: If you want to add a new method to Module
, here is an example of how to make methods of Module compatible with View:
.. code-block:: python
# Use data in view to return something.\ndef count_small_branches(self):\n # no need to use self.base.attr + viewed indices,\n # since no change is made to the attr in question (nodes)\n comp_lens = self.nodes[\"length\"]\n branch_lens = comp_lens.groupby(\"global_branch_index\").sum()\n return np.sum(branch_lens < 10)\n\n# Change data in view.\ndef change_attr_in_view(self):\n # changes to attrs have to be made via self.base.attr + viewed indices\n a = func1(self.base.attr1[self._cells_in_view])\n b = func2(self.base.attr2[self._edges_in_view])\n self.base.attr3[self._branches_in_view] = a + b\n
Source code in jaxley/modules/base.py
class Module(ABC):\n \"\"\"Module base class.\n\n Modules are everything that can be passed to `jx.integrate`, i.e. compartments,\n branches, cells, and networks.\n\n This base class defines the scaffold for all jaxley modules (compartments,\n branches, cells, networks).\n\n Modules can be traversed and modified using the `at`, `cell`, `branch`, `comp`,\n `edge`, and `loc` methods. The `scope` method can be used to toggle between\n global and local indices. Traversal of Modules will return a `View` of itself,\n that has a modified set of attributes, which only consider the part of the Module\n that is in view.\n\n For developers: The above has consequences for how to operate on `Module` and which\n changes take affect where. The following guidelines should be followed (copied from\n `View`):\n\n 1. We consider a Module to have everything in view.\n 2. Views can display and keep track of how a module is traversed. But(!),\n do not support making changes or setting variables. This still has to be\n done in the base Module, i.e. `self.base`. In order to enssure that these\n changes only affects whatever is currently in view `self._nodes_in_view`,\n or `self._edges_in_view` among others have to be used. Operating on nodes\n currently in view can for example be done with\n `self.base.node.loc[self._nodes_in_view]`.\n 3. Every attribute of Module that changes based on what's in view, i.e. `xyzr`,\n needs to modified when View is instantiated. I.e. `xyzr` of `cell.branch(0)`,\n should be `[self.base.xyzr[0]]` This could be achieved via:\n `[self.base.xyzr[b] for b in self._branches_in_view]`.\n\n For developers: If you want to add a new method to `Module`, here is an example of\n how to make methods of Module compatible with View:\n\n .. code-block:: python\n\n # Use data in view to return something.\n def count_small_branches(self):\n # no need to use self.base.attr + viewed indices,\n # since no change is made to the attr in question (nodes)\n comp_lens = self.nodes[\"length\"]\n branch_lens = comp_lens.groupby(\"global_branch_index\").sum()\n return np.sum(branch_lens < 10)\n\n # Change data in view.\n def change_attr_in_view(self):\n # changes to attrs have to be made via self.base.attr + viewed indices\n a = func1(self.base.attr1[self._cells_in_view])\n b = func2(self.base.attr2[self._edges_in_view])\n self.base.attr3[self._branches_in_view] = a + b\n \"\"\"\n\n def __init__(self):\n self.ncomp: int = None\n self.total_nbranches: int = 0\n self.nbranches_per_cell: List[int] = None\n\n self.groups = {}\n\n self.nodes: Optional[pd.DataFrame] = None\n self._scope = \"local\" # defaults to local scope\n self._nodes_in_view: np.ndarray = None\n self._edges_in_view: np.ndarray = None\n\n self.edges = pd.DataFrame(\n columns=[\n \"global_edge_index\",\n \"pre_global_comp_index\",\n \"post_global_comp_index\",\n \"pre_locs\",\n \"post_locs\",\n \"type\",\n \"type_ind\",\n ]\n )\n\n self._cumsum_nbranches: Optional[np.ndarray] = None\n\n self.comb_parents: jnp.ndarray = jnp.asarray([-1])\n\n self.initialized_morph: bool = False\n self.initialized_syns: bool = False\n\n # List of all types of `jx.Synapse`s.\n self.synapses: List = []\n self.synapse_param_names = []\n self.synapse_state_names = []\n self.synapse_names = []\n\n # List of types of all `jx.Channel`s.\n self.channels: List[Channel] = []\n self.membrane_current_names: List[str] = []\n\n # For trainable parameters.\n self.indices_set_by_trainables: List[jnp.ndarray] = []\n self.trainable_params: List[Dict[str, jnp.ndarray]] = []\n self.allow_make_trainable: bool = True\n self.num_trainable_params: int = 0\n\n # For recordings.\n self.recordings: pd.DataFrame = pd.DataFrame().from_dict({})\n\n # For stimuli or clamps.\n # E.g. `self.externals = {\"v\": zeros(1000,2), \"i\": ones(1000, 2)}`\n # for 1000 timesteps and two compartments.\n self.externals: Dict[str, jnp.ndarray] = {}\n # E.g. `self.external)inds = {\"v\": jnp.asarray([0,1]), \"i\": jnp.asarray([2,3])}`\n self.external_inds: Dict[str, jnp.ndarray] = {}\n\n # x, y, z coordinates and radius.\n self.xyzr: List[np.ndarray] = []\n self._radius_generating_fns = None # Defined by `.read_swc()`.\n\n # For debugging the solver. Will be empty by default and only filled if\n # `self._init_morph_for_debugging` is run.\n self.debug_states = {}\n\n # needs to be set at the end\n self.base: Module = self\n\n def __repr__(self):\n return f\"{type(self).__name__} with {len(self.channels)} different channels. Use `.nodes` for details.\"\n\n def __str__(self):\n return f\"jx.{type(self).__name__}\"\n\n def __dir__(self):\n base_dir = object.__dir__(self)\n return sorted(base_dir + self.synapse_names + list(self.group_nodes.keys()))\n\n def __getattr__(self, key):\n # Ensure that hidden methods such as `__deepcopy__` still work.\n if key.startswith(\"__\"):\n return super().__getattribute__(key)\n\n # intercepts calls to groups\n if key in self.base.groups:\n view = (\n self.select(self.groups[key])\n if key in self.groups\n else self.select(None)\n )\n view._set_controlled_by_param(key)\n return view\n\n # intercepts calls to channels\n if key in [c._name for c in self.base.channels]:\n channel_names = [c._name for c in self.channels]\n inds = self.nodes.index[self.nodes[key]].to_numpy()\n view = self.select(inds) if key in channel_names else self.select(None)\n view._set_controlled_by_param(key)\n return view\n\n # intercepts calls to synapse types\n if key in self.base.synapse_names:\n syn_inds = self.edges[self.edges[\"type\"] == key][\n \"global_edge_index\"\n ].to_numpy()\n orig_scope = self._scope\n view = (\n self.scope(\"global\").edge(syn_inds).scope(orig_scope)\n if key in self.synapse_names\n else self.select(None)\n )\n view._set_controlled_by_param(key) # overwrites param set by edge\n # Ensure synapse param sharing works with `edge`\n # `edge` will be removed as part of #463\n view.edges[\"local_edge_index\"] = np.arange(len(view.edges))\n return view\n\n def _childviews(self) -> List[str]:\n \"\"\"Returns levels that module can be viewed at.\n\n I.e. for net -> [cell, branch, comp]. For branch -> [comp]\"\"\"\n levels = [\"network\", \"cell\", \"branch\", \"comp\"]\n if self._current_view in levels:\n children = levels[levels.index(self._current_view) + 1 :]\n return children\n return []\n\n def _has_childview(self, key: str) -> bool:\n child_views = self._childviews()\n return key in child_views\n\n def __getitem__(self, index):\n \"\"\"Lazy indexing of the module.\"\"\"\n supported_parents = [\"network\", \"cell\", \"branch\"] # cannot index into comp\n\n not_group_view = self._current_view not in self.groups\n assert (\n self._current_view in supported_parents or not_group_view\n ), \"Lazy indexing is only supported for `Network`, `Cell`, `Branch` and Views thereof.\"\n index = index if isinstance(index, tuple) else (index,)\n\n child_views = self._childviews()\n assert len(index) <= len(child_views), \"Too many indices.\"\n view = self\n for i, child in zip(index, child_views):\n view = view._at_nodes(child, i)\n return view\n\n def _update_local_indices(self) -> pd.DataFrame:\n \"\"\"Compute local indices from the global indices that are in view.\n This is recomputed everytime a View is created.\"\"\"\n rerank = lambda df: df.rank(method=\"dense\").astype(int) - 1\n\n def reorder_cols(\n df: pd.DataFrame, cols: List[str], first: bool = True\n ) -> pd.DataFrame:\n \"\"\"Move cols to front/back.\n\n Args:\n df: DataFrame to reorder.\n cols: List of columns to place before/after remaining columns.\n first: If True, cols are placed in front, otherwise at the end.\n\n Returns:\n DataFrame with reordered columns.\"\"\"\n new_cols = [col for col in df.columns if first == (col in cols)]\n new_cols += [col for col in df.columns if first != (col in cols)]\n return df[new_cols]\n\n def reindex_a_by_b(\n df: pd.DataFrame, a: str, b: Optional[Union[str, List[str]]] = None\n ) -> pd.DataFrame:\n \"\"\"Reindex based on a different col or several columns\n for b=[0,0,1,1,2,2,2] -> a=[0,1,0,1,0,1,2]\"\"\"\n grouped_df = df.groupby(b) if b is not None else df\n df.loc[:, a] = rerank(grouped_df[a])\n return df\n\n index_names = [\"cell_index\", \"branch_index\", \"comp_index\"] # order is important\n global_idx_cols = [f\"global_{name}\" for name in index_names]\n local_idx_cols = [f\"local_{name}\" for name in index_names]\n idcs = self.nodes[global_idx_cols]\n\n # update local indices of nodes\n idcs = reindex_a_by_b(idcs, global_idx_cols[0])\n idcs = reindex_a_by_b(idcs, global_idx_cols[1], global_idx_cols[0])\n idcs = reindex_a_by_b(idcs, global_idx_cols[2], global_idx_cols[:2])\n idcs.columns = [col.replace(\"global\", \"local\") for col in global_idx_cols]\n self.nodes[local_idx_cols] = idcs[local_idx_cols].astype(int)\n\n # move indices to the front of the dataframe; move controlled_by_param to the end\n # move indices of current scope to the front and the others to the back\n not_scope = \"global\" if self._scope == \"local\" else \"local\"\n self.nodes = reorder_cols(\n self.nodes, [f\"{self._scope}_{name}\" for name in index_names], first=True\n )\n self.nodes = reorder_cols(\n self.nodes, [f\"{not_scope}_{name}\" for name in index_names], first=False\n )\n\n self.edges = reorder_cols(self.edges, [\"global_edge_index\"])\n self.nodes = reorder_cols(self.nodes, [\"controlled_by_param\"], first=False)\n self.edges = reorder_cols(self.edges, [\"controlled_by_param\"], first=False)\n\n def _init_view(self):\n \"\"\"Init attributes critical for View.\n\n Needs to be called at init of a Module.\"\"\"\n parent = self.__class__.__name__.lower()\n self._current_view = \"comp\" if parent == \"compartment\" else parent\n self._nodes_in_view = self.nodes.index.to_numpy()\n self._edges_in_view = self.edges.index.to_numpy()\n self.nodes[\"controlled_by_param\"] = 0\n\n def _compute_coords_of_comp_centers(self) -> np.ndarray:\n \"\"\"Compute xyz coordinates of compartment centers.\n\n Centers are the midpoint between the comparment endpoints on the morphology\n as defined by xyzr.\n\n Note: For sake of performance, interpolation is not done for each branch\n individually, but only once along a concatenated (and padded) array of all branches.\n This means for ncomps = [2,4] and normalized cum_branch_lens of [[0,1],[0,1]] we would\n interpolate xyz at the locations comp_ends = [[0,0.5,1], [0,0.25,0.5,0.75,1]],\n where 0 is the start of the branch and 1 is the end point at the full branch_len.\n To avoid do this in one go we set comp_ends = [0,0.5,1,2,2.25,2.5,2.75,3], and\n norm_cum_branch_len = [0,1,2,3] incrememting and also padding them by 1 to\n avoid overlapping branch_lens i.e. norm_cum_branch_len = [0,1,1,2] for only\n incrementing.\n \"\"\"\n nodes_by_branches = self.nodes.groupby(\"global_branch_index\")\n ncomps = nodes_by_branches[\"global_comp_index\"].nunique().to_numpy()\n\n comp_ends = [\n np.linspace(0, 1, ncomp + 1) + 2 * i for i, ncomp in enumerate(ncomps)\n ]\n comp_ends = np.hstack(comp_ends)\n\n comp_ends = comp_ends.reshape(-1)\n cum_branch_lens = []\n for i, xyzr in enumerate(self.xyzr):\n branch_len = np.sqrt(np.sum(np.diff(xyzr[:, :3], axis=0) ** 2, axis=1))\n cum_branch_len = np.cumsum(np.concatenate([np.array([0]), branch_len]))\n max_len = cum_branch_len.max()\n # add padding like above\n cum_branch_len = cum_branch_len / (max_len if max_len > 0 else 1) + 2 * i\n cum_branch_len[np.isnan(cum_branch_len)] = 0\n cum_branch_lens.append(cum_branch_len)\n cum_branch_lens = np.hstack(cum_branch_lens)\n xyz = np.vstack(self.xyzr)[:, :3]\n xyz = v_interp(comp_ends, cum_branch_lens, xyz).T\n centers = (xyz[:-1] + xyz[1:]) / 2 # unaware of inter vs intra comp centers\n cum_ncomps = np.cumsum(ncomps)\n # this means centers between comps have to be removed here\n between_comp_inds = (cum_ncomps + np.arange(len(cum_ncomps)))[:-1]\n centers = np.delete(centers, between_comp_inds, axis=0)\n return centers\n\n def compute_compartment_centers(self):\n \"\"\"Add compartment centers to nodes dataframe\"\"\"\n centers = self._compute_coords_of_comp_centers()\n self.base.nodes.loc[self._nodes_in_view, [\"x\", \"y\", \"z\"]] = centers\n\n def _reformat_index(self, idx: Any, dtype: type = int) -> np.ndarray:\n \"\"\"Transforms different types of indices into an array.\n\n Takes slice, list, array, ints, range and None and transforms\n it into array of indices. If index == \"all\" it returns \"all\"\n to be handled downstream.\n\n Args:\n idx: index that specifies at which locations to view the module.\n dtype: defaults to int, but can also reformat float for use in `loc`\n\n Returns:\n array of indices of shape (N,)\"\"\"\n if is_str_all(idx): # also asserts that the only allowed str == \"all\"\n return idx\n\n np_dtype = np.int64 if dtype is int else np.float64\n idx = np.array([], dtype=dtype) if idx is None else idx\n idx = np.array([idx]) if isinstance(idx, (dtype, np_dtype)) else idx\n idx = np.array(idx) if isinstance(idx, (list, range, pd.Index)) else idx\n\n idx = np.arange(len(self.base.nodes))[idx] if isinstance(idx, slice) else idx\n if idx.dtype == bool:\n shape = (*self.shape, len(self.edges))\n which_idx = len(idx) == np.array(shape)\n assert np.any(which_idx), \"Index not matching num of cells/branches/comps.\"\n dim = shape[np.where(which_idx)[0][0]]\n idx = np.arange(dim)[idx]\n assert isinstance(idx, np.ndarray), \"Invalid type\"\n assert idx.dtype in [np_dtype, bool], \"Invalid dtype\"\n return idx.reshape(-1)\n\n def _set_controlled_by_param(self, key: str):\n \"\"\"Determines which parameters are shared in `make_trainable`.\n\n Adds column to nodes/edges dataframes to read of shared params from.\n\n Args:\n key: key specifying group / view that is in control of the params.\"\"\"\n if key in [\"comp\", \"branch\", \"cell\"]:\n self.nodes[\"controlled_by_param\"] = self.nodes[f\"global_{key}_index\"]\n self.edges[\"controlled_by_param\"] = 0\n elif key == \"edge\":\n self.edges[\"controlled_by_param\"] = np.arange(len(self.edges))\n elif key == \"filter\":\n self.nodes[\"controlled_by_param\"] = np.arange(len(self.nodes))\n self.edges[\"controlled_by_param\"] = np.arange(len(self.edges))\n else:\n self.nodes[\"controlled_by_param\"] = 0\n self.edges[\"controlled_by_param\"] = 0\n self._current_view = key\n\n def select(\n self, nodes: np.ndarray = None, edges: np.ndarray = None, sorted: bool = False\n ) -> View:\n \"\"\"Return View of the module filtered by specific node or edges indices.\n\n Args:\n nodes: indices of nodes to view. If None, all nodes are viewed.\n edges: indices of edges to view. If None, all edges are viewed.\n sorted: if True, nodes and edges are sorted.\n\n Returns:\n View for subset of selected nodes and/or edges.\"\"\"\n\n nodes = self._reformat_index(nodes) if nodes is not None else None\n nodes = self._nodes_in_view if is_str_all(nodes) else nodes\n nodes = np.sort(nodes) if sorted else nodes\n\n edges = self._reformat_index(edges) if edges is not None else None\n edges = self._edges_in_view if is_str_all(edges) else edges\n edges = np.sort(edges) if sorted else edges\n\n view = View(self, nodes, edges)\n view._set_controlled_by_param(\"filter\")\n return view\n\n def set_scope(self, scope: str):\n \"\"\"Toggle between \"global\" or \"local\" scope.\n\n Determines if global or local indices are used for viewing the module.\n\n Args:\n scope: either \"global\" or \"local\".\"\"\"\n assert scope in [\"global\", \"local\"], \"Invalid scope.\"\n self._scope = scope\n\n def scope(self, scope: str) -> View:\n \"\"\"Return a View of the module with the specified scope.\n\n For example `cell.scope(\"global\").branch(2).scope(\"local\").comp(1)`\n will return the 1st compartment of branch 2.\n\n Args:\n scope: either \"global\" or \"local\".\n\n Returns:\n View with the specified scope.\"\"\"\n view = self.view\n view.set_scope(scope)\n return view\n\n def _at_nodes(self, key: str, idx: Any) -> View:\n \"\"\"Return a View of the module filtering `nodes` by specified key and index.\n\n Keys can be `cell`, `branch`, `comp` and determine which index is used to filter.\n \"\"\"\n base_name = self.base.__class__.__name__\n assert self.base._has_childview(key), f\"{base_name} does not support {key}.\"\n idx = self._reformat_index(idx)\n idx = self.nodes[self._scope + f\"_{key}_index\"] if is_str_all(idx) else idx\n where = self.nodes[self._scope + f\"_{key}_index\"].isin(idx)\n inds = self.nodes.index[where].to_numpy()\n\n view = View(self, nodes=inds)\n view._set_controlled_by_param(key)\n return view\n\n def _at_edges(self, key: str, idx: Any) -> View:\n \"\"\"Return a View of the module filtering `edges` by specified key and index.\n\n Keys can be `pre`, `post`, `edge` and determine which index is used to filter.\n \"\"\"\n idx = self._reformat_index(idx)\n idx = self.edges[self._scope + f\"_{key}_index\"] if is_str_all(idx) else idx\n where = self.edges[self._scope + f\"_{key}_index\"].isin(idx)\n inds = self.edges.index[where].to_numpy()\n\n view = View(self, edges=inds)\n view._set_controlled_by_param(key)\n return view\n\n def cell(self, idx: Any) -> View:\n \"\"\"Return a View of the module at the selected cell(s).\n\n Args:\n idx: index of the cell to view.\n\n Returns:\n View of the module at the specified cell index.\"\"\"\n return self._at_nodes(\"cell\", idx)\n\n def branch(self, idx: Any) -> View:\n \"\"\"Return a View of the module at the selected branches(s).\n\n Args:\n idx: index of the branch to view.\n\n Returns:\n View of the module at the specified branch index.\"\"\"\n return self._at_nodes(\"branch\", idx)\n\n def comp(self, idx: Any) -> View:\n \"\"\"Return a View of the module at the selected compartments(s).\n\n Args:\n idx: index of the comp to view.\n\n Returns:\n View of the module at the specified compartment index.\"\"\"\n return self._at_nodes(\"comp\", idx)\n\n def edge(self, idx: Any) -> View:\n \"\"\"Return a View of the module at the selected synapse edges(s).\n\n Args:\n idx: index of the edge to view.\n\n Returns:\n View of the module at the specified edge index.\"\"\"\n return self._at_edges(\"edge\", idx)\n\n def loc(self, at: Any) -> View:\n \"\"\"Return a View of the module at the selected branch location(s).\n\n Args:\n at: location along the branch.\n\n Returns:\n View of the module at the specified branch location.\"\"\"\n global_comp_idxs = []\n for i in self._branches_in_view:\n ncomp = self.base.ncomp_per_branch[i]\n comp_locs = np.linspace(0, 1, ncomp)\n at = comp_locs if is_str_all(at) else self._reformat_index(at, dtype=float)\n comp_edges = np.linspace(0, 1 + 1e-10, ncomp + 1)\n idx = np.digitize(at, comp_edges) - 1 + self.base.cumsum_ncomp[i]\n global_comp_idxs.append(idx)\n global_comp_idxs = np.concatenate(global_comp_idxs)\n orig_scope = self._scope\n # global scope needed to select correct comps, for i.e. branches w. ncomp=[1,2]\n # loc(0.9) will correspond to different local branches (0 vs 1).\n view = self.scope(\"global\").comp(global_comp_idxs).scope(orig_scope)\n view._current_view = \"loc\"\n return view\n\n @property\n def _comps_in_view(self):\n \"\"\"Lists the global compartment indices which are currently part of the view.\"\"\"\n # method also exists in View. this copy forgoes need to instantiate a View\n return self.nodes[\"global_comp_index\"].unique()\n\n @property\n def _branches_in_view(self):\n \"\"\"Lists the global branch indices which are currently part of the view.\"\"\"\n # method also exists in View. this copy forgoes need to instantiate a View\n return self.nodes[\"global_branch_index\"].unique()\n\n @property\n def _cells_in_view(self):\n \"\"\"Lists the global cell indices which are currently part of the view.\"\"\"\n # method also exists in View. this copy forgoes need to instantiate a View\n return self.nodes[\"global_cell_index\"].unique()\n\n def _iter_submodules(self, name: str):\n \"\"\"Iterate over submoduleslevel.\n\n Used for `cells`, `branches`, `comps`.\"\"\"\n col = self._scope + f\"_{name}_index\"\n idxs = self.nodes[col].unique()\n for idx in idxs:\n yield self._at_nodes(name, idx)\n\n @property\n def cells(self):\n \"\"\"Iterate over all cells in the module.\n\n Returns a generator that yields a View of each cell.\"\"\"\n yield from self._iter_submodules(\"cell\")\n\n @property\n def branches(self):\n \"\"\"Iterate over all branches in the module.\n\n Returns a generator that yields a View of each branch.\"\"\"\n yield from self._iter_submodules(\"branch\")\n\n @property\n def comps(self):\n \"\"\"Iterate over all compartments in the module.\n Can be called on any module, i.e. `net.comps`, `cell.comps` or\n `branch.comps`. `__iter__` does not allow for this.\n\n Returns a generator that yields a View of each compartment.\"\"\"\n yield from self._iter_submodules(\"comp\")\n\n def __iter__(self):\n \"\"\"Iterate over parts of the module.\n\n Internally calls `cells`, `branches`, `comps` at the appropriate level.\n\n Example:\n\n .. code-block:: python\n\n for cell in network:\n for branch in cell:\n for comp in branch:\n print(comp.nodes.shape)\n \"\"\"\n next_level = self._childviews()[0]\n yield from self._iter_submodules(next_level)\n\n @property\n def shape(self) -> Tuple[int]:\n \"\"\"Returns the number of submodules contained in a module.\n\n .. code-block:: python\n\n network.shape = (num_cells, num_branches, num_compartments)\n cell.shape = (num_branches, num_compartments)\n branch.shape = (num_compartments,)\n \"\"\"\n cols = [\"global_cell_index\", \"global_branch_index\", \"global_comp_index\"]\n raw_shape = self.nodes[cols].nunique().to_list()\n\n # ensure (net.shape -> dim=3, cell.shape -> dim=2, branch.shape -> dim=1, comp.shape -> dim=0)\n levels = [\"network\", \"cell\", \"branch\", \"comp\"]\n module = self.base.__class__.__name__.lower()\n module = \"comp\" if module == \"compartment\" else module\n shape = tuple(raw_shape[levels.index(module) :])\n return shape\n\n def copy(\n self, reset_index: bool = False, as_module: bool = False\n ) -> Union[Module, View]:\n \"\"\"Extract part of a module and return a copy of its View or a new module.\n\n This can be used to call `jx.integrate` on part of a Module.\n\n Args:\n reset_index: if True, the indices of the new module are reset to start from 0.\n as_module: if True, a new module is returned instead of a View.\n\n Returns:\n A part of the module or a copied view of it.\"\"\"\n view = deepcopy(self)\n warnings.warn(\"This method is experimental, use at your own risk.\")\n # TODO FROM #447: add reset_index, i.e. for parents, nodes, edges etc. such that they\n # start from 0/-1 and are contiguous\n if as_module:\n raise NotImplementedError(\"Not yet implemented.\")\n # initialize a new module with the same attributes\n return view\n\n @property\n def view(self):\n \"\"\"Return view of the module.\"\"\"\n return View(self, self._nodes_in_view, self._edges_in_view)\n\n @property\n def _module_type(self):\n \"\"\"Return type of the module (compartment, branch, cell, network) as string.\n\n This is used to perform asserts for some modules (e.g. network cannot use\n `set_ncomp`) without having to import the module in `base.py`.\"\"\"\n return self.__class__.__name__.lower()\n\n def _append_params_and_states(self, param_dict: Dict, state_dict: Dict):\n \"\"\"Insert the default params of the module (e.g. radius, length).\n\n This is run at `__init__()`. It does not deal with channels.\n \"\"\"\n for param_name, param_value in param_dict.items():\n self.base.nodes[param_name] = param_value\n for state_name, state_value in state_dict.items():\n self.base.nodes[state_name] = state_value\n\n def _gather_channels_from_constituents(self, constituents: List):\n \"\"\"Modify `self.channels` and `self.nodes` with channel info from constituents.\n\n This is run at `__init__()`. It takes all branches of constituents (e.g.\n of all branches when the are assembled into a cell) and adds columns to\n `.nodes` for the relevant channels.\n \"\"\"\n for module in constituents:\n for channel in module.channels:\n if channel._name not in [c._name for c in self.channels]:\n self.base.channels.append(channel)\n if channel.current_name not in self.membrane_current_names:\n self.base.membrane_current_names.append(channel.current_name)\n # Setting columns of channel names to `False` instead of `NaN`.\n for channel in self.base.channels:\n name = channel._name\n self.base.nodes.loc[self.nodes[name].isna(), name] = False\n\n @only_allow_module\n def to_jax(self):\n # TODO FROM #447: Make this work for View?\n \"\"\"Move `.nodes` to `.jaxnodes`.\n\n Before the actual simulation is run (via `jx.integrate`), all parameters of\n the `jx.Module` are stored in `.nodes` (a `pd.DataFrame`). However, for\n simulation, these parameters have to be moved to be `jnp.ndarrays` such that\n they can be processed on GPU/TPU and such that the simulation can be\n differentiated. `.to_jax()` copies the `.nodes` to `.jaxnodes`.\n \"\"\"\n self.base.jaxnodes = {}\n for key, value in self.base.nodes.to_dict(orient=\"list\").items():\n inds = jnp.arange(len(value))\n self.base.jaxnodes[key] = jnp.asarray(value)[inds]\n\n # `jaxedges` contains only parameters (no indices).\n # `jaxedges` contains only non-Nan elements. This is unlike the channels where\n # we allow parameter sharing.\n self.base.jaxedges = {}\n edges = self.base.edges.to_dict(orient=\"list\")\n for i, synapse in enumerate(self.base.synapses):\n condition = np.asarray(edges[\"type_ind\"]) == i\n for key in synapse.synapse_params:\n self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])\n for key in synapse.synapse_states:\n self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])\n\n def show(\n self,\n param_names: Optional[Union[str, List[str]]] = None,\n *,\n indices: bool = True,\n params: bool = True,\n states: bool = True,\n channel_names: Optional[List[str]] = None,\n ) -> pd.DataFrame:\n \"\"\"Print detailed information about the Module or a view of it.\n\n Args:\n param_names: The names of the parameters to show. If `None`, all parameters\n are shown.\n indices: Whether to show the indices of the compartments.\n params: Whether to show the parameters of the compartments.\n states: Whether to show the states of the compartments.\n channel_names: The names of the channels to show. If `None`, all channels are\n shown.\n\n Returns:\n A `pd.DataFrame` with the requested information.\n \"\"\"\n nodes = self.nodes.copy() # prevents this from being edited\n\n cols = []\n inds = [\"comp_index\", \"branch_index\", \"cell_index\"]\n scopes = [\"local\", \"global\"]\n inds = [f\"{s}_{i}\" for i in inds for s in scopes] if indices else []\n cols += inds\n cols += [ch._name for ch in self.channels] if channel_names else []\n cols += (\n sum([list(ch.channel_params) for ch in self.channels], []) if params else []\n )\n cols += (\n sum([list(ch.channel_states) for ch in self.channels], []) if states else []\n )\n\n if not param_names is None:\n cols = (\n inds + [c for c in cols if c in param_names]\n if params\n else list(param_names)\n )\n\n return nodes[cols]\n\n @only_allow_module\n def _init_morph(self):\n \"\"\"Initialize the morphology such that it can be processed by the solvers.\"\"\"\n self._init_morph_jaxley_spsolve()\n self._init_morph_jax_spsolve()\n self.initialized_morph = True\n\n @abstractmethod\n def _init_morph_jax_spsolve(self):\n \"\"\"Initialize the morphology for the JAX sparse solver.\"\"\"\n raise NotImplementedError\n\n @abstractmethod\n def _init_morph_jaxley_spsolve(self):\n \"\"\"Initialize the morphology for the custom Jaxley solver.\"\"\"\n raise NotImplementedError\n\n def _compute_axial_conductances(self, params: Dict[str, jnp.ndarray]):\n \"\"\"Given radius, length, r_a, compute the axial coupling conductances.\"\"\"\n return compute_axial_conductances(self._comp_edges, params)\n\n def set(self, key: str, val: Union[float, jnp.ndarray]):\n \"\"\"Set parameter of module (or its view) to a new value.\n\n Note that this function can not be called within `jax.jit` or `jax.grad`.\n Instead, it should be used set the parameters of the module **before** the\n simulation. Use `.data_set()` to set parameters during `jax.jit` or\n `jax.grad`.\n\n Args:\n key: The name of the parameter to set.\n val: The value to set the parameter to. If it is `jnp.ndarray` then it\n must be of shape `(len(num_compartments))`.\n \"\"\"\n if key in self.nodes.columns:\n not_nan = ~self.nodes[key].isna().to_numpy()\n self.base.nodes.loc[self._nodes_in_view[not_nan], key] = val\n elif key in self.edges.columns:\n not_nan = ~self.edges[key].isna().to_numpy()\n self.base.edges.loc[self._edges_in_view[not_nan], key] = val\n else:\n raise KeyError(f\"Key '{key}' not found in nodes or edges\")\n\n def data_set(\n self,\n key: str,\n val: Union[float, jnp.ndarray],\n param_state: Optional[List[Dict]],\n ):\n \"\"\"Set parameter of module (or its view) to a new value within `jit`.\n\n Args:\n key: The name of the parameter to set.\n val: The value to set the parameter to. If it is `jnp.ndarray` then it\n must be of shape `(len(num_compartments))`.\n param_state: State of the setted parameters, internally used such that this\n function does not modify global state.\n \"\"\"\n # Note: `data_set` does not support arrays for `val`.\n is_node_param = key in self.nodes.columns\n data = self.nodes if is_node_param else self.edges\n viewed_inds = self._nodes_in_view if is_node_param else self._edges_in_view\n if key in data.columns:\n not_nan = ~data[key].isna()\n added_param_state = [\n {\n \"indices\": np.atleast_2d(viewed_inds[not_nan]),\n \"key\": key,\n \"val\": jnp.atleast_1d(jnp.asarray(val)),\n }\n ]\n if param_state is not None:\n param_state += added_param_state\n else:\n param_state = added_param_state\n else:\n raise KeyError(\"Key not recognized.\")\n return param_state\n\n def set_ncomp(\n self,\n ncomp: int,\n min_radius: Optional[float] = None,\n ):\n \"\"\"Set the number of compartments with which the branch is discretized.\n\n Args:\n ncomp: The number of compartments that the branch should be discretized\n into.\n min_radius: Only used if the morphology was read from an SWC file. If passed\n the radius is capped to be at least this value.\n\n Raises:\n - When there are stimuli in any compartment in the module.\n - When there are recordings in any compartment in the module.\n - When the channels of the compartments are not the same within the branch\n that is modified.\n - When the lengths of the compartments are not the same within the branch\n that is modified.\n - Unless the morphology was read from an SWC file, when the radiuses of the\n compartments are not the same within the branch that is modified.\n \"\"\"\n assert len(self.base.externals) == 0, \"No stimuli allowed!\"\n assert len(self.base.recordings) == 0, \"No recordings allowed!\"\n assert len(self.base.trainable_params) == 0, \"No trainables allowed!\"\n\n assert self.base._module_type != \"network\", \"This is not allowed for networks.\"\n assert not (\n self.base._module_type == \"cell\"\n and len(self._branches_in_view) == len(self.base._branches_in_view)\n ), \"This is not allowed for cells.\"\n\n # Update all attributes that are affected by compartment structure.\n view = self.nodes.copy()\n all_nodes = self.base.nodes\n start_idx = self.nodes[\"global_comp_index\"].to_numpy()[0]\n ncomp_per_branch = self.base.ncomp_per_branch\n channel_names = [c._name for c in self.base.channels]\n channel_param_names = list(\n chain(*[c.channel_params for c in self.base.channels])\n )\n channel_state_names = list(\n chain(*[c.channel_states for c in self.base.channels])\n )\n radius_generating_fns = self.base._radius_generating_fns\n\n within_branch_radiuses = view[\"radius\"].to_numpy()\n compartment_lengths = view[\"length\"].to_numpy()\n num_previous_ncomp = len(within_branch_radiuses)\n branch_indices = pd.unique(view[\"global_branch_index\"])\n\n error_msg = lambda name: (\n f\"You previously modified the {name} of individual compartments, but \"\n f\"now you are modifying the number of compartments in this branch. \"\n f\"This is not allowed. First build the morphology with `set_ncomp()` and \"\n f\"then modify the radiuses and lengths of compartments.\"\n )\n\n if (\n ~np.all(within_branch_radiuses == within_branch_radiuses[0])\n and radius_generating_fns is None\n ):\n raise ValueError(error_msg(\"radius\"))\n\n for property_name in [\"length\", \"capacitance\", \"axial_resistivity\"]:\n compartment_properties = view[property_name].to_numpy()\n if ~np.all(compartment_properties == compartment_properties[0]):\n raise ValueError(error_msg(property_name))\n\n if not (self.nodes[channel_names].var() == 0.0).all():\n raise ValueError(\n \"Some channel exists only in some compartments of the branch which you\"\n \"are trying to modify. This is not allowed. First specify the number\"\n \"of compartments with `.set_ncomp()` and then insert the channels\"\n \"accordingly.\"\n )\n\n if not (\n self.nodes[channel_param_names + channel_state_names].var() == 0.0\n ).all():\n raise ValueError(\n \"Some channel has different parameters or states between the \"\n \"different compartments of the branch which you are trying to modify. \"\n \"This is not allowed. First specify the number of compartments with \"\n \"`.set_ncomp()` and then insert the channels accordingly.\"\n )\n\n # Add new rows as the average of all rows. Special case for the length is below.\n average_row = self.nodes.mean(skipna=False)\n average_row = average_row.to_frame().T\n view = pd.concat([*[average_row] * ncomp], axis=\"rows\")\n\n # Set the correct datatype after having performed an average which cast\n # everything to float.\n integer_cols = [\"global_cell_index\", \"global_branch_index\", \"global_comp_index\"]\n view[integer_cols] = view[integer_cols].astype(int)\n\n # Whether or not a channel exists in a compartment is a boolean.\n boolean_cols = channel_names\n view[boolean_cols] = view[boolean_cols].astype(bool)\n\n # Special treatment for the lengths and radiuses. These are not being set as\n # the average because we:\n # 1) Want to maintain the total length of a branch.\n # 2) Want to use the SWC inferred radius.\n #\n # Compute new compartment lengths.\n comp_lengths = np.sum(compartment_lengths) / ncomp\n view[\"length\"] = comp_lengths\n\n # Compute new compartment radiuses.\n if radius_generating_fns is not None:\n view[\"radius\"] = build_radiuses_from_xyzr(\n radius_fns=radius_generating_fns,\n branch_indices=branch_indices,\n min_radius=min_radius,\n ncomp=ncomp,\n )\n else:\n view[\"radius\"] = within_branch_radiuses[0] * np.ones(ncomp)\n\n # Update `.nodes`.\n # 1) Delete N rows starting from start_idx\n number_deleted = num_previous_ncomp\n all_nodes = all_nodes.drop(index=range(start_idx, start_idx + number_deleted))\n\n # 2) Insert M new rows at the same location\n df1 = all_nodes.iloc[:start_idx] # Rows before the insertion point\n df2 = all_nodes.iloc[start_idx:] # Rows after the insertion point\n\n # 3) Combine the parts: before, new rows, and after\n all_nodes = pd.concat([df1, view, df2]).reset_index(drop=True)\n\n # Override `comp_index` to just be a consecutive list.\n all_nodes[\"global_comp_index\"] = np.arange(len(all_nodes))\n\n # Update compartment structure arguments.\n ncomp_per_branch[branch_indices] = ncomp\n ncomp = int(np.max(ncomp_per_branch))\n cumsum_ncomp = cumsum_leading_zero(ncomp_per_branch)\n internal_node_inds = np.arange(cumsum_ncomp[-1])\n\n self.base.nodes = all_nodes\n self.base.ncomp_per_branch = ncomp_per_branch\n self.base.ncomp = ncomp\n self.base.cumsum_ncomp = cumsum_ncomp\n self.base._internal_node_inds = internal_node_inds\n\n # Update the morphology indexing (e.g., `.comp_edges`).\n self.base._initialize()\n self.base._init_view()\n self.base._update_local_indices()\n\n def make_trainable(\n self,\n key: str,\n init_val: Optional[Union[float, list]] = None,\n verbose: bool = True,\n ):\n \"\"\"Make a parameter trainable.\n\n If a parameter is made trainable, it will be returned by `get_parameters()`\n and should then be passed to `jx.integrate(..., params=params)`.\n\n Args:\n key: Name of the parameter to make trainable.\n init_val: Initial value of the parameter. If `float`, the same value is\n used for every created parameter. If `list`, the length of the list has\n to match the number of created parameters. If `None`, the current\n parameter value is used and if parameter sharing is performed that the\n current parameter value is averaged over all shared parameters.\n verbose: Whether to print the number of parameters that are added and the\n total number of parameters.\n \"\"\"\n assert (\n self.allow_make_trainable\n ), \"network.cell('all').make_trainable() is not supported. Use a for-loop over cells.\"\n ncomps_per_branch = (\n self.base.nodes[\"global_branch_index\"].value_counts().to_numpy()\n )\n assert np.all(\n ncomps_per_branch == ncomps_per_branch[0]\n ), \"Parameter sharing is not allowed for modules containing branches with different numbers of compartments.\"\n\n data = self.nodes if key in self.nodes.columns else None\n data = self.edges if key in self.edges.columns else data\n\n assert data is not None, f\"Key '{key}' not found in nodes or edges\"\n not_nan = ~data[key].isna()\n data = data.loc[not_nan]\n assert (\n len(data) > 0\n ), \"No settable parameters found in the selected compartments.\"\n\n grouped_view = data.groupby(\"controlled_by_param\")\n # Because of this `x.index.values` we cannot support `make_trainable()` on\n # the module level for synapse parameters (but only for `SynapseView`).\n inds_of_comps = list(\n grouped_view.apply(lambda x: x.index.values, include_groups=False)\n )\n indices_per_param = jnp.stack(inds_of_comps)\n # Sorted inds are only used to infer the correct starting values.\n param_vals = jnp.asarray(\n [data.loc[inds, key].to_numpy() for inds in inds_of_comps]\n )\n\n # Set the value which the trainable parameter should take.\n num_created_parameters = len(indices_per_param)\n if init_val is not None:\n if isinstance(init_val, float):\n new_params = jnp.asarray([init_val] * num_created_parameters)\n elif isinstance(init_val, list):\n assert (\n len(init_val) == num_created_parameters\n ), f\"len(init_val)={len(init_val)}, but trying to create {num_created_parameters} parameters.\"\n new_params = jnp.asarray(init_val)\n else:\n raise ValueError(\n f\"init_val must a float, list, or None, but it is a {type(init_val).__name__}.\"\n )\n else:\n new_params = jnp.mean(param_vals, axis=1)\n self.base.trainable_params.append({key: new_params})\n self.base.indices_set_by_trainables.append(indices_per_param)\n self.base.num_trainable_params += num_created_parameters\n if verbose:\n print(\n f\"Number of newly added trainable parameters: {num_created_parameters}. Total number of trainable parameters: {self.base.num_trainable_params}\"\n )\n\n def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]):\n \"\"\"Write the trainables into `.nodes` and `.edges`.\n\n This allows to, e.g., visualize trained networks with `.vis()`.\n\n Args:\n trainable_params: The trainable parameters returned by `get_parameters()`.\n \"\"\"\n # We do not support views. Why? `jaxedges` does not have any NaN\n # elements, whereas edges does. Because of this, we already need special\n # treatment to make this function work, and it would be an even bigger hassle\n # if we wanted to support this.\n assert self.__class__.__name__ in [\n \"Compartment\",\n \"Branch\",\n \"Cell\",\n \"Network\",\n ], \"Only supports modules.\"\n\n # We could also implement this without casting the module to jax.\n # However, I think it allows us to reuse as much code as possible and it avoids\n # any kind of issues with indexing or parameter sharing (as this is fully\n # taken care of by `get_all_parameters()`).\n self.base.to_jax()\n pstate = params_to_pstate(trainable_params, self.base.indices_set_by_trainables)\n all_params = self.base.get_all_parameters(pstate, voltage_solver=\"jaxley.stone\")\n\n # The value for `delta_t` does not matter here because it is only used to\n # compute the initial current. However, the initial current cannot be made\n # trainable and so its value never gets used below.\n all_states = self.base.get_all_states(pstate, all_params, delta_t=0.025)\n\n # Loop only over the keys in `pstate` to avoid unnecessary computation.\n for parameter in pstate:\n key = parameter[\"key\"]\n if key in self.base.nodes.columns:\n vals_to_set = all_params if key in all_params.keys() else all_states\n self.base.nodes[key] = vals_to_set[key]\n\n # `jaxedges` contains only non-Nan elements. This is unlike the channels where\n # we allow parameter sharing.\n edges = self.base.edges.to_dict(orient=\"list\")\n for i, synapse in enumerate(self.base.synapses):\n condition = np.asarray(edges[\"type_ind\"]) == i\n for key in list(synapse.synapse_params.keys()):\n self.base.edges.loc[condition, key] = all_params[key]\n for key in list(synapse.synapse_states.keys()):\n self.base.edges.loc[condition, key] = all_states[key]\n\n def distance(self, endpoint: \"View\") -> float:\n \"\"\"Return the direct distance between two compartments.\n This does not compute the pathwise distance (which is currently not\n implemented).\n Args:\n endpoint: The compartment to which to compute the distance to.\n \"\"\"\n assert len(self.xyzr) == 1 and len(endpoint.xyzr) == 1\n start_xyz = np.mean(self.xyzr[0][:, :3], axis=0)\n end_xyz = np.mean(endpoint.xyzr[0][:, :3], axis=0)\n return np.sqrt(np.sum((start_xyz - end_xyz) ** 2))\n\n def delete_trainables(self):\n \"\"\"Removes all trainable parameters from the module.\"\"\"\n\n if isinstance(self, View):\n trainables_and_inds = self._filter_trainables(is_viewed=False)\n self.base.indices_set_by_trainables = trainables_and_inds[0]\n self.base.trainable_params = trainables_and_inds[1]\n self.base.num_trainable_params -= self.num_trainable_params\n else:\n self.base.indices_set_by_trainables = []\n self.base.trainable_params = []\n self.base.num_trainable_params = 0\n self._update_view()\n\n def add_to_group(self, group_name: str):\n \"\"\"Add a view of the module to a group.\n\n Groups can then be indexed. For example:\n\n .. code-block:: python\n\n net.cell(0).add_to_group(\"excitatory\")\n net.excitatory.set(\"radius\", 0.1)\n\n Args:\n group_name: The name of the group.\n \"\"\"\n if group_name not in self.base.groups:\n self.base.groups[group_name] = self._nodes_in_view\n else:\n self.base.groups[group_name] = np.unique(\n np.concatenate([self.base.groups[group_name], self._nodes_in_view])\n )\n\n def _get_state_names(self) -> Tuple[List, List]:\n \"\"\"Collect all recordable / clampable states in the membrane and synapses.\n\n Returns states seperated by comps and edges.\"\"\"\n channel_states = [name for c in self.channels for name in c.channel_states]\n synapse_states = [name for s in self.synapses for name in s.synapse_states]\n membrane_states = [\"v\", \"i\"] + self.membrane_current_names\n return channel_states + membrane_states, synapse_states\n\n def get_parameters(self) -> List[Dict[str, jnp.ndarray]]:\n \"\"\"Get all trainable parameters.\n\n The returned parameters should be passed to `jx.integrate(..., params=params).\n\n Returns:\n A list of all trainable parameters in the form of\n [{\"gNa\": jnp.array([0.1, 0.2, 0.3])}, ...].\n \"\"\"\n return self.trainable_params\n\n @only_allow_module\n def get_all_parameters(\n self, pstate: List[Dict], voltage_solver: str\n ) -> Dict[str, jnp.ndarray]:\n # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n \"\"\"Return all parameters (and coupling conductances) needed to simulate.\n\n Runs `_compute_axial_conductances()` and return every parameter that is needed\n to solve the ODE. This includes conductances, radiuses, lengths,\n axial_resistivities, but also coupling conductances.\n\n This is done by first obtaining the current value of every parameter (not only\n the trainable ones) and then replacing the trainable ones with the value\n in `trainable_params()`. This function is run within `jx.integrate()`.\n\n pstate can be obtained by calling `params_to_pstate()`.\n\n .. code-block:: python\n\n params = module.get_parameters() # i.e. [0, 1, 2]\n pstate = params_to_pstate(params, module.indices_set_by_trainables)\n module.to_jax() # needed for call to module.jaxnodes\n\n Args:\n pstate: The state of the trainable parameters. pstate takes the form\n [{\n \"key\": \"gNa\", \"indices\": jnp.array([0, 1, 2]),\n \"val\": jnp.array([0.1, 0.2, 0.3])\n }, ...].\n voltage_solver: The voltage solver that is used. Since `jax.sparse` and\n `jaxley.xyz` require different formats of the axial conductances, this\n function will default to different building methods.\n\n Returns:\n A dictionary of all module parameters.\n \"\"\"\n params = {}\n for key in [\"radius\", \"length\", \"axial_resistivity\", \"capacitance\"]:\n params[key] = self.base.jaxnodes[key]\n\n for channel in self.base.channels:\n for channel_params in channel.channel_params:\n params[channel_params] = self.base.jaxnodes[channel_params]\n\n for synapse_params in self.base.synapse_param_names:\n params[synapse_params] = self.base.jaxedges[synapse_params]\n\n # Override with those parameters set by `.make_trainable()`.\n for parameter in pstate:\n key = parameter[\"key\"]\n inds = parameter[\"indices\"]\n set_param = parameter[\"val\"]\n\n # This is needed since SynapseViews worked differently before.\n # This mimics the old behaviour and tranformes the new indices\n # to the old indices.\n # TODO FROM #447: Longterm this should be gotten rid of.\n # Instead edges should work similar to nodes (would also allow for\n # param sharing).\n synapse_inds = self.base.edges.groupby(\"type\").rank()[\"global_edge_index\"]\n synapse_inds = (synapse_inds.astype(int) - 1).to_numpy()\n if key in self.base.synapse_param_names:\n inds = synapse_inds[inds]\n\n if key in params: # Only parameters, not initial states.\n # `inds` is of shape `(num_params, num_comps_per_param)`.\n # `set_param` is of shape `(num_params,)`\n # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the\n # `.set()` to work. This is done with `[:, None]`.\n params[key] = params[key].at[inds].set(set_param[:, None])\n\n # Compute conductance params and add them to the params dictionary.\n params[\"axial_conductances\"] = self.base._compute_axial_conductances(\n params=params\n )\n return params\n\n @only_allow_module\n def _get_states_from_nodes_and_edges(self) -> Dict[str, jnp.ndarray]:\n # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n \"\"\"Return states as they are set in the `.nodes` and `.edges` tables.\"\"\"\n self.base.to_jax() # Create `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`.\n states = {\"v\": self.base.jaxnodes[\"v\"]}\n # Join node and edge states into a single state dictionary.\n for channel in self.base.channels:\n for channel_states in channel.channel_states:\n states[channel_states] = self.base.jaxnodes[channel_states]\n for synapse_states in self.base.synapse_state_names:\n states[synapse_states] = self.base.jaxedges[synapse_states]\n return states\n\n @only_allow_module\n def get_all_states(\n self, pstate: List[Dict], all_params, delta_t: float\n ) -> Dict[str, jnp.ndarray]:\n # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n \"\"\"Get the full initial state of the module from jaxnodes and trainables.\n\n Args:\n pstate: The state of the trainable parameters.\n all_params: All parameters of the module.\n delta_t: The time step.\n\n Returns:\n A dictionary of all states of the module.\n \"\"\"\n states = self.base._get_states_from_nodes_and_edges()\n\n # Override with the initial states set by `.make_trainable()`.\n for parameter in pstate:\n key = parameter[\"key\"]\n inds = parameter[\"indices\"]\n set_param = parameter[\"val\"]\n if key in states: # Only initial states, not parameters.\n # `inds` is of shape `(num_params, num_comps_per_param)`.\n # `set_param` is of shape `(num_params,)`\n # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the\n # `.set()` to work. This is done with `[:, None]`.\n states[key] = states[key].at[inds].set(set_param[:, None])\n\n # Add to the states the initial current through every channel.\n states, _ = self.base._channel_currents(\n states, delta_t, self.channels, self.nodes, all_params\n )\n\n # Add to the states the initial current through every synapse.\n states, _ = self.base._synapse_currents(\n states, self.synapses, all_params, delta_t, self.edges\n )\n return states\n\n @property\n def initialized(self) -> bool:\n \"\"\"Whether the `Module` is ready to be solved or not.\"\"\"\n return self.initialized_morph\n\n def _initialize(self):\n \"\"\"Initialize the module.\"\"\"\n self._init_morph()\n return self\n\n @only_allow_module\n def init_states(self, delta_t: float = 0.025):\n # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n \"\"\"Initialize all mechanisms in their steady state.\n\n This considers the voltages and parameters of each compartment.\n\n Args:\n delta_t: Passed on to `channel.init_state()`.\n \"\"\"\n # Update states of the channels.\n channel_nodes = self.base.nodes\n states = self.base._get_states_from_nodes_and_edges()\n\n # We do not use any `pstate` for initializing. In principle, we could change\n # that by allowing an input `params` and `pstate` to this function.\n # `voltage_solver` could also be `jax.sparse` here, because both of them\n # build the channel parameters in the same way.\n params = self.base.get_all_parameters([], voltage_solver=\"jaxley.thomas\")\n\n for channel in self.base.channels:\n name = channel._name\n channel_indices = channel_nodes.loc[channel_nodes[name]][\n \"global_comp_index\"\n ].to_numpy()\n voltages = channel_nodes.loc[channel_indices, \"v\"].to_numpy()\n\n channel_param_names = list(channel.channel_params.keys())\n channel_state_names = list(channel.channel_states.keys())\n channel_states = query_channel_states_and_params(\n states, channel_state_names, channel_indices\n )\n channel_params = query_channel_states_and_params(\n params, channel_param_names, channel_indices\n )\n\n init_state = channel.init_state(\n channel_states, voltages, channel_params, delta_t\n )\n\n # `init_state` might not return all channel states. Only the ones that are\n # returned are updated here.\n for key, val in init_state.items():\n # Note that we are overriding `self.nodes` here, but `self.nodes` is\n # not used above to actually compute the current states (so there are\n # no issues with overriding states).\n self.nodes.loc[channel_indices, key] = val\n\n def _init_morph_for_debugging(self):\n \"\"\"Instandiates row and column inds which can be used to solve the voltage eqs.\n\n This is important only for expert users who try to modify the solver for the\n voltage equations. By default, this function is never run.\n\n This is useful for debugging the solver because one can use\n `scipy.linalg.sparse.spsolve` after every step of the solve.\n\n Here is the code snippet that can be used for debugging then (to be inserted in\n `solver_voltage`):\n ```python\n from scipy.sparse import csc_matrix\n from scipy.sparse.linalg import spsolve\n from jaxley.utils.debug_solver import build_voltage_matrix_elements\n\n elements, solve, num_entries, start_ind_for_branchpoints = (\n build_voltage_matrix_elements(\n uppers,\n lowers,\n diags,\n solves,\n branchpoint_conds_children[debug_states[\"child_inds\"]],\n branchpoint_conds_parents[debug_states[\"par_inds\"]],\n branchpoint_weights_children[debug_states[\"child_inds\"]],\n branchpoint_weights_parents[debug_states[\"par_inds\"]],\n branchpoint_diags,\n branchpoint_solves,\n debug_states[\"ncomp\"],\n nbranches,\n )\n )\n sparse_matrix = csc_matrix(\n (elements, (debug_states[\"row_inds\"], debug_states[\"col_inds\"])),\n shape=(num_entries, num_entries),\n )\n solution = spsolve(sparse_matrix, solve)\n solution = solution[:start_ind_for_branchpoints] # Delete branchpoint voltages.\n solves = jnp.reshape(solution, (debug_states[\"ncomp\"], nbranches))\n return solves\n ```\n \"\"\"\n # For scipy and jax.scipy.\n row_and_col_inds = compute_morphology_indices(\n len(self.base._par_inds),\n self.base._child_belongs_to_branchpoint,\n self.base._par_inds,\n self.base._child_inds,\n self.base.ncomp,\n self.base.total_nbranches,\n )\n\n num_elements = len(row_and_col_inds[\"row_inds\"])\n data_inds, indices, indptr = convert_to_csc(\n num_elements=num_elements,\n row_ind=row_and_col_inds[\"row_inds\"],\n col_ind=row_and_col_inds[\"col_inds\"],\n )\n self.base.debug_states[\"row_inds\"] = row_and_col_inds[\"row_inds\"]\n self.base.debug_states[\"col_inds\"] = row_and_col_inds[\"col_inds\"]\n self.base.debug_states[\"data_inds\"] = data_inds\n self.base.debug_states[\"indices\"] = indices\n self.base.debug_states[\"indptr\"] = indptr\n\n self.base.debug_states[\"ncomp\"] = self.base.ncomp\n self.base.debug_states[\"child_inds\"] = self.base._child_inds\n self.base.debug_states[\"par_inds\"] = self.base._par_inds\n\n def record(self, state: str = \"v\", verbose=True):\n comp_states, edge_states = self._get_state_names()\n if state not in comp_states + edge_states:\n raise KeyError(f\"{state} is not a recognized state in this module.\")\n in_view = self._nodes_in_view if state in comp_states else self._edges_in_view\n\n new_recs = pd.DataFrame(in_view, columns=[\"rec_index\"])\n new_recs[\"state\"] = state\n self.base.recordings = pd.concat([self.base.recordings, new_recs])\n has_duplicates = self.base.recordings.duplicated()\n self.base.recordings = self.base.recordings.loc[~has_duplicates]\n if verbose:\n print(\n f\"Added {len(in_view)-sum(has_duplicates)} recordings. See `.recordings` for details.\"\n )\n\n def _update_view(self):\n \"\"\"Update the attrs of the view after changes in the base module.\"\"\"\n if isinstance(self, View):\n scope = self._scope\n current_view = self._current_view\n # copy dict of new View. For some reason doing self = View(self)\n # did not work.\n self.__dict__ = View(\n self.base, self._nodes_in_view, self._edges_in_view\n ).__dict__\n\n # retain the scope and current_view of the previous view\n self._scope = scope\n self._current_view = current_view\n\n def delete_recordings(self):\n \"\"\"Removes all recordings from the module.\"\"\"\n if isinstance(self, View):\n base_recs = self.base.recordings\n self.base.recordings = base_recs[\n ~base_recs.isin(self.recordings).all(axis=1)\n ]\n self._update_view()\n else:\n self.base.recordings = pd.DataFrame().from_dict({})\n\n def stimulate(self, current: Optional[jnp.ndarray] = None, verbose: bool = True):\n \"\"\"Insert a stimulus into the compartment.\n\n current must be a 1d array or have batch dimension of size `(num_compartments, )`\n or `(1, )`. If 1d, the same stimulus is added to all compartments.\n\n This function cannot be run during `jax.jit` and `jax.grad`. Because of this,\n it should only be used for static stimuli (i.e., stimuli that do not depend\n on the data and that should not be learned). For stimuli that depend on data\n (or that should be learned), please use `data_stimulate()`.\n\n Args:\n current: Current in `nA`.\n \"\"\"\n self._external_input(\"i\", current, verbose=verbose)\n\n def clamp(self, state_name: str, state_array: jnp.ndarray, verbose: bool = True):\n \"\"\"Clamp a state to a given value across specified compartments.\n\n Args:\n state_name: The name of the state to clamp.\n state_array (jnp.nd: Array of values to clamp the state to.\n verbose : If True, prints details about the clamping.\n\n This function sets external states for the compartments.\n \"\"\"\n self._external_input(state_name, state_array, verbose=verbose)\n\n def _external_input(\n self,\n key: str,\n values: Optional[jnp.ndarray],\n verbose: bool = True,\n ):\n comp_states, edge_states = self._get_state_names()\n if key not in comp_states + edge_states:\n raise KeyError(f\"{key} is not a recognized state in this module.\")\n values = values if values.ndim == 2 else jnp.expand_dims(values, axis=0)\n batch_size = values.shape[0]\n num_inserted = (\n len(self._nodes_in_view) if key in comp_states else len(self._edges_in_view)\n )\n is_multiple = num_inserted == batch_size\n values = values if is_multiple else jnp.repeat(values, num_inserted, axis=0)\n assert batch_size in [\n 1,\n num_inserted,\n ], \"Number of comps and stimuli do not match.\"\n\n if key in self.base.externals.keys():\n self.base.externals[key] = jnp.concatenate(\n [self.base.externals[key], values]\n )\n self.base.external_inds[key] = jnp.concatenate(\n [self.base.external_inds[key], self._nodes_in_view]\n )\n else:\n if key in comp_states:\n self.base.externals[key] = values\n self.base.external_inds[key] = self._nodes_in_view\n else:\n self.base.externals[key] = values\n self.base.external_inds[key] = self._edges_in_view\n if verbose:\n print(\n f\"Added {num_inserted} external_states. See `.externals` for details.\"\n )\n\n def data_stimulate(\n self,\n current: jnp.ndarray,\n data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n verbose: bool = False,\n ) -> Tuple[jnp.ndarray, pd.DataFrame]:\n \"\"\"Insert a stimulus into the module within jit (or grad).\n\n Args:\n current: Current in `nA`.\n verbose: Whether or not to print the number of inserted stimuli. `False`\n by default because this method is meant to be jitted.\n \"\"\"\n return self._data_external_input(\n \"i\", current, data_stimuli, self.nodes, verbose=verbose\n )\n\n def data_clamp(\n self,\n state_name: str,\n state_array: jnp.ndarray,\n data_clamps: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n verbose: bool = False,\n ):\n \"\"\"Insert a clamp into the module within jit (or grad).\n\n Args:\n state_name: Name of the state variable to set.\n state_array: Time series of the state variable in the default Jaxley unit.\n State array should be of shape (num_clamps, simulation_time) or\n (simulation_time, ) for a single clamp.\n verbose: Whether or not to print the number of inserted clamps. `False`\n by default because this method is meant to be jitted.\n \"\"\"\n comp_states, edge_states = self._get_state_names()\n if state_name not in comp_states + edge_states:\n raise KeyError(f\"{state_name} is not a recognized state in this module.\")\n data = self.nodes if state_name in comp_states else self.edges\n return self._data_external_input(\n state_name, state_array, data_clamps, data, verbose=verbose\n )\n\n def _data_external_input(\n self,\n state_name: str,\n state_array: jnp.ndarray,\n data_external_input: Optional[Tuple[jnp.ndarray, pd.DataFrame]],\n view: pd.DataFrame,\n verbose: bool = False,\n ):\n comp_states, edge_states = self._get_state_names()\n state_array = (\n state_array\n if state_array.ndim == 2\n else jnp.expand_dims(state_array, axis=0)\n )\n batch_size = state_array.shape[0]\n num_inserted = (\n len(self._nodes_in_view)\n if state_name in comp_states\n else len(self._edges_in_view)\n )\n is_multiple = num_inserted == batch_size\n state_array = (\n state_array\n if is_multiple\n else jnp.repeat(state_array, num_inserted, axis=0)\n )\n assert batch_size in [\n 1,\n num_inserted,\n ], \"Number of comps and clamps do not match.\"\n\n if data_external_input is not None:\n external_input = data_external_input[1]\n external_input = jnp.concatenate([external_input, state_array])\n inds = data_external_input[2]\n else:\n external_input = state_array\n inds = pd.DataFrame().from_dict({})\n\n inds = pd.concat([inds, view])\n\n if verbose:\n if state_name == \"i\":\n print(f\"Added {len(view)} stimuli.\")\n else:\n print(f\"Added {len(view)} clamps.\")\n\n return (state_name, external_input, inds)\n\n def delete_stimuli(self):\n \"\"\"Removes all stimuli from the module.\"\"\"\n self.delete_clamps(\"i\")\n\n def delete_clamps(self, state_name: Optional[str] = None):\n \"\"\"Removes all clamps of the given state from the module.\"\"\"\n all_externals = list(self.externals.keys())\n if \"i\" in all_externals:\n all_externals.remove(\"i\")\n state_names = all_externals if state_name is None else [state_name]\n for state_name in state_names:\n if state_name in self.externals:\n keep_inds = ~np.isin(\n self.base.external_inds[state_name], self._nodes_in_view\n )\n base_exts = self.base.externals\n base_exts_inds = self.base.external_inds\n if np.all(~keep_inds):\n base_exts.pop(state_name, None)\n base_exts_inds.pop(state_name, None)\n else:\n base_exts[state_name] = base_exts[state_name][keep_inds]\n base_exts_inds[state_name] = base_exts_inds[state_name][keep_inds]\n self._update_view()\n else:\n pass # does not have to be deleted if not in externals\n\n def insert(self, channel: Channel):\n \"\"\"Insert a channel into the module.\n\n Args:\n channel: The channel to insert.\"\"\"\n name = channel._name\n\n # Channel does not yet exist in the `jx.Module` at all.\n if name not in [c._name for c in self.base.channels]:\n self.base.channels.append(channel)\n self.base.nodes[name] = (\n False # Previous columns do not have the new channel.\n )\n\n if channel.current_name not in self.base.membrane_current_names:\n self.base.membrane_current_names.append(channel.current_name)\n\n # Add a binary column that indicates if a channel is present.\n self.base.nodes.loc[self._nodes_in_view, name] = True\n\n # Loop over all new parameters, e.g. gNa, eNa.\n for key in channel.channel_params:\n self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_params[key]\n\n # Loop over all new parameters, e.g. gNa, eNa.\n for key in channel.channel_states:\n self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_states[key]\n\n def delete_channel(self, channel: Channel):\n \"\"\"Remove a channel from the module.\n\n Args:\n channel: The channel to remove.\"\"\"\n name = channel._name\n channel_names = [c._name for c in self.channels]\n all_channel_names = [c._name for c in self.base.channels]\n if name in channel_names:\n channel_cols = list(channel.channel_params.keys())\n channel_cols += list(channel.channel_states.keys())\n self.base.nodes.loc[self._nodes_in_view, channel_cols] = float(\"nan\")\n self.base.nodes.loc[self._nodes_in_view, name] = False\n\n # only delete cols if no other comps in the module have the same channel\n if np.all(~self.base.nodes[name]):\n self.base.channels.pop(all_channel_names.index(name))\n self.base.membrane_current_names.remove(channel.current_name)\n self.base.nodes.drop(columns=channel_cols + [name], inplace=True)\n else:\n raise ValueError(f\"Channel {name} not found in the module.\")\n\n @only_allow_module\n def step(\n self,\n u: Dict[str, jnp.ndarray],\n delta_t: float,\n external_inds: Dict[str, jnp.ndarray],\n externals: Dict[str, jnp.ndarray],\n params: Dict[str, jnp.ndarray],\n solver: str = \"bwd_euler\",\n voltage_solver: str = \"jaxley.stone\",\n ) -> Dict[str, jnp.ndarray]:\n \"\"\"One step of solving the Ordinary Differential Equation.\n\n This function is called inside of `integrate` and increments the state of the\n module by one time step. Calls `_step_channels` and `_step_synapse` to update\n the states of the channels and synapses using fwd_euler.\n\n Args:\n u: The state of the module. voltages = u[\"v\"]\n delta_t: The time step.\n external_inds: The indices of the external inputs.\n externals: The external inputs.\n params: The parameters of the module.\n solver: The solver to use for the voltages. Either of [\"bwd_euler\",\n \"fwd_euler\", \"crank_nicolson\"].\n voltage_solver: The tridiagonal solver used to diagonalize the\n coefficient matrix of the ODE system. Either of [\"jaxley.thomas\",\n \"jaxley.stone\"].\n\n Returns:\n The updated state of the module.\n \"\"\"\n\n # Extract the voltages\n voltages = u[\"v\"]\n\n # Extract the external inputs\n if \"i\" in externals.keys():\n i_current = externals[\"i\"]\n i_inds = external_inds[\"i\"]\n i_ext = self._get_external_input(\n voltages, i_inds, i_current, params[\"radius\"], params[\"length\"]\n )\n else:\n i_ext = 0.0\n\n # Step of the channels.\n u, (v_terms, const_terms) = self._step_channels(\n u, delta_t, self.channels, self.nodes, params\n )\n\n # Step of the synapse.\n u, (syn_v_terms, syn_const_terms) = self._step_synapse(\n u,\n self.synapses,\n params,\n delta_t,\n self.edges,\n )\n\n # Clamp for channels and synapses.\n for key in externals.keys():\n if key not in [\"i\", \"v\"]:\n u[key] = u[key].at[external_inds[key]].set(externals[key])\n\n # Voltage steps.\n cm = params[\"capacitance\"] # Abbreviation.\n\n # Arguments used by all solvers.\n solver_kwargs = {\n \"voltages\": voltages,\n \"voltage_terms\": (v_terms + syn_v_terms) / cm,\n \"constant_terms\": (const_terms + i_ext + syn_const_terms) / cm,\n \"axial_conductances\": params[\"axial_conductances\"],\n \"internal_node_inds\": self._internal_node_inds,\n }\n\n # Add solver specific arguments.\n if voltage_solver == \"jax.sparse\":\n solver_kwargs.update(\n {\n \"sinks\": np.asarray(self._comp_edges[\"sink\"].to_list()),\n \"data_inds\": self._data_inds,\n \"indices\": self._indices_jax_spsolve,\n \"indptr\": self._indptr_jax_spsolve,\n \"n_nodes\": self._n_nodes,\n }\n )\n # Only for `bwd_euler` and `cranck-nicolson`.\n step_voltage_implicit = step_voltage_implicit_with_jax_spsolve\n else:\n # Our custom sparse solver requires a different format of all conductance\n # values to perform triangulation and backsubstution optimally.\n #\n # Currently, the forward Euler solver also uses this format. However,\n # this is only for historical reasons and we are planning to change this in\n # the future.\n solver_kwargs.update(\n {\n \"sinks\": np.asarray(self._comp_edges[\"sink\"].to_list()),\n \"sources\": np.asarray(self._comp_edges[\"source\"].to_list()),\n \"types\": np.asarray(self._comp_edges[\"type\"].to_list()),\n \"ncomp_per_branch\": self.ncomp_per_branch,\n \"par_inds\": self._par_inds,\n \"child_inds\": self._child_inds,\n \"nbranches\": self.total_nbranches,\n \"solver\": voltage_solver,\n \"idx\": self._solve_indexer,\n \"debug_states\": self.debug_states,\n }\n )\n # Only for `bwd_euler` and `cranck-nicolson`.\n step_voltage_implicit = step_voltage_implicit_with_jaxley_spsolve\n\n if solver == \"bwd_euler\":\n u[\"v\"] = step_voltage_implicit(**solver_kwargs, delta_t=delta_t)\n elif solver == \"crank_nicolson\":\n # Crank-Nicolson advances by half a step of backward and half a step of\n # forward Euler.\n half_step_delta_t = delta_t / 2\n half_step_voltages = step_voltage_implicit(\n **solver_kwargs, delta_t=half_step_delta_t\n )\n # The forward Euler step in Crank-Nicolson can be performed easily as\n # `V_{n+1} = 2 * V_{n+1/2} - V_n`. See also NEURON book Chapter 4.\n u[\"v\"] = 2 * half_step_voltages - voltages\n elif solver == \"fwd_euler\":\n u[\"v\"] = step_voltage_explicit(**solver_kwargs, delta_t=delta_t)\n else:\n raise ValueError(\n f\"You specified `solver={solver}`. The only allowed solvers are \"\n \"['bwd_euler', 'fwd_euler', 'crank_nicolson'].\"\n )\n\n # Clamp for voltages.\n if \"v\" in externals.keys():\n u[\"v\"] = u[\"v\"].at[external_inds[\"v\"]].set(externals[\"v\"])\n\n return u\n\n def _step_channels(\n self,\n states: Dict[str, jnp.ndarray],\n delta_t: float,\n channels: List[Channel],\n channel_nodes: pd.DataFrame,\n params: Dict[str, jnp.ndarray],\n ) -> Tuple[Dict[str, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:\n \"\"\"One step of integration of the channels and of computing their current.\"\"\"\n states = self._step_channels_state(\n states, delta_t, channels, channel_nodes, params\n )\n states, current_terms = self._channel_currents(\n states, delta_t, channels, channel_nodes, params\n )\n return states, current_terms\n\n def _step_channels_state(\n self,\n states,\n delta_t,\n channels: List[Channel],\n channel_nodes: pd.DataFrame,\n params: Dict[str, jnp.ndarray],\n ) -> Dict[str, jnp.ndarray]:\n \"\"\"One integration step of the channels.\"\"\"\n voltages = states[\"v\"]\n\n # Update states of the channels.\n indices = channel_nodes[\"global_comp_index\"].to_numpy()\n for channel in channels:\n channel_param_names = list(channel.channel_params)\n channel_param_names += [\n \"radius\",\n \"length\",\n \"axial_resistivity\",\n \"capacitance\",\n ]\n channel_state_names = list(channel.channel_states)\n channel_state_names += self.membrane_current_names\n channel_indices = indices[channel_nodes[channel._name].astype(bool)]\n\n channel_params = query_channel_states_and_params(\n params, channel_param_names, channel_indices\n )\n channel_states = query_channel_states_and_params(\n states, channel_state_names, channel_indices\n )\n\n states_updated = channel.update_states(\n channel_states, delta_t, voltages[channel_indices], channel_params\n )\n # Rebuild state. This has to be done within the loop over channels to allow\n # multiple channels which modify the same state.\n for key, val in states_updated.items():\n states[key] = states[key].at[channel_indices].set(val)\n\n return states\n\n def _channel_currents(\n self,\n states: Dict[str, jnp.ndarray],\n delta_t: float,\n channels: List[Channel],\n channel_nodes: pd.DataFrame,\n params: Dict[str, jnp.ndarray],\n ) -> Tuple[Dict[str, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:\n \"\"\"Return the current through each channel.\n\n This is also updates `state` because the `state` also contains the current.\n \"\"\"\n voltages = states[\"v\"]\n\n # Compute current through channels.\n voltage_terms = jnp.zeros_like(voltages)\n constant_terms = jnp.zeros_like(voltages)\n # Run with two different voltages that are `diff` apart to infer the slope and\n # offset.\n diff = 1e-3\n\n current_states = {}\n for name in self.membrane_current_names:\n current_states[name] = jnp.zeros_like(voltages)\n\n for channel in channels:\n name = channel._name\n channel_param_names = list(channel.channel_params.keys())\n channel_state_names = list(channel.channel_states.keys())\n indices = channel_nodes.loc[channel_nodes[name]][\n \"global_comp_index\"\n ].to_numpy()\n\n channel_params = {}\n for p in channel_param_names:\n channel_params[p] = params[p][indices]\n channel_params[\"radius\"] = params[\"radius\"][indices]\n channel_params[\"length\"] = params[\"length\"][indices]\n channel_params[\"axial_resistivity\"] = params[\"axial_resistivity\"][indices]\n\n channel_states = {}\n for s in channel_state_names:\n channel_states[s] = states[s][indices]\n\n v_and_perturbed = jnp.stack([voltages[indices], voltages[indices] + diff])\n membrane_currents = vmap(channel.compute_current, in_axes=(None, 0, None))(\n channel_states, v_and_perturbed, channel_params\n )\n voltage_term = (membrane_currents[1] - membrane_currents[0]) / diff\n constant_term = membrane_currents[0] - voltage_term * voltages[indices]\n\n # * 1000 to convert from mA/cm^2 to uA/cm^2.\n voltage_terms = voltage_terms.at[indices].add(voltage_term * 1000.0)\n constant_terms = constant_terms.at[indices].add(-constant_term * 1000.0)\n\n # Save the current (for the unperturbed voltage) as a state that will\n # also be passed to the state update.\n current_states[channel.current_name] = (\n current_states[channel.current_name]\n .at[indices]\n .add(membrane_currents[0])\n )\n\n # Copy the currents into the `state` dictionary such that they can be\n # recorded and used by `Channel.update_states()`.\n for name in self.membrane_current_names:\n states[name] = current_states[name]\n\n return states, (voltage_terms, constant_terms)\n\n def _step_synapse(\n self,\n u: Dict[str, jnp.ndarray],\n syn_channels: List[Channel],\n params: Dict[str, jnp.ndarray],\n delta_t: float,\n edges: pd.DataFrame,\n ) -> Tuple[Dict[str, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:\n \"\"\"One step of integration of the channels.\n\n `Network` overrides this method (because it actually has synapses), whereas\n `Compartment`, `Branch`, and `Cell` do not override this.\n \"\"\"\n voltages = u[\"v\"]\n return u, (jnp.zeros_like(voltages), jnp.zeros_like(voltages))\n\n def _synapse_currents(\n self, states, syn_channels, params, delta_t, edges: pd.DataFrame\n ) -> Tuple[Dict[str, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]:\n return states, (None, None)\n\n @staticmethod\n def _get_external_input(\n voltages: jnp.ndarray,\n i_inds: jnp.ndarray,\n i_stim: jnp.ndarray,\n radius: float,\n length_single_compartment: float,\n ) -> jnp.ndarray:\n \"\"\"\n Return external input to each compartment in uA / cm^2.\n\n Args:\n voltages: mV.\n i_stim: nA.\n radius: um.\n length_single_compartment: um.\n \"\"\"\n zero_vec = jnp.zeros_like(voltages)\n current = convert_point_process_to_distributed(\n i_stim, radius[i_inds], length_single_compartment[i_inds]\n )\n\n dnums = ScatterDimensionNumbers(\n update_window_dims=(),\n inserted_window_dims=(0,),\n scatter_dims_to_operand_dims=(0,),\n )\n stim_at_timestep = scatter_add(zero_vec, i_inds[:, None], current, dnums)\n return stim_at_timestep\n\n def vis(\n self,\n ax: Optional[Axes] = None,\n col: str = \"k\",\n dims: Tuple[int] = (0, 1),\n type: str = \"line\",\n morph_plot_kwargs: Dict = {},\n ) -> Axes:\n \"\"\"Visualize the module.\n\n Modules can be visualized on one of the cardinal planes (xy, xz, yz) or\n even in 3D.\n\n Several options are available:\n - `line`: All points from the traced morphology (`xyzr`), are connected\n with a line plot.\n - `scatter`: All traced points, are plotted as scatter points.\n - `comp`: Plots the compartmentalized morphology, including radius\n and shape. (shows the true compartment lengths per default, but this can\n be changed via the `morph_plot_kwargs`, for details see\n `jaxley.utils.plot_utils.plot_comps`).\n - `morph`: Reconstructs the 3D shape of the traced morphology. For details see\n `jaxley.utils.plot_utils.plot_morph`. Warning: For 3D plots and morphologies\n with many traced points this can be very slow.\n\n Args:\n ax: An axis into which to plot.\n col: The color for all branches.\n dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n two of them.\n type: The type of plot. One of [\"line\", \"scatter\", \"comp\", \"morph\"].\n morph_plot_kwargs: Keyword arguments passed to the plotting function.\n \"\"\"\n if \"comp\" in type.lower():\n return plot_comps(self, dims=dims, ax=ax, col=col, **morph_plot_kwargs)\n if \"morph\" in type.lower():\n return plot_morph(self, dims=dims, ax=ax, col=col, **morph_plot_kwargs)\n\n assert not np.any(\n [np.isnan(xyzr[:, dims]).all() for xyzr in self.xyzr]\n ), \"No coordinates available. Use `vis(detail='point')` or run `.compute_xyz()` before running `.vis()`.\"\n\n ax = plot_graph(\n self.xyzr,\n dims=dims,\n col=col,\n ax=ax,\n type=type,\n morph_plot_kwargs=morph_plot_kwargs,\n )\n\n return ax\n\n def compute_xyz(self):\n \"\"\"Return xyz coordinates of every branch, based on the branch length.\n\n This function should not be called if the morphology was read from an `.swc`\n file. However, for morphologies that were constructed from scratch, this\n function **must** be called before `.vis()`. The computed `xyz` coordinates\n are only used for plotting.\n \"\"\"\n max_y_multiplier = 5.0\n min_y_multiplier = 0.5\n\n parents = self.comb_parents\n num_children = _compute_num_children(parents)\n index_of_child = _compute_index_of_child(parents)\n levels = compute_levels(parents)\n\n # Extract branch.\n inds_branch = self.nodes.groupby(\"global_branch_index\")[\n \"global_comp_index\"\n ].apply(list)\n branch_lens = [np.sum(self.nodes[\"length\"][np.asarray(i)]) for i in inds_branch]\n endpoints = []\n\n # Different levels will get a different \"angle\" at which the children emerge from\n # the parents. This angle is defined by the `y_offset_multiplier`. This value\n # defines the range between y-location of the first and of the last child of a\n # parent.\n y_offset_multiplier = np.linspace(\n max_y_multiplier, min_y_multiplier, np.max(levels) + 1\n )\n\n for b in range(self.total_nbranches):\n # For networks with mixed SWC and from-scatch neurons, only update those\n # branches that do not have coordingates yet.\n if np.any(np.isnan(self.xyzr[b])):\n if parents[b] > -1:\n start_point = endpoints[parents[b]]\n num_children_of_parent = num_children[parents[b]]\n if num_children_of_parent > 1:\n y_offset = (\n ((index_of_child[b] / (num_children_of_parent - 1))) - 0.5\n ) * y_offset_multiplier[levels[b]]\n else:\n y_offset = 0.0\n else:\n start_point = [0, 0, 0]\n y_offset = 0.0\n\n len_of_path = np.sqrt(y_offset**2 + 1.0)\n\n end_point = [\n start_point[0] + branch_lens[b] / len_of_path * 1.0,\n start_point[1] + branch_lens[b] / len_of_path * y_offset,\n start_point[2],\n ]\n endpoints.append(end_point)\n\n self.xyzr[b][:, :3] = np.asarray([start_point, end_point])\n else:\n # Dummy to keey the index `endpoints[parent[b]]` above working.\n endpoints.append(np.zeros((2,)))\n\n def move(\n self, x: float = 0.0, y: float = 0.0, z: float = 0.0, update_nodes: bool = False\n ):\n \"\"\"Move cells or networks by adding to their (x, y, z) coordinates.\n\n This function is used only for visualization. It does not affect the simulation.\n\n Args:\n x: The amount to move in the x direction in um.\n y: The amount to move in the y direction in um.\n z: The amount to move in the z direction in um.\n update_nodes: Whether `.nodes` should be updated or not. Setting this to\n `False` largely speeds up moving, especially for big networks, but\n `.nodes` or `.show` will not show the new xyz coordinates.\n \"\"\"\n for i in self._branches_in_view:\n self.base.xyzr[i][:, :3] += np.array([x, y, z])\n if update_nodes:\n self.compute_compartment_centers()\n\n def move_to(\n self,\n x: Union[float, np.ndarray] = 0.0,\n y: Union[float, np.ndarray] = 0.0,\n z: Union[float, np.ndarray] = 0.0,\n update_nodes: bool = False,\n ):\n \"\"\"Move cells or networks to a location (x, y, z).\n\n If x, y, and z are floats, then the first compartment of the first branch\n of the first cell is moved to that float coordinate, and everything else is\n shifted by the difference between that compartment's previous coordinate and\n the new float location.\n\n If x, y, and z are arrays, then they must each have a length equal to the number\n of cells being moved. Then the first compartment of the first branch of each\n cell is moved to the specified location.\n\n Args:\n update_nodes: Whether `.nodes` should be updated or not. Setting this to\n `False` largely speeds up moving, especially for big networks, but\n `.nodes` or `.show` will not show the new xyz coordinates.\n \"\"\"\n # Test if any coordinate values are NaN which would greatly affect moving\n if np.any(np.concatenate(self.xyzr, axis=0)[:, :3] == np.nan):\n raise ValueError(\n \"NaN coordinate values detected. Shift amounts cannot be computed. Please run compute_xyzr() or assign initial coordinate values.\"\n )\n\n # can only iterate over cells for networks\n # lambda makes sure that generator can be created multiple times\n base_is_net = self.base._current_view == \"network\"\n cells = lambda: (self.cells if base_is_net else [self])\n\n root_xyz_cells = np.array([c.xyzr[0][0, :3] for c in cells()])\n root_xyz = root_xyz_cells[0] if isinstance(x, float) else root_xyz_cells\n move_by = np.array([x, y, z]).T - root_xyz\n\n if len(move_by.shape) == 1:\n move_by = np.tile(move_by, (len(self._cells_in_view), 1))\n\n for cell, offset in zip(cells(), move_by):\n for idx in cell._branches_in_view:\n self.base.xyzr[idx][:, :3] += offset\n if update_nodes:\n self.compute_compartment_centers()\n\n def rotate(\n self, degrees: float, rotation_axis: str = \"xy\", update_nodes: bool = False\n ):\n \"\"\"Rotate jaxley modules clockwise. Used only for visualization.\n\n This function is used only for visualization. It does not affect the simulation.\n\n Args:\n degrees: How many degrees to rotate the module by.\n rotation_axis: Either of {`xy` | `xz` | `yz`}.\n \"\"\"\n degrees = degrees / 180 * np.pi\n if rotation_axis == \"xy\":\n dims = [0, 1]\n elif rotation_axis == \"xz\":\n dims = [0, 2]\n elif rotation_axis == \"yz\":\n dims = [1, 2]\n else:\n raise ValueError\n\n rotation_matrix = np.asarray(\n [[np.cos(degrees), np.sin(degrees)], [-np.sin(degrees), np.cos(degrees)]]\n )\n for i in self._branches_in_view:\n rot = np.dot(rotation_matrix, self.base.xyzr[i][:, dims].T).T\n self.base.xyzr[i][:, dims] = rot\n if update_nodes:\n self.compute_compartment_centers()\n\n def copy_node_property_to_edges(\n self,\n properties_to_import: Union[str, List[str]],\n pre_or_post: Union[str, List[str]] = [\"pre\", \"post\"],\n ) -> Module:\n \"\"\"Copy a property that is in `node` over to `edges`.\n\n By default, `.edges` does not contain the properties (radius, length, cm,\n channel properties,...) of the pre- and post-synaptic compartments. This\n method allows to copy a property of the pre- and/or post-synaptic compartment\n to the edges. It is then accessible as `module.edges.pre_property_name` or\n `module.edges.post_property_name`.\n\n Note that, if you modify the node property _after_ having run\n `copy_node_property_to_edges`, it will not automatically update the value in\n `.edges`.\n\n Note that, if this method is called on a View (e.g.\n `net.cell(0).copy_node_property_to_edges`), then it will return a View, but\n it will _not_ modify the module itself.\n\n Args:\n properties_to_import: The name of the node properties that should be\n imported. To list all available properties, look at\n `module.nodes.columns`.\n pre_or_post: Whether to import only the pre-synaptic property ('pre'), only\n the post-synaptic property ('post'), or both (['pre', 'post']).\n\n Returns:\n A new module which has the property copied to the `nodes`.\n \"\"\"\n # If a string is passed, wrap it as a list.\n if isinstance(pre_or_post, str):\n pre_or_post = [pre_or_post]\n if isinstance(properties_to_import, str):\n properties_to_import = [properties_to_import]\n\n for pre_or_post_val in pre_or_post:\n assert pre_or_post_val in [\"pre\", \"post\"]\n for property_to_import in properties_to_import:\n # Delete the column if it already exists. Otherwise it would exist\n # twice.\n if f\"{pre_or_post_val}_{property_to_import}\" in self.edges.columns:\n self.edges.drop(\n columns=f\"{pre_or_post_val}_{property_to_import}\", inplace=True\n )\n\n self.edges = self.edges.join(\n self.nodes[[property_to_import, \"global_comp_index\"]].set_index(\n \"global_comp_index\"\n ),\n on=f\"{pre_or_post_val}_global_comp_index\",\n )\n self.edges = self.edges.rename(\n columns={\n property_to_import: f\"{pre_or_post_val}_{property_to_import}\"\n }\n )\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.branches","title":"branches
property
","text":"Iterate over all branches in the module.
Returns a generator that yields a View of each branch.
"},{"location":"reference/modules/#jaxley.modules.base.Module.cells","title":"cells
property
","text":"Iterate over all cells in the module.
Returns a generator that yields a View of each cell.
"},{"location":"reference/modules/#jaxley.modules.base.Module.comps","title":"comps
property
","text":"Iterate over all compartments in the module. Can be called on any module, i.e. net.comps
, cell.comps
or branch.comps
. __iter__
does not allow for this.
Returns a generator that yields a View of each compartment.
"},{"location":"reference/modules/#jaxley.modules.base.Module.initialized","title":"initialized: bool
property
","text":"Whether the Module
is ready to be solved or not.
shape: Tuple[int]
property
","text":"Returns the number of submodules contained in a module.
.. code-block:: python
network.shape = (num_cells, num_branches, num_compartments)\ncell.shape = (num_branches, num_compartments)\nbranch.shape = (num_compartments,)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.view","title":"view
property
","text":"Return view of the module.
"},{"location":"reference/modules/#jaxley.modules.base.Module.__getitem__","title":"__getitem__(index)
","text":"Lazy indexing of the module.
Source code injaxley/modules/base.py
def __getitem__(self, index):\n \"\"\"Lazy indexing of the module.\"\"\"\n supported_parents = [\"network\", \"cell\", \"branch\"] # cannot index into comp\n\n not_group_view = self._current_view not in self.groups\n assert (\n self._current_view in supported_parents or not_group_view\n ), \"Lazy indexing is only supported for `Network`, `Cell`, `Branch` and Views thereof.\"\n index = index if isinstance(index, tuple) else (index,)\n\n child_views = self._childviews()\n assert len(index) <= len(child_views), \"Too many indices.\"\n view = self\n for i, child in zip(index, child_views):\n view = view._at_nodes(child, i)\n return view\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.__iter__","title":"__iter__()
","text":"Iterate over parts of the module.
Internally calls cells
, branches
, comps
at the appropriate level.
Example:
.. code-block:: python
for cell in network:\n for branch in cell:\n for comp in branch:\n print(comp.nodes.shape)\n
Source code in jaxley/modules/base.py
def __iter__(self):\n \"\"\"Iterate over parts of the module.\n\n Internally calls `cells`, `branches`, `comps` at the appropriate level.\n\n Example:\n\n .. code-block:: python\n\n for cell in network:\n for branch in cell:\n for comp in branch:\n print(comp.nodes.shape)\n \"\"\"\n next_level = self._childviews()[0]\n yield from self._iter_submodules(next_level)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.add_to_group","title":"add_to_group(group_name)
","text":"Add a view of the module to a group.
Groups can then be indexed. For example:
.. code-block:: python
net.cell(0).add_to_group(\"excitatory\")\nnet.excitatory.set(\"radius\", 0.1)\n
Parameters:
Name Type Description Defaultgroup_name
str
The name of the group.
required Source code injaxley/modules/base.py
def add_to_group(self, group_name: str):\n \"\"\"Add a view of the module to a group.\n\n Groups can then be indexed. For example:\n\n .. code-block:: python\n\n net.cell(0).add_to_group(\"excitatory\")\n net.excitatory.set(\"radius\", 0.1)\n\n Args:\n group_name: The name of the group.\n \"\"\"\n if group_name not in self.base.groups:\n self.base.groups[group_name] = self._nodes_in_view\n else:\n self.base.groups[group_name] = np.unique(\n np.concatenate([self.base.groups[group_name], self._nodes_in_view])\n )\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.branch","title":"branch(idx)
","text":"Return a View of the module at the selected branches(s).
Parameters:
Name Type Description Defaultidx
Any
index of the branch to view.
requiredReturns:
Type DescriptionView
View of the module at the specified branch index.
Source code injaxley/modules/base.py
def branch(self, idx: Any) -> View:\n \"\"\"Return a View of the module at the selected branches(s).\n\n Args:\n idx: index of the branch to view.\n\n Returns:\n View of the module at the specified branch index.\"\"\"\n return self._at_nodes(\"branch\", idx)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.cell","title":"cell(idx)
","text":"Return a View of the module at the selected cell(s).
Parameters:
Name Type Description Defaultidx
Any
index of the cell to view.
requiredReturns:
Type DescriptionView
View of the module at the specified cell index.
Source code injaxley/modules/base.py
def cell(self, idx: Any) -> View:\n \"\"\"Return a View of the module at the selected cell(s).\n\n Args:\n idx: index of the cell to view.\n\n Returns:\n View of the module at the specified cell index.\"\"\"\n return self._at_nodes(\"cell\", idx)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.clamp","title":"clamp(state_name, state_array, verbose=True)
","text":"Clamp a state to a given value across specified compartments.
Parameters:
Name Type Description Defaultstate_name
str
The name of the state to clamp.
requiredstate_array
nd
Array of values to clamp the state to.
requiredverbose
If True, prints details about the clamping.
True
This function sets external states for the compartments.
Source code injaxley/modules/base.py
def clamp(self, state_name: str, state_array: jnp.ndarray, verbose: bool = True):\n \"\"\"Clamp a state to a given value across specified compartments.\n\n Args:\n state_name: The name of the state to clamp.\n state_array (jnp.nd: Array of values to clamp the state to.\n verbose : If True, prints details about the clamping.\n\n This function sets external states for the compartments.\n \"\"\"\n self._external_input(state_name, state_array, verbose=verbose)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.comp","title":"comp(idx)
","text":"Return a View of the module at the selected compartments(s).
Parameters:
Name Type Description Defaultidx
Any
index of the comp to view.
requiredReturns:
Type DescriptionView
View of the module at the specified compartment index.
Source code injaxley/modules/base.py
def comp(self, idx: Any) -> View:\n \"\"\"Return a View of the module at the selected compartments(s).\n\n Args:\n idx: index of the comp to view.\n\n Returns:\n View of the module at the specified compartment index.\"\"\"\n return self._at_nodes(\"comp\", idx)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.compute_compartment_centers","title":"compute_compartment_centers()
","text":"Add compartment centers to nodes dataframe
Source code injaxley/modules/base.py
def compute_compartment_centers(self):\n \"\"\"Add compartment centers to nodes dataframe\"\"\"\n centers = self._compute_coords_of_comp_centers()\n self.base.nodes.loc[self._nodes_in_view, [\"x\", \"y\", \"z\"]] = centers\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.compute_xyz","title":"compute_xyz()
","text":"Return xyz coordinates of every branch, based on the branch length.
This function should not be called if the morphology was read from an .swc
file. However, for morphologies that were constructed from scratch, this function must be called before .vis()
. The computed xyz
coordinates are only used for plotting.
jaxley/modules/base.py
def compute_xyz(self):\n \"\"\"Return xyz coordinates of every branch, based on the branch length.\n\n This function should not be called if the morphology was read from an `.swc`\n file. However, for morphologies that were constructed from scratch, this\n function **must** be called before `.vis()`. The computed `xyz` coordinates\n are only used for plotting.\n \"\"\"\n max_y_multiplier = 5.0\n min_y_multiplier = 0.5\n\n parents = self.comb_parents\n num_children = _compute_num_children(parents)\n index_of_child = _compute_index_of_child(parents)\n levels = compute_levels(parents)\n\n # Extract branch.\n inds_branch = self.nodes.groupby(\"global_branch_index\")[\n \"global_comp_index\"\n ].apply(list)\n branch_lens = [np.sum(self.nodes[\"length\"][np.asarray(i)]) for i in inds_branch]\n endpoints = []\n\n # Different levels will get a different \"angle\" at which the children emerge from\n # the parents. This angle is defined by the `y_offset_multiplier`. This value\n # defines the range between y-location of the first and of the last child of a\n # parent.\n y_offset_multiplier = np.linspace(\n max_y_multiplier, min_y_multiplier, np.max(levels) + 1\n )\n\n for b in range(self.total_nbranches):\n # For networks with mixed SWC and from-scatch neurons, only update those\n # branches that do not have coordingates yet.\n if np.any(np.isnan(self.xyzr[b])):\n if parents[b] > -1:\n start_point = endpoints[parents[b]]\n num_children_of_parent = num_children[parents[b]]\n if num_children_of_parent > 1:\n y_offset = (\n ((index_of_child[b] / (num_children_of_parent - 1))) - 0.5\n ) * y_offset_multiplier[levels[b]]\n else:\n y_offset = 0.0\n else:\n start_point = [0, 0, 0]\n y_offset = 0.0\n\n len_of_path = np.sqrt(y_offset**2 + 1.0)\n\n end_point = [\n start_point[0] + branch_lens[b] / len_of_path * 1.0,\n start_point[1] + branch_lens[b] / len_of_path * y_offset,\n start_point[2],\n ]\n endpoints.append(end_point)\n\n self.xyzr[b][:, :3] = np.asarray([start_point, end_point])\n else:\n # Dummy to keey the index `endpoints[parent[b]]` above working.\n endpoints.append(np.zeros((2,)))\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.copy","title":"copy(reset_index=False, as_module=False)
","text":"Extract part of a module and return a copy of its View or a new module.
This can be used to call jx.integrate
on part of a Module.
Parameters:
Name Type Description Defaultreset_index
bool
if True, the indices of the new module are reset to start from 0.
False
as_module
bool
if True, a new module is returned instead of a View.
False
Returns:
Type DescriptionUnion[Module, View]
A part of the module or a copied view of it.
Source code injaxley/modules/base.py
def copy(\n self, reset_index: bool = False, as_module: bool = False\n) -> Union[Module, View]:\n \"\"\"Extract part of a module and return a copy of its View or a new module.\n\n This can be used to call `jx.integrate` on part of a Module.\n\n Args:\n reset_index: if True, the indices of the new module are reset to start from 0.\n as_module: if True, a new module is returned instead of a View.\n\n Returns:\n A part of the module or a copied view of it.\"\"\"\n view = deepcopy(self)\n warnings.warn(\"This method is experimental, use at your own risk.\")\n # TODO FROM #447: add reset_index, i.e. for parents, nodes, edges etc. such that they\n # start from 0/-1 and are contiguous\n if as_module:\n raise NotImplementedError(\"Not yet implemented.\")\n # initialize a new module with the same attributes\n return view\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.copy_node_property_to_edges","title":"copy_node_property_to_edges(properties_to_import, pre_or_post=['pre', 'post'])
","text":"Copy a property that is in node
over to edges
.
By default, .edges
does not contain the properties (radius, length, cm, channel properties,\u2026) of the pre- and post-synaptic compartments. This method allows to copy a property of the pre- and/or post-synaptic compartment to the edges. It is then accessible as module.edges.pre_property_name
or module.edges.post_property_name
.
Note that, if you modify the node property after having run copy_node_property_to_edges
, it will not automatically update the value in .edges
.
Note that, if this method is called on a View (e.g. net.cell(0).copy_node_property_to_edges
), then it will return a View, but it will not modify the module itself.
Parameters:
Name Type Description Defaultproperties_to_import
Union[str, List[str]]
The name of the node properties that should be imported. To list all available properties, look at module.nodes.columns
.
pre_or_post
Union[str, List[str]]
Whether to import only the pre-synaptic property (\u2018pre\u2019), only the post-synaptic property (\u2018post\u2019), or both ([\u2018pre\u2019, \u2018post\u2019]).
['pre', 'post']
Returns:
Type DescriptionModule
A new module which has the property copied to the nodes
.
jaxley/modules/base.py
def copy_node_property_to_edges(\n self,\n properties_to_import: Union[str, List[str]],\n pre_or_post: Union[str, List[str]] = [\"pre\", \"post\"],\n) -> Module:\n \"\"\"Copy a property that is in `node` over to `edges`.\n\n By default, `.edges` does not contain the properties (radius, length, cm,\n channel properties,...) of the pre- and post-synaptic compartments. This\n method allows to copy a property of the pre- and/or post-synaptic compartment\n to the edges. It is then accessible as `module.edges.pre_property_name` or\n `module.edges.post_property_name`.\n\n Note that, if you modify the node property _after_ having run\n `copy_node_property_to_edges`, it will not automatically update the value in\n `.edges`.\n\n Note that, if this method is called on a View (e.g.\n `net.cell(0).copy_node_property_to_edges`), then it will return a View, but\n it will _not_ modify the module itself.\n\n Args:\n properties_to_import: The name of the node properties that should be\n imported. To list all available properties, look at\n `module.nodes.columns`.\n pre_or_post: Whether to import only the pre-synaptic property ('pre'), only\n the post-synaptic property ('post'), or both (['pre', 'post']).\n\n Returns:\n A new module which has the property copied to the `nodes`.\n \"\"\"\n # If a string is passed, wrap it as a list.\n if isinstance(pre_or_post, str):\n pre_or_post = [pre_or_post]\n if isinstance(properties_to_import, str):\n properties_to_import = [properties_to_import]\n\n for pre_or_post_val in pre_or_post:\n assert pre_or_post_val in [\"pre\", \"post\"]\n for property_to_import in properties_to_import:\n # Delete the column if it already exists. Otherwise it would exist\n # twice.\n if f\"{pre_or_post_val}_{property_to_import}\" in self.edges.columns:\n self.edges.drop(\n columns=f\"{pre_or_post_val}_{property_to_import}\", inplace=True\n )\n\n self.edges = self.edges.join(\n self.nodes[[property_to_import, \"global_comp_index\"]].set_index(\n \"global_comp_index\"\n ),\n on=f\"{pre_or_post_val}_global_comp_index\",\n )\n self.edges = self.edges.rename(\n columns={\n property_to_import: f\"{pre_or_post_val}_{property_to_import}\"\n }\n )\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.data_clamp","title":"data_clamp(state_name, state_array, data_clamps=None, verbose=False)
","text":"Insert a clamp into the module within jit (or grad).
Parameters:
Name Type Description Defaultstate_name
str
Name of the state variable to set.
requiredstate_array
ndarray
Time series of the state variable in the default Jaxley unit. State array should be of shape (num_clamps, simulation_time) or (simulation_time, ) for a single clamp.
requiredverbose
bool
Whether or not to print the number of inserted clamps. False
by default because this method is meant to be jitted.
False
Source code in jaxley/modules/base.py
def data_clamp(\n self,\n state_name: str,\n state_array: jnp.ndarray,\n data_clamps: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n verbose: bool = False,\n):\n \"\"\"Insert a clamp into the module within jit (or grad).\n\n Args:\n state_name: Name of the state variable to set.\n state_array: Time series of the state variable in the default Jaxley unit.\n State array should be of shape (num_clamps, simulation_time) or\n (simulation_time, ) for a single clamp.\n verbose: Whether or not to print the number of inserted clamps. `False`\n by default because this method is meant to be jitted.\n \"\"\"\n comp_states, edge_states = self._get_state_names()\n if state_name not in comp_states + edge_states:\n raise KeyError(f\"{state_name} is not a recognized state in this module.\")\n data = self.nodes if state_name in comp_states else self.edges\n return self._data_external_input(\n state_name, state_array, data_clamps, data, verbose=verbose\n )\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.data_set","title":"data_set(key, val, param_state)
","text":"Set parameter of module (or its view) to a new value within jit
.
Parameters:
Name Type Description Defaultkey
str
The name of the parameter to set.
requiredval
Union[float, ndarray]
The value to set the parameter to. If it is jnp.ndarray
then it must be of shape (len(num_compartments))
.
param_state
Optional[List[Dict]]
State of the setted parameters, internally used such that this function does not modify global state.
required Source code injaxley/modules/base.py
def data_set(\n self,\n key: str,\n val: Union[float, jnp.ndarray],\n param_state: Optional[List[Dict]],\n):\n \"\"\"Set parameter of module (or its view) to a new value within `jit`.\n\n Args:\n key: The name of the parameter to set.\n val: The value to set the parameter to. If it is `jnp.ndarray` then it\n must be of shape `(len(num_compartments))`.\n param_state: State of the setted parameters, internally used such that this\n function does not modify global state.\n \"\"\"\n # Note: `data_set` does not support arrays for `val`.\n is_node_param = key in self.nodes.columns\n data = self.nodes if is_node_param else self.edges\n viewed_inds = self._nodes_in_view if is_node_param else self._edges_in_view\n if key in data.columns:\n not_nan = ~data[key].isna()\n added_param_state = [\n {\n \"indices\": np.atleast_2d(viewed_inds[not_nan]),\n \"key\": key,\n \"val\": jnp.atleast_1d(jnp.asarray(val)),\n }\n ]\n if param_state is not None:\n param_state += added_param_state\n else:\n param_state = added_param_state\n else:\n raise KeyError(\"Key not recognized.\")\n return param_state\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.data_stimulate","title":"data_stimulate(current, data_stimuli=None, verbose=False)
","text":"Insert a stimulus into the module within jit (or grad).
Parameters:
Name Type Description Defaultcurrent
ndarray
Current in nA
.
verbose
bool
Whether or not to print the number of inserted stimuli. False
by default because this method is meant to be jitted.
False
Source code in jaxley/modules/base.py
def data_stimulate(\n self,\n current: jnp.ndarray,\n data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None,\n verbose: bool = False,\n) -> Tuple[jnp.ndarray, pd.DataFrame]:\n \"\"\"Insert a stimulus into the module within jit (or grad).\n\n Args:\n current: Current in `nA`.\n verbose: Whether or not to print the number of inserted stimuli. `False`\n by default because this method is meant to be jitted.\n \"\"\"\n return self._data_external_input(\n \"i\", current, data_stimuli, self.nodes, verbose=verbose\n )\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.delete_channel","title":"delete_channel(channel)
","text":"Remove a channel from the module.
Parameters:
Name Type Description Defaultchannel
Channel
The channel to remove.
required Source code injaxley/modules/base.py
def delete_channel(self, channel: Channel):\n \"\"\"Remove a channel from the module.\n\n Args:\n channel: The channel to remove.\"\"\"\n name = channel._name\n channel_names = [c._name for c in self.channels]\n all_channel_names = [c._name for c in self.base.channels]\n if name in channel_names:\n channel_cols = list(channel.channel_params.keys())\n channel_cols += list(channel.channel_states.keys())\n self.base.nodes.loc[self._nodes_in_view, channel_cols] = float(\"nan\")\n self.base.nodes.loc[self._nodes_in_view, name] = False\n\n # only delete cols if no other comps in the module have the same channel\n if np.all(~self.base.nodes[name]):\n self.base.channels.pop(all_channel_names.index(name))\n self.base.membrane_current_names.remove(channel.current_name)\n self.base.nodes.drop(columns=channel_cols + [name], inplace=True)\n else:\n raise ValueError(f\"Channel {name} not found in the module.\")\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.delete_clamps","title":"delete_clamps(state_name=None)
","text":"Removes all clamps of the given state from the module.
Source code injaxley/modules/base.py
def delete_clamps(self, state_name: Optional[str] = None):\n \"\"\"Removes all clamps of the given state from the module.\"\"\"\n all_externals = list(self.externals.keys())\n if \"i\" in all_externals:\n all_externals.remove(\"i\")\n state_names = all_externals if state_name is None else [state_name]\n for state_name in state_names:\n if state_name in self.externals:\n keep_inds = ~np.isin(\n self.base.external_inds[state_name], self._nodes_in_view\n )\n base_exts = self.base.externals\n base_exts_inds = self.base.external_inds\n if np.all(~keep_inds):\n base_exts.pop(state_name, None)\n base_exts_inds.pop(state_name, None)\n else:\n base_exts[state_name] = base_exts[state_name][keep_inds]\n base_exts_inds[state_name] = base_exts_inds[state_name][keep_inds]\n self._update_view()\n else:\n pass # does not have to be deleted if not in externals\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.delete_recordings","title":"delete_recordings()
","text":"Removes all recordings from the module.
Source code injaxley/modules/base.py
def delete_recordings(self):\n \"\"\"Removes all recordings from the module.\"\"\"\n if isinstance(self, View):\n base_recs = self.base.recordings\n self.base.recordings = base_recs[\n ~base_recs.isin(self.recordings).all(axis=1)\n ]\n self._update_view()\n else:\n self.base.recordings = pd.DataFrame().from_dict({})\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.delete_stimuli","title":"delete_stimuli()
","text":"Removes all stimuli from the module.
Source code injaxley/modules/base.py
def delete_stimuli(self):\n \"\"\"Removes all stimuli from the module.\"\"\"\n self.delete_clamps(\"i\")\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.delete_trainables","title":"delete_trainables()
","text":"Removes all trainable parameters from the module.
Source code injaxley/modules/base.py
def delete_trainables(self):\n \"\"\"Removes all trainable parameters from the module.\"\"\"\n\n if isinstance(self, View):\n trainables_and_inds = self._filter_trainables(is_viewed=False)\n self.base.indices_set_by_trainables = trainables_and_inds[0]\n self.base.trainable_params = trainables_and_inds[1]\n self.base.num_trainable_params -= self.num_trainable_params\n else:\n self.base.indices_set_by_trainables = []\n self.base.trainable_params = []\n self.base.num_trainable_params = 0\n self._update_view()\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.distance","title":"distance(endpoint)
","text":"Return the direct distance between two compartments. This does not compute the pathwise distance (which is currently not implemented). Args: endpoint: The compartment to which to compute the distance to.
Source code injaxley/modules/base.py
def distance(self, endpoint: \"View\") -> float:\n \"\"\"Return the direct distance between two compartments.\n This does not compute the pathwise distance (which is currently not\n implemented).\n Args:\n endpoint: The compartment to which to compute the distance to.\n \"\"\"\n assert len(self.xyzr) == 1 and len(endpoint.xyzr) == 1\n start_xyz = np.mean(self.xyzr[0][:, :3], axis=0)\n end_xyz = np.mean(endpoint.xyzr[0][:, :3], axis=0)\n return np.sqrt(np.sum((start_xyz - end_xyz) ** 2))\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.edge","title":"edge(idx)
","text":"Return a View of the module at the selected synapse edges(s).
Parameters:
Name Type Description Defaultidx
Any
index of the edge to view.
requiredReturns:
Type DescriptionView
View of the module at the specified edge index.
Source code injaxley/modules/base.py
def edge(self, idx: Any) -> View:\n \"\"\"Return a View of the module at the selected synapse edges(s).\n\n Args:\n idx: index of the edge to view.\n\n Returns:\n View of the module at the specified edge index.\"\"\"\n return self._at_edges(\"edge\", idx)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.get_all_parameters","title":"get_all_parameters(pstate, voltage_solver)
","text":"Return all parameters (and coupling conductances) needed to simulate.
Runs _compute_axial_conductances()
and return every parameter that is needed to solve the ODE. This includes conductances, radiuses, lengths, axial_resistivities, but also coupling conductances.
This is done by first obtaining the current value of every parameter (not only the trainable ones) and then replacing the trainable ones with the value in trainable_params()
. This function is run within jx.integrate()
.
pstate can be obtained by calling params_to_pstate()
.
.. code-block:: python
params = module.get_parameters() # i.e. [0, 1, 2]\npstate = params_to_pstate(params, module.indices_set_by_trainables)\nmodule.to_jax() # needed for call to module.jaxnodes\n
Parameters:
Name Type Description Defaultpstate
List[Dict]
The state of the trainable parameters. pstate takes the form [{ \u201ckey\u201d: \u201cgNa\u201d, \u201cindices\u201d: jnp.array([0, 1, 2]), \u201cval\u201d: jnp.array([0.1, 0.2, 0.3]) }, \u2026].
requiredvoltage_solver
str
The voltage solver that is used. Since jax.sparse
and jaxley.xyz
require different formats of the axial conductances, this function will default to different building methods.
Returns:
Type DescriptionDict[str, ndarray]
A dictionary of all module parameters.
Source code injaxley/modules/base.py
@only_allow_module\ndef get_all_parameters(\n self, pstate: List[Dict], voltage_solver: str\n) -> Dict[str, jnp.ndarray]:\n # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n \"\"\"Return all parameters (and coupling conductances) needed to simulate.\n\n Runs `_compute_axial_conductances()` and return every parameter that is needed\n to solve the ODE. This includes conductances, radiuses, lengths,\n axial_resistivities, but also coupling conductances.\n\n This is done by first obtaining the current value of every parameter (not only\n the trainable ones) and then replacing the trainable ones with the value\n in `trainable_params()`. This function is run within `jx.integrate()`.\n\n pstate can be obtained by calling `params_to_pstate()`.\n\n .. code-block:: python\n\n params = module.get_parameters() # i.e. [0, 1, 2]\n pstate = params_to_pstate(params, module.indices_set_by_trainables)\n module.to_jax() # needed for call to module.jaxnodes\n\n Args:\n pstate: The state of the trainable parameters. pstate takes the form\n [{\n \"key\": \"gNa\", \"indices\": jnp.array([0, 1, 2]),\n \"val\": jnp.array([0.1, 0.2, 0.3])\n }, ...].\n voltage_solver: The voltage solver that is used. Since `jax.sparse` and\n `jaxley.xyz` require different formats of the axial conductances, this\n function will default to different building methods.\n\n Returns:\n A dictionary of all module parameters.\n \"\"\"\n params = {}\n for key in [\"radius\", \"length\", \"axial_resistivity\", \"capacitance\"]:\n params[key] = self.base.jaxnodes[key]\n\n for channel in self.base.channels:\n for channel_params in channel.channel_params:\n params[channel_params] = self.base.jaxnodes[channel_params]\n\n for synapse_params in self.base.synapse_param_names:\n params[synapse_params] = self.base.jaxedges[synapse_params]\n\n # Override with those parameters set by `.make_trainable()`.\n for parameter in pstate:\n key = parameter[\"key\"]\n inds = parameter[\"indices\"]\n set_param = parameter[\"val\"]\n\n # This is needed since SynapseViews worked differently before.\n # This mimics the old behaviour and tranformes the new indices\n # to the old indices.\n # TODO FROM #447: Longterm this should be gotten rid of.\n # Instead edges should work similar to nodes (would also allow for\n # param sharing).\n synapse_inds = self.base.edges.groupby(\"type\").rank()[\"global_edge_index\"]\n synapse_inds = (synapse_inds.astype(int) - 1).to_numpy()\n if key in self.base.synapse_param_names:\n inds = synapse_inds[inds]\n\n if key in params: # Only parameters, not initial states.\n # `inds` is of shape `(num_params, num_comps_per_param)`.\n # `set_param` is of shape `(num_params,)`\n # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the\n # `.set()` to work. This is done with `[:, None]`.\n params[key] = params[key].at[inds].set(set_param[:, None])\n\n # Compute conductance params and add them to the params dictionary.\n params[\"axial_conductances\"] = self.base._compute_axial_conductances(\n params=params\n )\n return params\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.get_all_states","title":"get_all_states(pstate, all_params, delta_t)
","text":"Get the full initial state of the module from jaxnodes and trainables.
Parameters:
Name Type Description Defaultpstate
List[Dict]
The state of the trainable parameters.
requiredall_params
All parameters of the module.
requireddelta_t
float
The time step.
requiredReturns:
Type DescriptionDict[str, ndarray]
A dictionary of all states of the module.
Source code injaxley/modules/base.py
@only_allow_module\ndef get_all_states(\n self, pstate: List[Dict], all_params, delta_t: float\n) -> Dict[str, jnp.ndarray]:\n # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n \"\"\"Get the full initial state of the module from jaxnodes and trainables.\n\n Args:\n pstate: The state of the trainable parameters.\n all_params: All parameters of the module.\n delta_t: The time step.\n\n Returns:\n A dictionary of all states of the module.\n \"\"\"\n states = self.base._get_states_from_nodes_and_edges()\n\n # Override with the initial states set by `.make_trainable()`.\n for parameter in pstate:\n key = parameter[\"key\"]\n inds = parameter[\"indices\"]\n set_param = parameter[\"val\"]\n if key in states: # Only initial states, not parameters.\n # `inds` is of shape `(num_params, num_comps_per_param)`.\n # `set_param` is of shape `(num_params,)`\n # We need to unsqueeze `set_param` to make it `(num_params, 1)` for the\n # `.set()` to work. This is done with `[:, None]`.\n states[key] = states[key].at[inds].set(set_param[:, None])\n\n # Add to the states the initial current through every channel.\n states, _ = self.base._channel_currents(\n states, delta_t, self.channels, self.nodes, all_params\n )\n\n # Add to the states the initial current through every synapse.\n states, _ = self.base._synapse_currents(\n states, self.synapses, all_params, delta_t, self.edges\n )\n return states\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.get_parameters","title":"get_parameters()
","text":"Get all trainable parameters.
The returned parameters should be passed to `jx.integrate(\u2026, params=params).
Returns:
Type DescriptionList[Dict[str, ndarray]]
A list of all trainable parameters in the form of [{\u201cgNa\u201d: jnp.array([0.1, 0.2, 0.3])}, \u2026].
Source code injaxley/modules/base.py
def get_parameters(self) -> List[Dict[str, jnp.ndarray]]:\n \"\"\"Get all trainable parameters.\n\n The returned parameters should be passed to `jx.integrate(..., params=params).\n\n Returns:\n A list of all trainable parameters in the form of\n [{\"gNa\": jnp.array([0.1, 0.2, 0.3])}, ...].\n \"\"\"\n return self.trainable_params\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.init_states","title":"init_states(delta_t=0.025)
","text":"Initialize all mechanisms in their steady state.
This considers the voltages and parameters of each compartment.
Parameters:
Name Type Description Defaultdelta_t
float
Passed on to channel.init_state()
.
0.025
Source code in jaxley/modules/base.py
@only_allow_module\ndef init_states(self, delta_t: float = 0.025):\n # TODO FROM #447: MAKE THIS WORK FOR VIEW?\n \"\"\"Initialize all mechanisms in their steady state.\n\n This considers the voltages and parameters of each compartment.\n\n Args:\n delta_t: Passed on to `channel.init_state()`.\n \"\"\"\n # Update states of the channels.\n channel_nodes = self.base.nodes\n states = self.base._get_states_from_nodes_and_edges()\n\n # We do not use any `pstate` for initializing. In principle, we could change\n # that by allowing an input `params` and `pstate` to this function.\n # `voltage_solver` could also be `jax.sparse` here, because both of them\n # build the channel parameters in the same way.\n params = self.base.get_all_parameters([], voltage_solver=\"jaxley.thomas\")\n\n for channel in self.base.channels:\n name = channel._name\n channel_indices = channel_nodes.loc[channel_nodes[name]][\n \"global_comp_index\"\n ].to_numpy()\n voltages = channel_nodes.loc[channel_indices, \"v\"].to_numpy()\n\n channel_param_names = list(channel.channel_params.keys())\n channel_state_names = list(channel.channel_states.keys())\n channel_states = query_channel_states_and_params(\n states, channel_state_names, channel_indices\n )\n channel_params = query_channel_states_and_params(\n params, channel_param_names, channel_indices\n )\n\n init_state = channel.init_state(\n channel_states, voltages, channel_params, delta_t\n )\n\n # `init_state` might not return all channel states. Only the ones that are\n # returned are updated here.\n for key, val in init_state.items():\n # Note that we are overriding `self.nodes` here, but `self.nodes` is\n # not used above to actually compute the current states (so there are\n # no issues with overriding states).\n self.nodes.loc[channel_indices, key] = val\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.insert","title":"insert(channel)
","text":"Insert a channel into the module.
Parameters:
Name Type Description Defaultchannel
Channel
The channel to insert.
required Source code injaxley/modules/base.py
def insert(self, channel: Channel):\n \"\"\"Insert a channel into the module.\n\n Args:\n channel: The channel to insert.\"\"\"\n name = channel._name\n\n # Channel does not yet exist in the `jx.Module` at all.\n if name not in [c._name for c in self.base.channels]:\n self.base.channels.append(channel)\n self.base.nodes[name] = (\n False # Previous columns do not have the new channel.\n )\n\n if channel.current_name not in self.base.membrane_current_names:\n self.base.membrane_current_names.append(channel.current_name)\n\n # Add a binary column that indicates if a channel is present.\n self.base.nodes.loc[self._nodes_in_view, name] = True\n\n # Loop over all new parameters, e.g. gNa, eNa.\n for key in channel.channel_params:\n self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_params[key]\n\n # Loop over all new parameters, e.g. gNa, eNa.\n for key in channel.channel_states:\n self.base.nodes.loc[self._nodes_in_view, key] = channel.channel_states[key]\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.loc","title":"loc(at)
","text":"Return a View of the module at the selected branch location(s).
Parameters:
Name Type Description Defaultat
Any
location along the branch.
requiredReturns:
Type DescriptionView
View of the module at the specified branch location.
Source code injaxley/modules/base.py
def loc(self, at: Any) -> View:\n \"\"\"Return a View of the module at the selected branch location(s).\n\n Args:\n at: location along the branch.\n\n Returns:\n View of the module at the specified branch location.\"\"\"\n global_comp_idxs = []\n for i in self._branches_in_view:\n ncomp = self.base.ncomp_per_branch[i]\n comp_locs = np.linspace(0, 1, ncomp)\n at = comp_locs if is_str_all(at) else self._reformat_index(at, dtype=float)\n comp_edges = np.linspace(0, 1 + 1e-10, ncomp + 1)\n idx = np.digitize(at, comp_edges) - 1 + self.base.cumsum_ncomp[i]\n global_comp_idxs.append(idx)\n global_comp_idxs = np.concatenate(global_comp_idxs)\n orig_scope = self._scope\n # global scope needed to select correct comps, for i.e. branches w. ncomp=[1,2]\n # loc(0.9) will correspond to different local branches (0 vs 1).\n view = self.scope(\"global\").comp(global_comp_idxs).scope(orig_scope)\n view._current_view = \"loc\"\n return view\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.make_trainable","title":"make_trainable(key, init_val=None, verbose=True)
","text":"Make a parameter trainable.
If a parameter is made trainable, it will be returned by get_parameters()
and should then be passed to jx.integrate(..., params=params)
.
Parameters:
Name Type Description Defaultkey
str
Name of the parameter to make trainable.
requiredinit_val
Optional[Union[float, list]]
Initial value of the parameter. If float
, the same value is used for every created parameter. If list
, the length of the list has to match the number of created parameters. If None
, the current parameter value is used and if parameter sharing is performed that the current parameter value is averaged over all shared parameters.
None
verbose
bool
Whether to print the number of parameters that are added and the total number of parameters.
True
Source code in jaxley/modules/base.py
def make_trainable(\n self,\n key: str,\n init_val: Optional[Union[float, list]] = None,\n verbose: bool = True,\n):\n \"\"\"Make a parameter trainable.\n\n If a parameter is made trainable, it will be returned by `get_parameters()`\n and should then be passed to `jx.integrate(..., params=params)`.\n\n Args:\n key: Name of the parameter to make trainable.\n init_val: Initial value of the parameter. If `float`, the same value is\n used for every created parameter. If `list`, the length of the list has\n to match the number of created parameters. If `None`, the current\n parameter value is used and if parameter sharing is performed that the\n current parameter value is averaged over all shared parameters.\n verbose: Whether to print the number of parameters that are added and the\n total number of parameters.\n \"\"\"\n assert (\n self.allow_make_trainable\n ), \"network.cell('all').make_trainable() is not supported. Use a for-loop over cells.\"\n ncomps_per_branch = (\n self.base.nodes[\"global_branch_index\"].value_counts().to_numpy()\n )\n assert np.all(\n ncomps_per_branch == ncomps_per_branch[0]\n ), \"Parameter sharing is not allowed for modules containing branches with different numbers of compartments.\"\n\n data = self.nodes if key in self.nodes.columns else None\n data = self.edges if key in self.edges.columns else data\n\n assert data is not None, f\"Key '{key}' not found in nodes or edges\"\n not_nan = ~data[key].isna()\n data = data.loc[not_nan]\n assert (\n len(data) > 0\n ), \"No settable parameters found in the selected compartments.\"\n\n grouped_view = data.groupby(\"controlled_by_param\")\n # Because of this `x.index.values` we cannot support `make_trainable()` on\n # the module level for synapse parameters (but only for `SynapseView`).\n inds_of_comps = list(\n grouped_view.apply(lambda x: x.index.values, include_groups=False)\n )\n indices_per_param = jnp.stack(inds_of_comps)\n # Sorted inds are only used to infer the correct starting values.\n param_vals = jnp.asarray(\n [data.loc[inds, key].to_numpy() for inds in inds_of_comps]\n )\n\n # Set the value which the trainable parameter should take.\n num_created_parameters = len(indices_per_param)\n if init_val is not None:\n if isinstance(init_val, float):\n new_params = jnp.asarray([init_val] * num_created_parameters)\n elif isinstance(init_val, list):\n assert (\n len(init_val) == num_created_parameters\n ), f\"len(init_val)={len(init_val)}, but trying to create {num_created_parameters} parameters.\"\n new_params = jnp.asarray(init_val)\n else:\n raise ValueError(\n f\"init_val must a float, list, or None, but it is a {type(init_val).__name__}.\"\n )\n else:\n new_params = jnp.mean(param_vals, axis=1)\n self.base.trainable_params.append({key: new_params})\n self.base.indices_set_by_trainables.append(indices_per_param)\n self.base.num_trainable_params += num_created_parameters\n if verbose:\n print(\n f\"Number of newly added trainable parameters: {num_created_parameters}. Total number of trainable parameters: {self.base.num_trainable_params}\"\n )\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.move","title":"move(x=0.0, y=0.0, z=0.0, update_nodes=False)
","text":"Move cells or networks by adding to their (x, y, z) coordinates.
This function is used only for visualization. It does not affect the simulation.
Parameters:
Name Type Description Defaultx
float
The amount to move in the x direction in um.
0.0
y
float
The amount to move in the y direction in um.
0.0
z
float
The amount to move in the z direction in um.
0.0
update_nodes
bool
Whether .nodes
should be updated or not. Setting this to False
largely speeds up moving, especially for big networks, but .nodes
or .show
will not show the new xyz coordinates.
False
Source code in jaxley/modules/base.py
def move(\n self, x: float = 0.0, y: float = 0.0, z: float = 0.0, update_nodes: bool = False\n):\n \"\"\"Move cells or networks by adding to their (x, y, z) coordinates.\n\n This function is used only for visualization. It does not affect the simulation.\n\n Args:\n x: The amount to move in the x direction in um.\n y: The amount to move in the y direction in um.\n z: The amount to move in the z direction in um.\n update_nodes: Whether `.nodes` should be updated or not. Setting this to\n `False` largely speeds up moving, especially for big networks, but\n `.nodes` or `.show` will not show the new xyz coordinates.\n \"\"\"\n for i in self._branches_in_view:\n self.base.xyzr[i][:, :3] += np.array([x, y, z])\n if update_nodes:\n self.compute_compartment_centers()\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.move_to","title":"move_to(x=0.0, y=0.0, z=0.0, update_nodes=False)
","text":"Move cells or networks to a location (x, y, z).
If x, y, and z are floats, then the first compartment of the first branch of the first cell is moved to that float coordinate, and everything else is shifted by the difference between that compartment\u2019s previous coordinate and the new float location.
If x, y, and z are arrays, then they must each have a length equal to the number of cells being moved. Then the first compartment of the first branch of each cell is moved to the specified location.
Parameters:
Name Type Description Defaultupdate_nodes
bool
Whether .nodes
should be updated or not. Setting this to False
largely speeds up moving, especially for big networks, but .nodes
or .show
will not show the new xyz coordinates.
False
Source code in jaxley/modules/base.py
def move_to(\n self,\n x: Union[float, np.ndarray] = 0.0,\n y: Union[float, np.ndarray] = 0.0,\n z: Union[float, np.ndarray] = 0.0,\n update_nodes: bool = False,\n):\n \"\"\"Move cells or networks to a location (x, y, z).\n\n If x, y, and z are floats, then the first compartment of the first branch\n of the first cell is moved to that float coordinate, and everything else is\n shifted by the difference between that compartment's previous coordinate and\n the new float location.\n\n If x, y, and z are arrays, then they must each have a length equal to the number\n of cells being moved. Then the first compartment of the first branch of each\n cell is moved to the specified location.\n\n Args:\n update_nodes: Whether `.nodes` should be updated or not. Setting this to\n `False` largely speeds up moving, especially for big networks, but\n `.nodes` or `.show` will not show the new xyz coordinates.\n \"\"\"\n # Test if any coordinate values are NaN which would greatly affect moving\n if np.any(np.concatenate(self.xyzr, axis=0)[:, :3] == np.nan):\n raise ValueError(\n \"NaN coordinate values detected. Shift amounts cannot be computed. Please run compute_xyzr() or assign initial coordinate values.\"\n )\n\n # can only iterate over cells for networks\n # lambda makes sure that generator can be created multiple times\n base_is_net = self.base._current_view == \"network\"\n cells = lambda: (self.cells if base_is_net else [self])\n\n root_xyz_cells = np.array([c.xyzr[0][0, :3] for c in cells()])\n root_xyz = root_xyz_cells[0] if isinstance(x, float) else root_xyz_cells\n move_by = np.array([x, y, z]).T - root_xyz\n\n if len(move_by.shape) == 1:\n move_by = np.tile(move_by, (len(self._cells_in_view), 1))\n\n for cell, offset in zip(cells(), move_by):\n for idx in cell._branches_in_view:\n self.base.xyzr[idx][:, :3] += offset\n if update_nodes:\n self.compute_compartment_centers()\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.rotate","title":"rotate(degrees, rotation_axis='xy', update_nodes=False)
","text":"Rotate jaxley modules clockwise. Used only for visualization.
This function is used only for visualization. It does not affect the simulation.
Parameters:
Name Type Description Defaultdegrees
float
How many degrees to rotate the module by.
requiredrotation_axis
str
Either of {xy
| xz
| yz
}.
'xy'
Source code in jaxley/modules/base.py
def rotate(\n self, degrees: float, rotation_axis: str = \"xy\", update_nodes: bool = False\n):\n \"\"\"Rotate jaxley modules clockwise. Used only for visualization.\n\n This function is used only for visualization. It does not affect the simulation.\n\n Args:\n degrees: How many degrees to rotate the module by.\n rotation_axis: Either of {`xy` | `xz` | `yz`}.\n \"\"\"\n degrees = degrees / 180 * np.pi\n if rotation_axis == \"xy\":\n dims = [0, 1]\n elif rotation_axis == \"xz\":\n dims = [0, 2]\n elif rotation_axis == \"yz\":\n dims = [1, 2]\n else:\n raise ValueError\n\n rotation_matrix = np.asarray(\n [[np.cos(degrees), np.sin(degrees)], [-np.sin(degrees), np.cos(degrees)]]\n )\n for i in self._branches_in_view:\n rot = np.dot(rotation_matrix, self.base.xyzr[i][:, dims].T).T\n self.base.xyzr[i][:, dims] = rot\n if update_nodes:\n self.compute_compartment_centers()\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.scope","title":"scope(scope)
","text":"Return a View of the module with the specified scope.
For example cell.scope(\"global\").branch(2).scope(\"local\").comp(1)
will return the 1st compartment of branch 2.
Parameters:
Name Type Description Defaultscope
str
either \u201cglobal\u201d or \u201clocal\u201d.
requiredReturns:
Type DescriptionView
View with the specified scope.
Source code injaxley/modules/base.py
def scope(self, scope: str) -> View:\n \"\"\"Return a View of the module with the specified scope.\n\n For example `cell.scope(\"global\").branch(2).scope(\"local\").comp(1)`\n will return the 1st compartment of branch 2.\n\n Args:\n scope: either \"global\" or \"local\".\n\n Returns:\n View with the specified scope.\"\"\"\n view = self.view\n view.set_scope(scope)\n return view\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.select","title":"select(nodes=None, edges=None, sorted=False)
","text":"Return View of the module filtered by specific node or edges indices.
Parameters:
Name Type Description Defaultnodes
ndarray
indices of nodes to view. If None, all nodes are viewed.
None
edges
ndarray
indices of edges to view. If None, all edges are viewed.
None
sorted
bool
if True, nodes and edges are sorted.
False
Returns:
Type DescriptionView
View for subset of selected nodes and/or edges.
Source code injaxley/modules/base.py
def select(\n self, nodes: np.ndarray = None, edges: np.ndarray = None, sorted: bool = False\n) -> View:\n \"\"\"Return View of the module filtered by specific node or edges indices.\n\n Args:\n nodes: indices of nodes to view. If None, all nodes are viewed.\n edges: indices of edges to view. If None, all edges are viewed.\n sorted: if True, nodes and edges are sorted.\n\n Returns:\n View for subset of selected nodes and/or edges.\"\"\"\n\n nodes = self._reformat_index(nodes) if nodes is not None else None\n nodes = self._nodes_in_view if is_str_all(nodes) else nodes\n nodes = np.sort(nodes) if sorted else nodes\n\n edges = self._reformat_index(edges) if edges is not None else None\n edges = self._edges_in_view if is_str_all(edges) else edges\n edges = np.sort(edges) if sorted else edges\n\n view = View(self, nodes, edges)\n view._set_controlled_by_param(\"filter\")\n return view\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.set","title":"set(key, val)
","text":"Set parameter of module (or its view) to a new value.
Note that this function can not be called within jax.jit
or jax.grad
. Instead, it should be used set the parameters of the module before the simulation. Use .data_set()
to set parameters during jax.jit
or jax.grad
.
Parameters:
Name Type Description Defaultkey
str
The name of the parameter to set.
requiredval
Union[float, ndarray]
The value to set the parameter to. If it is jnp.ndarray
then it must be of shape (len(num_compartments))
.
jaxley/modules/base.py
def set(self, key: str, val: Union[float, jnp.ndarray]):\n \"\"\"Set parameter of module (or its view) to a new value.\n\n Note that this function can not be called within `jax.jit` or `jax.grad`.\n Instead, it should be used set the parameters of the module **before** the\n simulation. Use `.data_set()` to set parameters during `jax.jit` or\n `jax.grad`.\n\n Args:\n key: The name of the parameter to set.\n val: The value to set the parameter to. If it is `jnp.ndarray` then it\n must be of shape `(len(num_compartments))`.\n \"\"\"\n if key in self.nodes.columns:\n not_nan = ~self.nodes[key].isna().to_numpy()\n self.base.nodes.loc[self._nodes_in_view[not_nan], key] = val\n elif key in self.edges.columns:\n not_nan = ~self.edges[key].isna().to_numpy()\n self.base.edges.loc[self._edges_in_view[not_nan], key] = val\n else:\n raise KeyError(f\"Key '{key}' not found in nodes or edges\")\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.set_ncomp","title":"set_ncomp(ncomp, min_radius=None)
","text":"Set the number of compartments with which the branch is discretized.
Parameters:
Name Type Description Defaultncomp
int
The number of compartments that the branch should be discretized into.
requiredmin_radius
Optional[float]
Only used if the morphology was read from an SWC file. If passed the radius is capped to be at least this value.
None
Source code in jaxley/modules/base.py
def set_ncomp(\n self,\n ncomp: int,\n min_radius: Optional[float] = None,\n):\n \"\"\"Set the number of compartments with which the branch is discretized.\n\n Args:\n ncomp: The number of compartments that the branch should be discretized\n into.\n min_radius: Only used if the morphology was read from an SWC file. If passed\n the radius is capped to be at least this value.\n\n Raises:\n - When there are stimuli in any compartment in the module.\n - When there are recordings in any compartment in the module.\n - When the channels of the compartments are not the same within the branch\n that is modified.\n - When the lengths of the compartments are not the same within the branch\n that is modified.\n - Unless the morphology was read from an SWC file, when the radiuses of the\n compartments are not the same within the branch that is modified.\n \"\"\"\n assert len(self.base.externals) == 0, \"No stimuli allowed!\"\n assert len(self.base.recordings) == 0, \"No recordings allowed!\"\n assert len(self.base.trainable_params) == 0, \"No trainables allowed!\"\n\n assert self.base._module_type != \"network\", \"This is not allowed for networks.\"\n assert not (\n self.base._module_type == \"cell\"\n and len(self._branches_in_view) == len(self.base._branches_in_view)\n ), \"This is not allowed for cells.\"\n\n # Update all attributes that are affected by compartment structure.\n view = self.nodes.copy()\n all_nodes = self.base.nodes\n start_idx = self.nodes[\"global_comp_index\"].to_numpy()[0]\n ncomp_per_branch = self.base.ncomp_per_branch\n channel_names = [c._name for c in self.base.channels]\n channel_param_names = list(\n chain(*[c.channel_params for c in self.base.channels])\n )\n channel_state_names = list(\n chain(*[c.channel_states for c in self.base.channels])\n )\n radius_generating_fns = self.base._radius_generating_fns\n\n within_branch_radiuses = view[\"radius\"].to_numpy()\n compartment_lengths = view[\"length\"].to_numpy()\n num_previous_ncomp = len(within_branch_radiuses)\n branch_indices = pd.unique(view[\"global_branch_index\"])\n\n error_msg = lambda name: (\n f\"You previously modified the {name} of individual compartments, but \"\n f\"now you are modifying the number of compartments in this branch. \"\n f\"This is not allowed. First build the morphology with `set_ncomp()` and \"\n f\"then modify the radiuses and lengths of compartments.\"\n )\n\n if (\n ~np.all(within_branch_radiuses == within_branch_radiuses[0])\n and radius_generating_fns is None\n ):\n raise ValueError(error_msg(\"radius\"))\n\n for property_name in [\"length\", \"capacitance\", \"axial_resistivity\"]:\n compartment_properties = view[property_name].to_numpy()\n if ~np.all(compartment_properties == compartment_properties[0]):\n raise ValueError(error_msg(property_name))\n\n if not (self.nodes[channel_names].var() == 0.0).all():\n raise ValueError(\n \"Some channel exists only in some compartments of the branch which you\"\n \"are trying to modify. This is not allowed. First specify the number\"\n \"of compartments with `.set_ncomp()` and then insert the channels\"\n \"accordingly.\"\n )\n\n if not (\n self.nodes[channel_param_names + channel_state_names].var() == 0.0\n ).all():\n raise ValueError(\n \"Some channel has different parameters or states between the \"\n \"different compartments of the branch which you are trying to modify. \"\n \"This is not allowed. First specify the number of compartments with \"\n \"`.set_ncomp()` and then insert the channels accordingly.\"\n )\n\n # Add new rows as the average of all rows. Special case for the length is below.\n average_row = self.nodes.mean(skipna=False)\n average_row = average_row.to_frame().T\n view = pd.concat([*[average_row] * ncomp], axis=\"rows\")\n\n # Set the correct datatype after having performed an average which cast\n # everything to float.\n integer_cols = [\"global_cell_index\", \"global_branch_index\", \"global_comp_index\"]\n view[integer_cols] = view[integer_cols].astype(int)\n\n # Whether or not a channel exists in a compartment is a boolean.\n boolean_cols = channel_names\n view[boolean_cols] = view[boolean_cols].astype(bool)\n\n # Special treatment for the lengths and radiuses. These are not being set as\n # the average because we:\n # 1) Want to maintain the total length of a branch.\n # 2) Want to use the SWC inferred radius.\n #\n # Compute new compartment lengths.\n comp_lengths = np.sum(compartment_lengths) / ncomp\n view[\"length\"] = comp_lengths\n\n # Compute new compartment radiuses.\n if radius_generating_fns is not None:\n view[\"radius\"] = build_radiuses_from_xyzr(\n radius_fns=radius_generating_fns,\n branch_indices=branch_indices,\n min_radius=min_radius,\n ncomp=ncomp,\n )\n else:\n view[\"radius\"] = within_branch_radiuses[0] * np.ones(ncomp)\n\n # Update `.nodes`.\n # 1) Delete N rows starting from start_idx\n number_deleted = num_previous_ncomp\n all_nodes = all_nodes.drop(index=range(start_idx, start_idx + number_deleted))\n\n # 2) Insert M new rows at the same location\n df1 = all_nodes.iloc[:start_idx] # Rows before the insertion point\n df2 = all_nodes.iloc[start_idx:] # Rows after the insertion point\n\n # 3) Combine the parts: before, new rows, and after\n all_nodes = pd.concat([df1, view, df2]).reset_index(drop=True)\n\n # Override `comp_index` to just be a consecutive list.\n all_nodes[\"global_comp_index\"] = np.arange(len(all_nodes))\n\n # Update compartment structure arguments.\n ncomp_per_branch[branch_indices] = ncomp\n ncomp = int(np.max(ncomp_per_branch))\n cumsum_ncomp = cumsum_leading_zero(ncomp_per_branch)\n internal_node_inds = np.arange(cumsum_ncomp[-1])\n\n self.base.nodes = all_nodes\n self.base.ncomp_per_branch = ncomp_per_branch\n self.base.ncomp = ncomp\n self.base.cumsum_ncomp = cumsum_ncomp\n self.base._internal_node_inds = internal_node_inds\n\n # Update the morphology indexing (e.g., `.comp_edges`).\n self.base._initialize()\n self.base._init_view()\n self.base._update_local_indices()\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.set_scope","title":"set_scope(scope)
","text":"Toggle between \u201cglobal\u201d or \u201clocal\u201d scope.
Determines if global or local indices are used for viewing the module.
Parameters:
Name Type Description Defaultscope
str
either \u201cglobal\u201d or \u201clocal\u201d.
required Source code injaxley/modules/base.py
def set_scope(self, scope: str):\n \"\"\"Toggle between \"global\" or \"local\" scope.\n\n Determines if global or local indices are used for viewing the module.\n\n Args:\n scope: either \"global\" or \"local\".\"\"\"\n assert scope in [\"global\", \"local\"], \"Invalid scope.\"\n self._scope = scope\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.show","title":"show(param_names=None, *, indices=True, params=True, states=True, channel_names=None)
","text":"Print detailed information about the Module or a view of it.
Parameters:
Name Type Description Defaultparam_names
Optional[Union[str, List[str]]]
The names of the parameters to show. If None
, all parameters are shown.
None
indices
bool
Whether to show the indices of the compartments.
True
params
bool
Whether to show the parameters of the compartments.
True
states
bool
Whether to show the states of the compartments.
True
channel_names
Optional[List[str]]
The names of the channels to show. If None
, all channels are shown.
None
Returns:
Type DescriptionDataFrame
A pd.DataFrame
with the requested information.
jaxley/modules/base.py
def show(\n self,\n param_names: Optional[Union[str, List[str]]] = None,\n *,\n indices: bool = True,\n params: bool = True,\n states: bool = True,\n channel_names: Optional[List[str]] = None,\n) -> pd.DataFrame:\n \"\"\"Print detailed information about the Module or a view of it.\n\n Args:\n param_names: The names of the parameters to show. If `None`, all parameters\n are shown.\n indices: Whether to show the indices of the compartments.\n params: Whether to show the parameters of the compartments.\n states: Whether to show the states of the compartments.\n channel_names: The names of the channels to show. If `None`, all channels are\n shown.\n\n Returns:\n A `pd.DataFrame` with the requested information.\n \"\"\"\n nodes = self.nodes.copy() # prevents this from being edited\n\n cols = []\n inds = [\"comp_index\", \"branch_index\", \"cell_index\"]\n scopes = [\"local\", \"global\"]\n inds = [f\"{s}_{i}\" for i in inds for s in scopes] if indices else []\n cols += inds\n cols += [ch._name for ch in self.channels] if channel_names else []\n cols += (\n sum([list(ch.channel_params) for ch in self.channels], []) if params else []\n )\n cols += (\n sum([list(ch.channel_states) for ch in self.channels], []) if states else []\n )\n\n if not param_names is None:\n cols = (\n inds + [c for c in cols if c in param_names]\n if params\n else list(param_names)\n )\n\n return nodes[cols]\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.step","title":"step(u, delta_t, external_inds, externals, params, solver='bwd_euler', voltage_solver='jaxley.stone')
","text":"One step of solving the Ordinary Differential Equation.
This function is called inside of integrate
and increments the state of the module by one time step. Calls _step_channels
and _step_synapse
to update the states of the channels and synapses using fwd_euler.
Parameters:
Name Type Description Defaultu
Dict[str, ndarray]
The state of the module. voltages = u[\u201cv\u201d]
requireddelta_t
float
The time step.
requiredexternal_inds
Dict[str, ndarray]
The indices of the external inputs.
requiredexternals
Dict[str, ndarray]
The external inputs.
requiredparams
Dict[str, ndarray]
The parameters of the module.
requiredsolver
str
The solver to use for the voltages. Either of [\u201cbwd_euler\u201d, \u201cfwd_euler\u201d, \u201ccrank_nicolson\u201d].
'bwd_euler'
voltage_solver
str
The tridiagonal solver used to diagonalize the coefficient matrix of the ODE system. Either of [\u201cjaxley.thomas\u201d, \u201cjaxley.stone\u201d].
'jaxley.stone'
Returns:
Type DescriptionDict[str, ndarray]
The updated state of the module.
Source code injaxley/modules/base.py
@only_allow_module\ndef step(\n self,\n u: Dict[str, jnp.ndarray],\n delta_t: float,\n external_inds: Dict[str, jnp.ndarray],\n externals: Dict[str, jnp.ndarray],\n params: Dict[str, jnp.ndarray],\n solver: str = \"bwd_euler\",\n voltage_solver: str = \"jaxley.stone\",\n) -> Dict[str, jnp.ndarray]:\n \"\"\"One step of solving the Ordinary Differential Equation.\n\n This function is called inside of `integrate` and increments the state of the\n module by one time step. Calls `_step_channels` and `_step_synapse` to update\n the states of the channels and synapses using fwd_euler.\n\n Args:\n u: The state of the module. voltages = u[\"v\"]\n delta_t: The time step.\n external_inds: The indices of the external inputs.\n externals: The external inputs.\n params: The parameters of the module.\n solver: The solver to use for the voltages. Either of [\"bwd_euler\",\n \"fwd_euler\", \"crank_nicolson\"].\n voltage_solver: The tridiagonal solver used to diagonalize the\n coefficient matrix of the ODE system. Either of [\"jaxley.thomas\",\n \"jaxley.stone\"].\n\n Returns:\n The updated state of the module.\n \"\"\"\n\n # Extract the voltages\n voltages = u[\"v\"]\n\n # Extract the external inputs\n if \"i\" in externals.keys():\n i_current = externals[\"i\"]\n i_inds = external_inds[\"i\"]\n i_ext = self._get_external_input(\n voltages, i_inds, i_current, params[\"radius\"], params[\"length\"]\n )\n else:\n i_ext = 0.0\n\n # Step of the channels.\n u, (v_terms, const_terms) = self._step_channels(\n u, delta_t, self.channels, self.nodes, params\n )\n\n # Step of the synapse.\n u, (syn_v_terms, syn_const_terms) = self._step_synapse(\n u,\n self.synapses,\n params,\n delta_t,\n self.edges,\n )\n\n # Clamp for channels and synapses.\n for key in externals.keys():\n if key not in [\"i\", \"v\"]:\n u[key] = u[key].at[external_inds[key]].set(externals[key])\n\n # Voltage steps.\n cm = params[\"capacitance\"] # Abbreviation.\n\n # Arguments used by all solvers.\n solver_kwargs = {\n \"voltages\": voltages,\n \"voltage_terms\": (v_terms + syn_v_terms) / cm,\n \"constant_terms\": (const_terms + i_ext + syn_const_terms) / cm,\n \"axial_conductances\": params[\"axial_conductances\"],\n \"internal_node_inds\": self._internal_node_inds,\n }\n\n # Add solver specific arguments.\n if voltage_solver == \"jax.sparse\":\n solver_kwargs.update(\n {\n \"sinks\": np.asarray(self._comp_edges[\"sink\"].to_list()),\n \"data_inds\": self._data_inds,\n \"indices\": self._indices_jax_spsolve,\n \"indptr\": self._indptr_jax_spsolve,\n \"n_nodes\": self._n_nodes,\n }\n )\n # Only for `bwd_euler` and `cranck-nicolson`.\n step_voltage_implicit = step_voltage_implicit_with_jax_spsolve\n else:\n # Our custom sparse solver requires a different format of all conductance\n # values to perform triangulation and backsubstution optimally.\n #\n # Currently, the forward Euler solver also uses this format. However,\n # this is only for historical reasons and we are planning to change this in\n # the future.\n solver_kwargs.update(\n {\n \"sinks\": np.asarray(self._comp_edges[\"sink\"].to_list()),\n \"sources\": np.asarray(self._comp_edges[\"source\"].to_list()),\n \"types\": np.asarray(self._comp_edges[\"type\"].to_list()),\n \"ncomp_per_branch\": self.ncomp_per_branch,\n \"par_inds\": self._par_inds,\n \"child_inds\": self._child_inds,\n \"nbranches\": self.total_nbranches,\n \"solver\": voltage_solver,\n \"idx\": self._solve_indexer,\n \"debug_states\": self.debug_states,\n }\n )\n # Only for `bwd_euler` and `cranck-nicolson`.\n step_voltage_implicit = step_voltage_implicit_with_jaxley_spsolve\n\n if solver == \"bwd_euler\":\n u[\"v\"] = step_voltage_implicit(**solver_kwargs, delta_t=delta_t)\n elif solver == \"crank_nicolson\":\n # Crank-Nicolson advances by half a step of backward and half a step of\n # forward Euler.\n half_step_delta_t = delta_t / 2\n half_step_voltages = step_voltage_implicit(\n **solver_kwargs, delta_t=half_step_delta_t\n )\n # The forward Euler step in Crank-Nicolson can be performed easily as\n # `V_{n+1} = 2 * V_{n+1/2} - V_n`. See also NEURON book Chapter 4.\n u[\"v\"] = 2 * half_step_voltages - voltages\n elif solver == \"fwd_euler\":\n u[\"v\"] = step_voltage_explicit(**solver_kwargs, delta_t=delta_t)\n else:\n raise ValueError(\n f\"You specified `solver={solver}`. The only allowed solvers are \"\n \"['bwd_euler', 'fwd_euler', 'crank_nicolson'].\"\n )\n\n # Clamp for voltages.\n if \"v\" in externals.keys():\n u[\"v\"] = u[\"v\"].at[external_inds[\"v\"]].set(externals[\"v\"])\n\n return u\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.stimulate","title":"stimulate(current=None, verbose=True)
","text":"Insert a stimulus into the compartment.
current must be a 1d array or have batch dimension of size (num_compartments, )
or (1, )
. If 1d, the same stimulus is added to all compartments.
This function cannot be run during jax.jit
and jax.grad
. Because of this, it should only be used for static stimuli (i.e., stimuli that do not depend on the data and that should not be learned). For stimuli that depend on data (or that should be learned), please use data_stimulate()
.
Parameters:
Name Type Description Defaultcurrent
Optional[ndarray]
Current in nA
.
None
Source code in jaxley/modules/base.py
def stimulate(self, current: Optional[jnp.ndarray] = None, verbose: bool = True):\n \"\"\"Insert a stimulus into the compartment.\n\n current must be a 1d array or have batch dimension of size `(num_compartments, )`\n or `(1, )`. If 1d, the same stimulus is added to all compartments.\n\n This function cannot be run during `jax.jit` and `jax.grad`. Because of this,\n it should only be used for static stimuli (i.e., stimuli that do not depend\n on the data and that should not be learned). For stimuli that depend on data\n (or that should be learned), please use `data_stimulate()`.\n\n Args:\n current: Current in `nA`.\n \"\"\"\n self._external_input(\"i\", current, verbose=verbose)\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.to_jax","title":"to_jax()
","text":"Move .nodes
to .jaxnodes
.
Before the actual simulation is run (via jx.integrate
), all parameters of the jx.Module
are stored in .nodes
(a pd.DataFrame
). However, for simulation, these parameters have to be moved to be jnp.ndarrays
such that they can be processed on GPU/TPU and such that the simulation can be differentiated. .to_jax()
copies the .nodes
to .jaxnodes
.
jaxley/modules/base.py
@only_allow_module\ndef to_jax(self):\n # TODO FROM #447: Make this work for View?\n \"\"\"Move `.nodes` to `.jaxnodes`.\n\n Before the actual simulation is run (via `jx.integrate`), all parameters of\n the `jx.Module` are stored in `.nodes` (a `pd.DataFrame`). However, for\n simulation, these parameters have to be moved to be `jnp.ndarrays` such that\n they can be processed on GPU/TPU and such that the simulation can be\n differentiated. `.to_jax()` copies the `.nodes` to `.jaxnodes`.\n \"\"\"\n self.base.jaxnodes = {}\n for key, value in self.base.nodes.to_dict(orient=\"list\").items():\n inds = jnp.arange(len(value))\n self.base.jaxnodes[key] = jnp.asarray(value)[inds]\n\n # `jaxedges` contains only parameters (no indices).\n # `jaxedges` contains only non-Nan elements. This is unlike the channels where\n # we allow parameter sharing.\n self.base.jaxedges = {}\n edges = self.base.edges.to_dict(orient=\"list\")\n for i, synapse in enumerate(self.base.synapses):\n condition = np.asarray(edges[\"type_ind\"]) == i\n for key in synapse.synapse_params:\n self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])\n for key in synapse.synapse_states:\n self.base.jaxedges[key] = jnp.asarray(np.asarray(edges[key])[condition])\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.vis","title":"vis(ax=None, col='k', dims=(0, 1), type='line', morph_plot_kwargs={})
","text":"Visualize the module.
Modules can be visualized on one of the cardinal planes (xy, xz, yz) or even in 3D.
Several options are available: - line
: All points from the traced morphology (xyzr
), are connected with a line plot. - scatter
: All traced points, are plotted as scatter points. - comp
: Plots the compartmentalized morphology, including radius and shape. (shows the true compartment lengths per default, but this can be changed via the morph_plot_kwargs
, for details see jaxley.utils.plot_utils.plot_comps
). - morph
: Reconstructs the 3D shape of the traced morphology. For details see jaxley.utils.plot_utils.plot_morph
. Warning: For 3D plots and morphologies with many traced points this can be very slow.
Parameters:
Name Type Description Defaultax
Optional[Axes]
An axis into which to plot.
None
col
str
The color for all branches.
'k'
dims
Tuple[int]
Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them.
(0, 1)
type
str
The type of plot. One of [\u201cline\u201d, \u201cscatter\u201d, \u201ccomp\u201d, \u201cmorph\u201d].
'line'
morph_plot_kwargs
Dict
Keyword arguments passed to the plotting function.
{}
Source code in jaxley/modules/base.py
def vis(\n self,\n ax: Optional[Axes] = None,\n col: str = \"k\",\n dims: Tuple[int] = (0, 1),\n type: str = \"line\",\n morph_plot_kwargs: Dict = {},\n) -> Axes:\n \"\"\"Visualize the module.\n\n Modules can be visualized on one of the cardinal planes (xy, xz, yz) or\n even in 3D.\n\n Several options are available:\n - `line`: All points from the traced morphology (`xyzr`), are connected\n with a line plot.\n - `scatter`: All traced points, are plotted as scatter points.\n - `comp`: Plots the compartmentalized morphology, including radius\n and shape. (shows the true compartment lengths per default, but this can\n be changed via the `morph_plot_kwargs`, for details see\n `jaxley.utils.plot_utils.plot_comps`).\n - `morph`: Reconstructs the 3D shape of the traced morphology. For details see\n `jaxley.utils.plot_utils.plot_morph`. Warning: For 3D plots and morphologies\n with many traced points this can be very slow.\n\n Args:\n ax: An axis into which to plot.\n col: The color for all branches.\n dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n two of them.\n type: The type of plot. One of [\"line\", \"scatter\", \"comp\", \"morph\"].\n morph_plot_kwargs: Keyword arguments passed to the plotting function.\n \"\"\"\n if \"comp\" in type.lower():\n return plot_comps(self, dims=dims, ax=ax, col=col, **morph_plot_kwargs)\n if \"morph\" in type.lower():\n return plot_morph(self, dims=dims, ax=ax, col=col, **morph_plot_kwargs)\n\n assert not np.any(\n [np.isnan(xyzr[:, dims]).all() for xyzr in self.xyzr]\n ), \"No coordinates available. Use `vis(detail='point')` or run `.compute_xyz()` before running `.vis()`.\"\n\n ax = plot_graph(\n self.xyzr,\n dims=dims,\n col=col,\n ax=ax,\n type=type,\n morph_plot_kwargs=morph_plot_kwargs,\n )\n\n return ax\n
"},{"location":"reference/modules/#jaxley.modules.base.Module.write_trainables","title":"write_trainables(trainable_params)
","text":"Write the trainables into .nodes
and .edges
.
This allows to, e.g., visualize trained networks with .vis()
.
Parameters:
Name Type Description Defaulttrainable_params
List[Dict[str, ndarray]]
The trainable parameters returned by get_parameters()
.
jaxley/modules/base.py
def write_trainables(self, trainable_params: List[Dict[str, jnp.ndarray]]):\n \"\"\"Write the trainables into `.nodes` and `.edges`.\n\n This allows to, e.g., visualize trained networks with `.vis()`.\n\n Args:\n trainable_params: The trainable parameters returned by `get_parameters()`.\n \"\"\"\n # We do not support views. Why? `jaxedges` does not have any NaN\n # elements, whereas edges does. Because of this, we already need special\n # treatment to make this function work, and it would be an even bigger hassle\n # if we wanted to support this.\n assert self.__class__.__name__ in [\n \"Compartment\",\n \"Branch\",\n \"Cell\",\n \"Network\",\n ], \"Only supports modules.\"\n\n # We could also implement this without casting the module to jax.\n # However, I think it allows us to reuse as much code as possible and it avoids\n # any kind of issues with indexing or parameter sharing (as this is fully\n # taken care of by `get_all_parameters()`).\n self.base.to_jax()\n pstate = params_to_pstate(trainable_params, self.base.indices_set_by_trainables)\n all_params = self.base.get_all_parameters(pstate, voltage_solver=\"jaxley.stone\")\n\n # The value for `delta_t` does not matter here because it is only used to\n # compute the initial current. However, the initial current cannot be made\n # trainable and so its value never gets used below.\n all_states = self.base.get_all_states(pstate, all_params, delta_t=0.025)\n\n # Loop only over the keys in `pstate` to avoid unnecessary computation.\n for parameter in pstate:\n key = parameter[\"key\"]\n if key in self.base.nodes.columns:\n vals_to_set = all_params if key in all_params.keys() else all_states\n self.base.nodes[key] = vals_to_set[key]\n\n # `jaxedges` contains only non-Nan elements. This is unlike the channels where\n # we allow parameter sharing.\n edges = self.base.edges.to_dict(orient=\"list\")\n for i, synapse in enumerate(self.base.synapses):\n condition = np.asarray(edges[\"type_ind\"]) == i\n for key in list(synapse.synapse_params.keys()):\n self.base.edges.loc[condition, key] = all_params[key]\n for key in list(synapse.synapse_states.keys()):\n self.base.edges.loc[condition, key] = all_states[key]\n
"},{"location":"reference/modules/#compartment","title":"Compartment","text":" Bases: Module
Compartment class.
This class defines a single compartment that can be simulated by itself or connected up into branches. It is the basic building block of a neuron model.
Source code injaxley/modules/compartment.py
class Compartment(Module):\n \"\"\"Compartment class.\n\n This class defines a single compartment that can be simulated by itself or\n connected up into branches. It is the basic building block of a neuron model.\n \"\"\"\n\n compartment_params: Dict = {\n \"length\": 10.0, # um\n \"radius\": 1.0, # um\n \"axial_resistivity\": 5_000.0, # ohm cm\n \"capacitance\": 1.0, # uF/cm^2\n }\n compartment_states: Dict = {\"v\": -70.0}\n\n def __init__(self):\n super().__init__()\n\n self.ncomp = 1\n self.ncomp_per_branch = np.asarray([1])\n self.total_nbranches = 1\n self.nbranches_per_cell = [1]\n self._cumsum_nbranches = np.asarray([0, 1])\n self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n\n # Setting up the `nodes` for indexing.\n self.nodes = pd.DataFrame(\n dict(global_cell_index=[0], global_branch_index=[0], global_comp_index=[0])\n )\n self._append_params_and_states(self.compartment_params, self.compartment_states)\n self._update_local_indices()\n self._init_view()\n\n # Synapses.\n self.branch_edges = pd.DataFrame(\n dict(parent_branch_index=[], child_branch_index=[])\n )\n\n # For morphology indexing.\n self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n compute_children_and_parents(self.branch_edges)\n )\n self._internal_node_inds = jnp.asarray([0])\n\n # Initialize the module.\n self._initialize()\n\n # Coordinates.\n self.xyzr = [float(\"NaN\") * np.zeros((2, 4))]\n\n def _init_morph_jaxley_spsolve(self):\n self._solve_indexer = JaxleySolveIndexer(\n cumsum_ncomp=self.cumsum_ncomp,\n branchpoint_group_inds=np.asarray([]).astype(int),\n children_in_level=[],\n parents_in_level=[],\n root_inds=np.asarray([0]),\n remapped_node_indices=self._internal_node_inds,\n )\n\n def _init_morph_jax_spsolve(self):\n \"\"\"Initialize morphology for the jax sparse voltage solver.\n\n Explanation of `self._comp_eges['type']`:\n `type == 0`: compartment <--> compartment (within branch)\n `type == 1`: branchpoint --> parent-compartment\n `type == 2`: branchpoint --> child-compartment\n `type == 3`: parent-compartment --> branchpoint\n `type == 4`: child-compartment --> branchpoint\n \"\"\"\n self._comp_edges = pd.DataFrame().from_dict(\n {\"source\": [], \"sink\": [], \"type\": []}\n )\n n_nodes, data_inds, indices, indptr = comp_edges_to_indices(self._comp_edges)\n self._n_nodes = n_nodes\n self._data_inds = data_inds\n self._indices_jax_spsolve = indices\n self._indptr_jax_spsolve = indptr\n
"},{"location":"reference/modules/#branch","title":"Branch","text":" Bases: Module
Branch class.
This class defines a single branch that can be simulated by itself or connected to build a cell. A branch is linear segment of several compartments and can be connected to no, one or more other branches at each end to build more intricate cell morphologies.
Source code injaxley/modules/branch.py
class Branch(Module):\n \"\"\"Branch class.\n\n This class defines a single branch that can be simulated by itself or\n connected to build a cell. A branch is linear segment of several compartments\n and can be connected to no, one or more other branches at each end to build more\n intricate cell morphologies.\n \"\"\"\n\n branch_params: Dict = {}\n branch_states: Dict = {}\n\n @deprecated_kwargs(\"0.6.0\", [\"nseg\"])\n def __init__(\n self,\n compartments: Optional[Union[Compartment, List[Compartment]]] = None,\n ncomp: Optional[int] = None,\n nseg: Optional[int] = None,\n ):\n \"\"\"\n Args:\n compartments: A single compartment or a list of compartments that make up the\n branch.\n ncomp: Number of segments to divide the branch into. If `compartments` is an\n a single compartment, than the compartment is repeated `ncomp` times to\n create the branch.\n \"\"\"\n # Warnings and errors that deal with the change from `nseg` to `ncomp` change\n # in Jaxley v0.5.0.\n if ncomp is not None and nseg is not None:\n raise ValueError(\"You passed `ncomp` and `nseg`. Please pass only `ncomp`.\")\n if ncomp is None and nseg is not None:\n ncomp = nseg\n\n super().__init__()\n assert (\n isinstance(compartments, (Compartment, List)) or compartments is None\n ), \"Only Compartment or List[Compartment] is allowed.\"\n if isinstance(compartments, Compartment):\n assert (\n ncomp is not None\n ), \"If `compartments` is not a list then you have to set `ncomp`.\"\n compartments = Compartment() if compartments is None else compartments\n ncomp = 1 if ncomp is None else ncomp\n\n if isinstance(compartments, Compartment):\n compartment_list = [compartments] * ncomp\n else:\n compartment_list = compartments\n\n self.ncomp = len(compartment_list)\n self.ncomp_per_branch = np.asarray([self.ncomp])\n self.total_nbranches = 1\n self.nbranches_per_cell = [1]\n self._cumsum_nbranches = jnp.asarray([0, 1])\n self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n\n # Indexing.\n self.nodes = pd.concat([c.nodes for c in compartment_list], ignore_index=True)\n self._append_params_and_states(self.branch_params, self.branch_states)\n self.nodes[\"global_comp_index\"] = np.arange(self.ncomp).tolist()\n self.nodes[\"global_branch_index\"] = [0] * self.ncomp\n self.nodes[\"global_cell_index\"] = [0] * self.ncomp\n self._update_local_indices()\n self._init_view()\n\n # Channels.\n self._gather_channels_from_constituents(compartment_list)\n\n self.branch_edges = pd.DataFrame(\n dict(parent_branch_index=[], child_branch_index=[])\n )\n\n # For morphology indexing.\n self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n compute_children_and_parents(self.branch_edges)\n )\n self._internal_node_inds = jnp.arange(self.ncomp)\n\n self._initialize()\n\n # Coordinates.\n self.xyzr = [float(\"NaN\") * np.zeros((2, 4))]\n\n def _init_morph_jaxley_spsolve(self):\n self._solve_indexer = JaxleySolveIndexer(\n cumsum_ncomp=self.cumsum_ncomp,\n branchpoint_group_inds=np.asarray([]).astype(int),\n remapped_node_indices=self._internal_node_inds,\n children_in_level=[],\n parents_in_level=[],\n root_inds=np.asarray([0]),\n )\n\n def _init_morph_jax_spsolve(self):\n \"\"\"Initialize morphology for the jax sparse voltage solver.\n\n Explanation of `self._comp_eges['type']`:\n `type == 0`: compartment <--> compartment (within branch)\n `type == 1`: branchpoint --> parent-compartment\n `type == 2`: branchpoint --> child-compartment\n `type == 3`: parent-compartment --> branchpoint\n `type == 4`: child-compartment --> branchpoint\n \"\"\"\n self._comp_edges = pd.DataFrame().from_dict(\n {\n \"source\": list(range(self.ncomp - 1)) + list(range(1, self.ncomp)),\n \"sink\": list(range(1, self.ncomp)) + list(range(self.ncomp - 1)),\n }\n )\n self._comp_edges[\"type\"] = 0\n n_nodes, data_inds, indices, indptr = comp_edges_to_indices(self._comp_edges)\n self._n_nodes = n_nodes\n self._data_inds = data_inds\n self._indices_jax_spsolve = indices\n self._indptr_jax_spsolve = indptr\n\n def __len__(self) -> int:\n return self.ncomp\n
"},{"location":"reference/modules/#jaxley.modules.branch.Branch.__init__","title":"__init__(compartments=None, ncomp=None, nseg=None)
","text":"Parameters:
Name Type Description Defaultcompartments
Optional[Union[Compartment, List[Compartment]]]
A single compartment or a list of compartments that make up the branch.
None
ncomp
Optional[int]
Number of segments to divide the branch into. If compartments
is an a single compartment, than the compartment is repeated ncomp
times to create the branch.
None
Source code in jaxley/modules/branch.py
@deprecated_kwargs(\"0.6.0\", [\"nseg\"])\ndef __init__(\n self,\n compartments: Optional[Union[Compartment, List[Compartment]]] = None,\n ncomp: Optional[int] = None,\n nseg: Optional[int] = None,\n):\n \"\"\"\n Args:\n compartments: A single compartment or a list of compartments that make up the\n branch.\n ncomp: Number of segments to divide the branch into. If `compartments` is an\n a single compartment, than the compartment is repeated `ncomp` times to\n create the branch.\n \"\"\"\n # Warnings and errors that deal with the change from `nseg` to `ncomp` change\n # in Jaxley v0.5.0.\n if ncomp is not None and nseg is not None:\n raise ValueError(\"You passed `ncomp` and `nseg`. Please pass only `ncomp`.\")\n if ncomp is None and nseg is not None:\n ncomp = nseg\n\n super().__init__()\n assert (\n isinstance(compartments, (Compartment, List)) or compartments is None\n ), \"Only Compartment or List[Compartment] is allowed.\"\n if isinstance(compartments, Compartment):\n assert (\n ncomp is not None\n ), \"If `compartments` is not a list then you have to set `ncomp`.\"\n compartments = Compartment() if compartments is None else compartments\n ncomp = 1 if ncomp is None else ncomp\n\n if isinstance(compartments, Compartment):\n compartment_list = [compartments] * ncomp\n else:\n compartment_list = compartments\n\n self.ncomp = len(compartment_list)\n self.ncomp_per_branch = np.asarray([self.ncomp])\n self.total_nbranches = 1\n self.nbranches_per_cell = [1]\n self._cumsum_nbranches = jnp.asarray([0, 1])\n self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n\n # Indexing.\n self.nodes = pd.concat([c.nodes for c in compartment_list], ignore_index=True)\n self._append_params_and_states(self.branch_params, self.branch_states)\n self.nodes[\"global_comp_index\"] = np.arange(self.ncomp).tolist()\n self.nodes[\"global_branch_index\"] = [0] * self.ncomp\n self.nodes[\"global_cell_index\"] = [0] * self.ncomp\n self._update_local_indices()\n self._init_view()\n\n # Channels.\n self._gather_channels_from_constituents(compartment_list)\n\n self.branch_edges = pd.DataFrame(\n dict(parent_branch_index=[], child_branch_index=[])\n )\n\n # For morphology indexing.\n self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n compute_children_and_parents(self.branch_edges)\n )\n self._internal_node_inds = jnp.arange(self.ncomp)\n\n self._initialize()\n\n # Coordinates.\n self.xyzr = [float(\"NaN\") * np.zeros((2, 4))]\n
"},{"location":"reference/modules/#cell","title":"Cell","text":" Bases: Module
Cell class.
This class defines a single cell that can be simulated by itself or connected with synapses to build a network. A cell is made up of several branches and supports intricate cell morphologies.
Source code injaxley/modules/cell.py
class Cell(Module):\n \"\"\"Cell class.\n\n This class defines a single cell that can be simulated by itself or\n connected with synapses to build a network. A cell is made up of several branches\n and supports intricate cell morphologies.\n \"\"\"\n\n cell_params: Dict = {}\n cell_states: Dict = {}\n\n def __init__(\n self,\n branches: Optional[Union[Branch, List[Branch]]] = None,\n parents: Optional[List[int]] = None,\n xyzr: Optional[List[np.ndarray]] = None,\n ):\n \"\"\"Initialize a cell.\n\n Args:\n branches: A single branch or a list of branches that make up the cell.\n If a single branch is provided, then the branch is repeated `len(parents)`\n times to create the cell.\n parents: The parent branch index for each branch. The first branch has no\n parent and is therefore set to -1.\n xyzr: For every branch, the x, y, and z coordinates and the radius at the\n traced coordinates. Note that this is the full tracing (from SWC), not\n the stick representation coordinates.\n \"\"\"\n super().__init__()\n assert (\n isinstance(branches, (Branch, List)) or branches is None\n ), \"Only Branch or List[Branch] is allowed.\"\n if branches is not None:\n assert (\n parents is not None\n ), \"If `branches` is not a list then you have to set `parents`.\"\n if isinstance(branches, List):\n assert len(parents) == len(\n branches\n ), \"Ensure equally many parents, i.e. len(branches) == len(parents).\"\n\n branches = Branch() if branches is None else branches\n parents = [-1] if parents is None else parents\n\n if isinstance(branches, Branch):\n branch_list = [branches for _ in range(len(parents))]\n else:\n branch_list = branches\n\n if xyzr is not None:\n assert len(xyzr) == len(parents)\n self.xyzr = xyzr\n else:\n # For every branch (`len(parents)`), we have a start and end point (`2`) and\n # a (x,y,z,r) coordinate for each of them (`4`).\n # Since `xyzr` is only inspected at `.vis()` and because it depends on the\n # (potentially learned) length of every compartment, we only populate\n # self.xyzr at `.vis()`.\n self.xyzr = [float(\"NaN\") * np.zeros((2, 4)) for _ in range(len(parents))]\n\n self.total_nbranches = len(branch_list)\n self.nbranches_per_cell = [len(branch_list)]\n self.comb_parents = jnp.asarray(parents)\n self.comb_children = compute_children_indices(self.comb_parents)\n self._cumsum_nbranches = np.asarray([0, len(branch_list)])\n\n # Compartment structure. These arguments have to be rebuilt when `.set_ncomp()`\n # is run.\n self.ncomp_per_branch = np.asarray([branch.ncomp for branch in branch_list])\n self.ncomp = int(np.max(self.ncomp_per_branch))\n self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n self._internal_node_inds = np.arange(self.cumsum_ncomp[-1])\n\n # Build nodes. Has to be changed when `.set_ncomp()` is run.\n self.nodes = pd.concat([c.nodes for c in branch_list], ignore_index=True)\n self.nodes[\"global_comp_index\"] = np.arange(self.cumsum_ncomp[-1])\n self.nodes[\"global_branch_index\"] = np.repeat(\n np.arange(self.total_nbranches), self.ncomp_per_branch\n ).tolist()\n self.nodes[\"global_cell_index\"] = np.repeat(0, self.cumsum_ncomp[-1]).tolist()\n self._update_local_indices()\n self._init_view()\n\n # Appending general parameters (radius, length, r_a, cm) and channel parameters,\n # as well as the states (v, and channel states).\n self._append_params_and_states(self.cell_params, self.cell_states)\n\n # Channels.\n self._gather_channels_from_constituents(branch_list)\n\n self.branch_edges = pd.DataFrame(\n dict(\n parent_branch_index=self.comb_parents[1:],\n child_branch_index=np.arange(1, self.total_nbranches),\n )\n )\n\n # For morphology indexing.\n self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n compute_children_and_parents(self.branch_edges)\n )\n\n self._initialize()\n\n def _init_morph_jaxley_spsolve(self):\n \"\"\"Initialize morphology for the custom sparse solver.\n\n Running this function is only required for custom Jaxley solvers, i.e., for\n `voltage_solver={'jaxley.stone', 'jaxley.thomas'}`. However, because at\n `.__init__()` (when the function is run), we do not yet know which solver the\n user will use. Therefore, we always run this function at `.__init__()`.\n \"\"\"\n children_and_parents = compute_morphology_indices_in_levels(\n len(self._par_inds),\n self._child_belongs_to_branchpoint,\n self._par_inds,\n self._child_inds,\n )\n branchpoint_group_inds = build_branchpoint_group_inds(\n len(self._par_inds),\n self._child_belongs_to_branchpoint,\n self.cumsum_ncomp[-1],\n )\n parents = self.comb_parents\n children_inds = children_and_parents[\"children\"]\n parents_inds = children_and_parents[\"parents\"]\n\n levels = compute_levels(parents)\n children_in_level = compute_children_in_level(levels, children_inds)\n parents_in_level = compute_parents_in_level(\n levels, self._par_inds, parents_inds\n )\n levels_and_ncomp = pd.DataFrame().from_dict(\n {\n \"levels\": levels,\n \"ncomps\": self.ncomp_per_branch,\n }\n )\n levels_and_ncomp[\"max_ncomp_in_level\"] = levels_and_ncomp.groupby(\"levels\")[\n \"ncomps\"\n ].transform(\"max\")\n padded_cumsum_ncomp = cumsum_leading_zero(\n levels_and_ncomp[\"max_ncomp_in_level\"].to_numpy()\n )\n\n # Generate mapping to deal with the masking which allows using the custom\n # sparse solver to deal with different ncomp per branch.\n remapped_node_indices = remap_index_to_masked(\n self._internal_node_inds,\n self.nodes,\n padded_cumsum_ncomp,\n self.ncomp_per_branch,\n )\n self._solve_indexer = JaxleySolveIndexer(\n cumsum_ncomp=padded_cumsum_ncomp,\n branchpoint_group_inds=branchpoint_group_inds,\n children_in_level=children_in_level,\n parents_in_level=parents_in_level,\n root_inds=np.asarray([0]),\n remapped_node_indices=remapped_node_indices,\n )\n\n def _init_morph_jax_spsolve(self):\n \"\"\"For morphology indexing with the `jax.sparse` voltage volver.\n\n Explanation of `self._comp_eges['type']`:\n `type == 0`: compartment <--> compartment (within branch)\n `type == 1`: branchpoint --> parent-compartment\n `type == 2`: branchpoint --> child-compartment\n `type == 3`: parent-compartment --> branchpoint\n `type == 4`: child-compartment --> branchpoint\n\n Running this function is only required for generic sparse solvers, i.e., for\n `voltage_solver='jax.sparse'`.\n \"\"\"\n\n # Edges between compartments within the branches.\n self._comp_edges = pd.concat(\n [\n pd.DataFrame()\n .from_dict(\n {\n \"source\": list(range(cumsum_ncomp, ncomp - 1 + cumsum_ncomp))\n + list(range(1 + cumsum_ncomp, ncomp + cumsum_ncomp)),\n \"sink\": list(range(1 + cumsum_ncomp, ncomp + cumsum_ncomp))\n + list(range(cumsum_ncomp, ncomp - 1 + cumsum_ncomp)),\n }\n )\n .astype(int)\n for ncomp, cumsum_ncomp in zip(self.ncomp_per_branch, self.cumsum_ncomp)\n ]\n )\n self._comp_edges[\"type\"] = 0\n\n # Edges from branchpoints to compartments.\n branchpoint_to_parent_edges = pd.DataFrame().from_dict(\n {\n \"source\": np.arange(len(self._par_inds)) + self.cumsum_ncomp[-1],\n \"sink\": self.cumsum_ncomp[self._par_inds + 1] - 1,\n \"type\": 1,\n }\n )\n branchpoint_to_child_edges = pd.DataFrame().from_dict(\n {\n \"source\": self._child_belongs_to_branchpoint + self.cumsum_ncomp[-1],\n \"sink\": self.cumsum_ncomp[self._child_inds],\n \"type\": 2,\n }\n )\n self._comp_edges = pd.concat(\n [\n self._comp_edges,\n branchpoint_to_parent_edges,\n branchpoint_to_child_edges,\n ],\n ignore_index=True,\n )\n\n # Edges from compartments to branchpoints.\n parent_to_branchpoint_edges = branchpoint_to_parent_edges.rename(\n columns={\"sink\": \"source\", \"source\": \"sink\"}\n )\n parent_to_branchpoint_edges[\"type\"] = 3\n child_to_branchpoint_edges = branchpoint_to_child_edges.rename(\n columns={\"sink\": \"source\", \"source\": \"sink\"}\n )\n child_to_branchpoint_edges[\"type\"] = 4\n\n self._comp_edges = pd.concat(\n [\n self._comp_edges,\n parent_to_branchpoint_edges,\n child_to_branchpoint_edges,\n ],\n ignore_index=True,\n )\n\n n_nodes, data_inds, indices, indptr = comp_edges_to_indices(self._comp_edges)\n self._n_nodes = n_nodes\n self._data_inds = data_inds\n self._indices_jax_spsolve = indices\n self._indptr_jax_spsolve = indptr\n
"},{"location":"reference/modules/#jaxley.modules.cell.Cell.__init__","title":"__init__(branches=None, parents=None, xyzr=None)
","text":"Initialize a cell.
Parameters:
Name Type Description Defaultbranches
Optional[Union[Branch, List[Branch]]]
A single branch or a list of branches that make up the cell. If a single branch is provided, then the branch is repeated len(parents)
times to create the cell.
None
parents
Optional[List[int]]
The parent branch index for each branch. The first branch has no parent and is therefore set to -1.
None
xyzr
Optional[List[ndarray]]
For every branch, the x, y, and z coordinates and the radius at the traced coordinates. Note that this is the full tracing (from SWC), not the stick representation coordinates.
None
Source code in jaxley/modules/cell.py
def __init__(\n self,\n branches: Optional[Union[Branch, List[Branch]]] = None,\n parents: Optional[List[int]] = None,\n xyzr: Optional[List[np.ndarray]] = None,\n):\n \"\"\"Initialize a cell.\n\n Args:\n branches: A single branch or a list of branches that make up the cell.\n If a single branch is provided, then the branch is repeated `len(parents)`\n times to create the cell.\n parents: The parent branch index for each branch. The first branch has no\n parent and is therefore set to -1.\n xyzr: For every branch, the x, y, and z coordinates and the radius at the\n traced coordinates. Note that this is the full tracing (from SWC), not\n the stick representation coordinates.\n \"\"\"\n super().__init__()\n assert (\n isinstance(branches, (Branch, List)) or branches is None\n ), \"Only Branch or List[Branch] is allowed.\"\n if branches is not None:\n assert (\n parents is not None\n ), \"If `branches` is not a list then you have to set `parents`.\"\n if isinstance(branches, List):\n assert len(parents) == len(\n branches\n ), \"Ensure equally many parents, i.e. len(branches) == len(parents).\"\n\n branches = Branch() if branches is None else branches\n parents = [-1] if parents is None else parents\n\n if isinstance(branches, Branch):\n branch_list = [branches for _ in range(len(parents))]\n else:\n branch_list = branches\n\n if xyzr is not None:\n assert len(xyzr) == len(parents)\n self.xyzr = xyzr\n else:\n # For every branch (`len(parents)`), we have a start and end point (`2`) and\n # a (x,y,z,r) coordinate for each of them (`4`).\n # Since `xyzr` is only inspected at `.vis()` and because it depends on the\n # (potentially learned) length of every compartment, we only populate\n # self.xyzr at `.vis()`.\n self.xyzr = [float(\"NaN\") * np.zeros((2, 4)) for _ in range(len(parents))]\n\n self.total_nbranches = len(branch_list)\n self.nbranches_per_cell = [len(branch_list)]\n self.comb_parents = jnp.asarray(parents)\n self.comb_children = compute_children_indices(self.comb_parents)\n self._cumsum_nbranches = np.asarray([0, len(branch_list)])\n\n # Compartment structure. These arguments have to be rebuilt when `.set_ncomp()`\n # is run.\n self.ncomp_per_branch = np.asarray([branch.ncomp for branch in branch_list])\n self.ncomp = int(np.max(self.ncomp_per_branch))\n self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n self._internal_node_inds = np.arange(self.cumsum_ncomp[-1])\n\n # Build nodes. Has to be changed when `.set_ncomp()` is run.\n self.nodes = pd.concat([c.nodes for c in branch_list], ignore_index=True)\n self.nodes[\"global_comp_index\"] = np.arange(self.cumsum_ncomp[-1])\n self.nodes[\"global_branch_index\"] = np.repeat(\n np.arange(self.total_nbranches), self.ncomp_per_branch\n ).tolist()\n self.nodes[\"global_cell_index\"] = np.repeat(0, self.cumsum_ncomp[-1]).tolist()\n self._update_local_indices()\n self._init_view()\n\n # Appending general parameters (radius, length, r_a, cm) and channel parameters,\n # as well as the states (v, and channel states).\n self._append_params_and_states(self.cell_params, self.cell_states)\n\n # Channels.\n self._gather_channels_from_constituents(branch_list)\n\n self.branch_edges = pd.DataFrame(\n dict(\n parent_branch_index=self.comb_parents[1:],\n child_branch_index=np.arange(1, self.total_nbranches),\n )\n )\n\n # For morphology indexing.\n self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n compute_children_and_parents(self.branch_edges)\n )\n\n self._initialize()\n
"},{"location":"reference/modules/#network","title":"Network","text":" Bases: Module
Network class.
This class defines a network of cells that can be connected with synapses.
Source code injaxley/modules/network.py
class Network(Module):\n \"\"\"Network class.\n\n This class defines a network of cells that can be connected with synapses.\n \"\"\"\n\n network_params: Dict = {}\n network_states: Dict = {}\n\n def __init__(\n self,\n cells: List[Cell],\n ):\n \"\"\"Initialize network of cells and synapses.\n\n Args:\n cells: A list of cells that make up the network.\n \"\"\"\n super().__init__()\n for cell in cells:\n self.xyzr += deepcopy(cell.xyzr)\n\n self._cells_list = cells\n self.ncomp_per_branch = np.concatenate(\n [cell.ncomp_per_branch for cell in cells]\n )\n self.ncomp = int(np.max(self.ncomp_per_branch))\n self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n self._internal_node_inds = np.arange(self.cumsum_ncomp[-1])\n self._append_params_and_states(self.network_params, self.network_states)\n\n self.nbranches_per_cell = [cell.total_nbranches for cell in cells]\n self.total_nbranches = sum(self.nbranches_per_cell)\n self._cumsum_nbranches = cumsum_leading_zero(self.nbranches_per_cell)\n\n self.nodes = pd.concat([c.nodes for c in cells], ignore_index=True)\n self.nodes[\"global_comp_index\"] = np.arange(self.cumsum_ncomp[-1])\n self.nodes[\"global_branch_index\"] = np.repeat(\n np.arange(self.total_nbranches), self.ncomp_per_branch\n ).tolist()\n self.nodes[\"global_cell_index\"] = list(\n itertools.chain(\n *[[i] * int(cell.cumsum_ncomp[-1]) for i, cell in enumerate(cells)]\n )\n )\n self._update_local_indices()\n self._init_view()\n\n parents = [cell.comb_parents for cell in cells]\n self.comb_parents = jnp.concatenate(\n [p.at[1:].add(self._cumsum_nbranches[i]) for i, p in enumerate(parents)]\n )\n\n # Two columns: `parent_branch_index` and `child_branch_index`. One row per\n # branch, apart from those branches which do not have a parent (i.e.\n # -1 in parents). For every branch, tracks the global index of that branch\n # (`child_branch_index`) and the global index of its parent\n # (`parent_branch_index`).\n self.branch_edges = pd.DataFrame(\n dict(\n parent_branch_index=self.comb_parents[self.comb_parents != -1],\n child_branch_index=np.where(self.comb_parents != -1)[0],\n )\n )\n\n # For morphology indexing of both `jax.sparse` and the custom `jaxley` solvers.\n self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n compute_children_and_parents(self.branch_edges)\n )\n\n # `nbranchpoints` in each cell == cell._par_inds (because `par_inds` are unique).\n nbranchpoints = jnp.asarray([len(cell._par_inds) for cell in cells])\n self._cumsum_nbranchpoints_per_cell = cumsum_leading_zero(nbranchpoints)\n\n # Channels.\n self._gather_channels_from_constituents(cells)\n\n self._initialize()\n del self._cells_list\n\n def __repr__(self):\n return f\"{type(self).__name__} with {len(self.channels)} different channels and {len(self.synapses)} synapses. Use `.nodes` or `.edges` for details.\"\n\n def _init_morph_jaxley_spsolve(self):\n branchpoint_group_inds = build_branchpoint_group_inds(\n len(self._par_inds),\n self._child_belongs_to_branchpoint,\n self.cumsum_ncomp[-1],\n )\n children_in_level = merge_cells(\n self._cumsum_nbranches,\n self._cumsum_nbranchpoints_per_cell,\n [cell._solve_indexer.children_in_level for cell in self._cells_list],\n exclude_first=False,\n )\n parents_in_level = merge_cells(\n self._cumsum_nbranches,\n self._cumsum_nbranchpoints_per_cell,\n [cell._solve_indexer.parents_in_level for cell in self._cells_list],\n exclude_first=False,\n )\n padded_cumsum_ncomp = cumsum_leading_zero(\n np.concatenate(\n [np.diff(cell._solve_indexer.cumsum_ncomp) for cell in self._cells_list]\n )\n )\n\n # Generate mapping to dealing with the masking which allows using the custom\n # sparse solver to deal with different ncomp per branch.\n remapped_node_indices = remap_index_to_masked(\n self._internal_node_inds,\n self.nodes,\n padded_cumsum_ncomp,\n self.ncomp_per_branch,\n )\n self._solve_indexer = JaxleySolveIndexer(\n cumsum_ncomp=padded_cumsum_ncomp,\n branchpoint_group_inds=branchpoint_group_inds,\n children_in_level=children_in_level,\n parents_in_level=parents_in_level,\n root_inds=self._cumsum_nbranches[:-1],\n remapped_node_indices=remapped_node_indices,\n )\n\n def _init_morph_jax_spsolve(self):\n \"\"\"Initialize the morphology for networks.\n\n The reason that this function is a bit involved for a `Network` is that Jaxley\n considers branchpoint nodes to be at the very end of __all__ nodes (i.e. the\n branchpoints of the first cell are even after the compartments of the second\n cell. The reason for this is that, otherwise, `cumsum_ncomp` becomes tricky).\n\n To achieve this, we first loop over all compartments and append them, and then\n loop over all branchpoints and append those. The code for building the indices\n from the `comp_edges` is identical to `jx.Cell`.\n\n Explanation of `self._comp_eges['type']`:\n `type == 0`: compartment <--> compartment (within branch)\n `type == 1`: branchpoint --> parent-compartment\n `type == 2`: branchpoint --> child-compartment\n `type == 3`: parent-compartment --> branchpoint\n `type == 4`: child-compartment --> branchpoint\n \"\"\"\n self._cumsum_ncomp_per_cell = cumsum_leading_zero(\n jnp.asarray([cell.cumsum_ncomp[-1] for cell in self.cells])\n )\n self._comp_edges = pd.DataFrame()\n\n # Add all the internal nodes.\n for offset, cell in zip(self._cumsum_ncomp_per_cell, self._cells_list):\n condition = cell._comp_edges[\"type\"].to_numpy() == 0\n rows = cell._comp_edges[condition]\n self._comp_edges = pd.concat(\n [self._comp_edges, [offset, offset, 0] + rows], ignore_index=True\n )\n\n # All branchpoint-to-compartment nodes.\n start_branchpoints = self.cumsum_ncomp[-1] # Index of the first branchpoint.\n for offset, offset_branchpoints, cell in zip(\n self._cumsum_ncomp_per_cell,\n self._cumsum_nbranchpoints_per_cell,\n self._cells_list,\n ):\n offset_within_cell = cell.cumsum_ncomp[-1]\n condition = cell._comp_edges[\"type\"].isin([1, 2])\n rows = cell._comp_edges[condition]\n self._comp_edges = pd.concat(\n [\n self._comp_edges,\n [\n start_branchpoints - offset_within_cell + offset_branchpoints,\n offset,\n 0,\n ]\n + rows,\n ],\n ignore_index=True,\n )\n\n # All compartment-to-branchpoint nodes.\n for offset, offset_branchpoints, cell in zip(\n self._cumsum_ncomp_per_cell,\n self._cumsum_nbranchpoints_per_cell,\n self._cells_list,\n ):\n offset_within_cell = cell.cumsum_ncomp[-1]\n condition = cell._comp_edges[\"type\"].isin([3, 4])\n rows = cell._comp_edges[condition]\n self._comp_edges = pd.concat(\n [\n self._comp_edges,\n [\n offset,\n start_branchpoints - offset_within_cell + offset_branchpoints,\n 0,\n ]\n + rows,\n ],\n ignore_index=True,\n )\n\n # Convert comp_edges to the index format required for `jax.sparse` solvers.\n n_nodes, data_inds, indices, indptr = comp_edges_to_indices(self._comp_edges)\n self._n_nodes = n_nodes\n self._data_inds = data_inds\n self._indices_jax_spsolve = indices\n self._indptr_jax_spsolve = indptr\n\n def _step_synapse(\n self,\n states: Dict,\n syn_channels: List,\n params: Dict,\n delta_t: float,\n edges: pd.DataFrame,\n ) -> Tuple[Dict, Tuple[jnp.ndarray, jnp.ndarray]]:\n \"\"\"Perform one step of the synapses and obtain their currents.\"\"\"\n states = self._step_synapse_state(states, syn_channels, params, delta_t, edges)\n states, current_terms = self._synapse_currents(\n states, syn_channels, params, delta_t, edges\n )\n return states, current_terms\n\n def _step_synapse_state(\n self,\n states: Dict,\n syn_channels: List,\n params: Dict,\n delta_t: float,\n edges: pd.DataFrame,\n ) -> Dict:\n voltages = states[\"v\"]\n\n grouped_syns = edges.groupby(\"type\", sort=False, group_keys=False)\n pre_syn_inds = grouped_syns[\"pre_global_comp_index\"].apply(list)\n post_syn_inds = grouped_syns[\"post_global_comp_index\"].apply(list)\n synapse_names = list(grouped_syns.indices.keys())\n\n for i, synapse_type in enumerate(syn_channels):\n assert (\n synapse_names[i] == synapse_type._name\n ), \"Mixup in the ordering of synapses. Please create an issue on Github.\"\n synapse_param_names = list(synapse_type.synapse_params.keys())\n synapse_state_names = list(synapse_type.synapse_states.keys())\n\n synapse_params = {}\n for p in synapse_param_names:\n synapse_params[p] = params[p]\n synapse_states = {}\n for s in synapse_state_names:\n synapse_states[s] = states[s]\n\n pre_inds = np.asarray(pre_syn_inds[synapse_names[i]])\n post_inds = np.asarray(post_syn_inds[synapse_names[i]])\n\n # State updates.\n states_updated = synapse_type.update_states(\n synapse_states,\n delta_t,\n voltages[pre_inds],\n voltages[post_inds],\n synapse_params,\n )\n\n # Rebuild state.\n for key, val in states_updated.items():\n states[key] = val\n\n return states\n\n def _synapse_currents(\n self,\n states: Dict,\n syn_channels: List,\n params: Dict,\n delta_t: float,\n edges: pd.DataFrame,\n ) -> Tuple[Dict, Tuple[jnp.ndarray, jnp.ndarray]]:\n voltages = states[\"v\"]\n\n grouped_syns = edges.groupby(\"type\", sort=False, group_keys=False)\n pre_syn_inds = grouped_syns[\"pre_global_comp_index\"].apply(list)\n post_syn_inds = grouped_syns[\"post_global_comp_index\"].apply(list)\n synapse_names = list(grouped_syns.indices.keys())\n\n syn_voltage_terms = jnp.zeros_like(voltages)\n syn_constant_terms = jnp.zeros_like(voltages)\n # Run with two different voltages that are `diff` apart to infer the slope and\n # offset.\n diff = 1e-3\n for i, synapse_type in enumerate(syn_channels):\n assert (\n synapse_names[i] == synapse_type._name\n ), \"Mixup in the ordering of synapses. Please create an issue on Github.\"\n synapse_param_names = list(synapse_type.synapse_params.keys())\n synapse_state_names = list(synapse_type.synapse_states.keys())\n\n synapse_params = {}\n for p in synapse_param_names:\n synapse_params[p] = params[p]\n synapse_states = {}\n for s in synapse_state_names:\n synapse_states[s] = states[s]\n\n # Get pre and post indexes of the current synapse type.\n pre_inds = np.asarray(pre_syn_inds[synapse_names[i]])\n post_inds = np.asarray(post_syn_inds[synapse_names[i]])\n\n # Compute slope and offset of the current through every synapse.\n pre_v_and_perturbed = jnp.stack(\n [voltages[pre_inds], voltages[pre_inds] + diff]\n )\n post_v_and_perturbed = jnp.stack(\n [voltages[post_inds], voltages[post_inds] + diff]\n )\n synapse_currents = vmap(\n synapse_type.compute_current, in_axes=(None, 0, 0, None)\n )(\n synapse_states,\n pre_v_and_perturbed,\n post_v_and_perturbed,\n synapse_params,\n )\n synapse_currents_dist = convert_point_process_to_distributed(\n synapse_currents,\n params[\"radius\"][post_inds],\n params[\"length\"][post_inds],\n )\n\n # Split into voltage and constant terms.\n voltage_term = (synapse_currents_dist[1] - synapse_currents_dist[0]) / diff\n constant_term = (\n synapse_currents_dist[0] - voltage_term * voltages[post_inds]\n )\n\n # Gather slope and offset for every postsynaptic compartment.\n gathered_syn_currents = gather_synapes(\n len(voltages),\n post_inds,\n voltage_term,\n constant_term,\n )\n syn_voltage_terms += gathered_syn_currents[0]\n syn_constant_terms -= gathered_syn_currents[1]\n\n # Add the synaptic currents through every compartment as state.\n # `post_syn_currents` is a `jnp.ndarray` of as many elements as there are\n # compartments in the network.\n # `[0]` because we only use the non-perturbed voltage.\n states[f\"{synapse_type._name}_current\"] = synapse_currents[0]\n\n return states, (syn_voltage_terms, syn_constant_terms)\n\n def vis(\n self,\n detail: str = \"full\",\n ax: Optional[Axes] = None,\n col: str = \"k\",\n synapse_col: str = \"b\",\n dims: Tuple[int] = (0, 1),\n type: str = \"line\",\n layers: Optional[List] = None,\n morph_plot_kwargs: Dict = {},\n synapse_plot_kwargs: Dict = {},\n synapse_scatter_kwargs: Dict = {},\n networkx_options: Dict = {},\n layer_kwargs: Dict = {},\n ) -> Axes:\n \"\"\"Visualize the module.\n\n Args:\n detail: Either of [point, full]. `point` visualizes every neuron in the\n network as a dot (and it uses `networkx` to obtain cell positions).\n `full` plots the full morphology of every neuron. It requires that\n `compute_xyz()` has been run and allows for indivual neurons to be\n moved with `.move()`.\n col: The color in which cells are plotted. Only takes effect if\n `detail='full'`.\n type: Either `line` or `scatter`. Only takes effect if `detail='full'`.\n synapse_col: The color in which synapses are plotted. Only takes effect if\n `detail='full'`.\n dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n two of them.\n layers: Allows to plot the network in layers. Should provide the number of\n neurons in each layer, e.g., [5, 10, 1] would be a network with 5 input\n neurons, 10 hidden layer neurons, and 1 output neuron.\n morph_plot_kwargs: Keyword arguments passed to the plotting function for\n cell morphologies. Only takes effect for `detail='full'`.\n synapse_plot_kwargs: Keyword arguments passed to the plotting function for\n syanpses. Only takes effect for `detail='full'`.\n synapse_scatter_kwargs: Keyword arguments passed to the scatter function\n for the end point of synapses. Only takes effect for `detail='full'`.\n networkx_options: Options passed to `networkx.draw()`. Only takes effect if\n `detail='point'`.\n layer_kwargs: Only used if `layers` is specified and if `detail='full'`.\n Can have the following entries: `within_layer_offset` (float),\n `between_layer_offset` (float), `vertical_layers` (bool).\n \"\"\"\n if detail == \"point\":\n graph = self._build_graph(layers)\n\n if layers is not None:\n pos = nx.multipartite_layout(graph, subset_key=\"layer\")\n nx.draw(graph, pos, with_labels=True, **networkx_options)\n else:\n nx.draw(graph, with_labels=True, **networkx_options)\n elif detail == \"full\":\n if layers is not None:\n # Assemble cells in the network into layers.\n global_counter = 0\n layers_config = {\n \"within_layer_offset\": 500.0,\n \"between_layer_offset\": 1500.0,\n \"vertical_layers\": False,\n }\n layers_config.update(layer_kwargs)\n for layer_ind, num_in_layer in enumerate(layers):\n for ind_within_layer in range(num_in_layer):\n if layers_config[\"vertical_layers\"]:\n x_offset = (\n ind_within_layer - (num_in_layer - 1) / 2\n ) * layers_config[\"within_layer_offset\"]\n y_offset = (len(layers) - 1 - layer_ind) * layers_config[\n \"between_layer_offset\"\n ]\n else:\n x_offset = layer_ind * layers_config[\"between_layer_offset\"]\n y_offset = (\n ind_within_layer - (num_in_layer - 1) / 2\n ) * layers_config[\"within_layer_offset\"]\n\n self.cell(global_counter).move_to(x=x_offset, y=y_offset, z=0)\n global_counter += 1\n ax = super().vis(\n dims=dims,\n col=col,\n ax=ax,\n type=type,\n morph_plot_kwargs=morph_plot_kwargs,\n )\n\n pre_locs = self.edges[\"pre_locs\"].to_numpy()\n post_locs = self.edges[\"post_locs\"].to_numpy()\n pre_comp = self.edges[\"pre_global_comp_index\"].to_numpy()\n nodes = self.nodes.set_index(\"global_comp_index\")\n pre_branch = nodes.loc[pre_comp, \"global_branch_index\"].to_numpy()\n post_comp = self.edges[\"post_global_comp_index\"].to_numpy()\n post_branch = nodes.loc[post_comp, \"global_branch_index\"].to_numpy()\n\n dims_np = np.asarray(dims)\n\n for pre_loc, post_loc, pre_b, post_b in zip(\n pre_locs, post_locs, pre_branch, post_branch\n ):\n pre_coord = self.xyzr[pre_b]\n if len(pre_coord) == 2:\n # If only start and end point of a branch are traced, perform a\n # linear interpolation to get the synpase location.\n pre_coord = pre_coord[0] + (pre_coord[1] - pre_coord[0]) * pre_loc\n else:\n # If densely traced, use intermediate trace values for synapse loc.\n middle_ind = int((len(pre_coord) - 1) * pre_loc)\n pre_coord = pre_coord[middle_ind]\n\n post_coord = self.xyzr[post_b]\n if len(post_coord) == 2:\n # If only start and end point of a branch are traced, perform a\n # linear interpolation to get the synpase location.\n post_coord = (\n post_coord[0] + (post_coord[1] - post_coord[0]) * post_loc\n )\n else:\n # If densely traced, use intermediate trace values for synapse loc.\n middle_ind = int((len(post_coord) - 1) * post_loc)\n post_coord = post_coord[middle_ind]\n\n coords = np.stack([pre_coord[dims_np], post_coord[dims_np]]).T\n ax.plot(\n coords[0],\n coords[1],\n c=synapse_col,\n **synapse_plot_kwargs,\n )\n ax.scatter(\n post_coord[dims_np[0]],\n post_coord[dims_np[1]],\n c=synapse_col,\n **synapse_scatter_kwargs,\n )\n else:\n raise ValueError(\"detail must be in {full, point}.\")\n\n return ax\n\n def _build_graph(self, layers: Optional[List] = None, **options):\n graph = nx.DiGraph()\n\n def build_extents(*subset_sizes):\n return nx.utils.pairwise(itertools.accumulate((0,) + subset_sizes))\n\n if layers is not None:\n extents = build_extents(*layers)\n layers = [range(start, end) for start, end in extents]\n for i, layer in enumerate(layers):\n graph.add_nodes_from(layer, layer=i)\n else:\n graph.add_nodes_from(range(len(self._cells_in_view)))\n\n pre_comp = self.edges[\"pre_global_comp_index\"].to_numpy()\n nodes = self.nodes.set_index(\"global_comp_index\")\n pre_cell = nodes.loc[pre_comp, \"global_cell_index\"].to_numpy()\n post_comp = self.edges[\"post_global_comp_index\"].to_numpy()\n post_cell = nodes.loc[post_comp, \"global_cell_index\"].to_numpy()\n\n inds = np.stack([pre_cell, post_cell]).T\n graph.add_edges_from(inds)\n\n return graph\n\n def _infer_synapse_type_ind(self, synapse_name):\n syn_names = self.base.synapse_names\n is_new_type = False if synapse_name in syn_names else True\n type_ind = len(syn_names) if is_new_type else syn_names.index(synapse_name)\n return type_ind, is_new_type\n\n def _update_synapse_state_names(self, synapse_type):\n # (Potentially) update variables that track meta information about synapses.\n self.base.synapse_names.append(synapse_type._name)\n self.base.synapse_param_names += list(synapse_type.synapse_params.keys())\n self.base.synapse_state_names += list(synapse_type.synapse_states.keys())\n self.base.synapses.append(synapse_type)\n\n def _append_multiple_synapses(self, pre_nodes, post_nodes, synapse_type):\n # Add synapse types to the module and infer their unique identifier.\n synapse_name = synapse_type._name\n type_ind, is_new = self._infer_synapse_type_ind(synapse_name)\n if is_new: # synapse is not known\n self._update_synapse_state_names(synapse_type)\n\n index = len(self.base.edges)\n indices = [idx for idx in range(index, index + len(pre_nodes))]\n global_edge_index = pd.DataFrame({\"global_edge_index\": indices})\n post_loc = loc_of_index(\n post_nodes[\"global_comp_index\"].to_numpy(),\n post_nodes[\"global_branch_index\"].to_numpy(),\n self.ncomp_per_branch,\n )\n pre_loc = loc_of_index(\n pre_nodes[\"global_comp_index\"].to_numpy(),\n pre_nodes[\"global_branch_index\"].to_numpy(),\n self.ncomp_per_branch,\n )\n\n # Define new synapses. Each row is one synapse.\n pre_nodes = pre_nodes[[\"global_comp_index\"]]\n pre_nodes.columns = [\"pre_global_comp_index\"]\n post_nodes = post_nodes[[\"global_comp_index\"]]\n post_nodes.columns = [\"post_global_comp_index\"]\n new_rows = pd.concat(\n [\n global_edge_index,\n pre_nodes.reset_index(drop=True),\n post_nodes.reset_index(drop=True),\n ],\n axis=1,\n )\n new_rows[\"type\"] = synapse_name\n new_rows[\"type_ind\"] = type_ind\n new_rows[\"pre_locs\"] = pre_loc\n new_rows[\"post_locs\"] = post_loc\n self.base.edges = concat_and_ignore_empty(\n [self.base.edges, new_rows], ignore_index=True, axis=0\n )\n self._add_params_to_edges(synapse_type, indices)\n self.base.edges[\"controlled_by_param\"] = 0\n self._edges_in_view = self.edges.index.to_numpy()\n\n def _add_params_to_edges(self, synapse_type, indices):\n # Add parameters and states to the `.edges` table.\n for key, param_val in synapse_type.synapse_params.items():\n self.base.edges.loc[indices, key] = param_val\n\n # Update synaptic state array.\n for key, state_val in synapse_type.synapse_states.items():\n self.base.edges.loc[indices, key] = state_val\n
"},{"location":"reference/modules/#jaxley.modules.network.Network.__init__","title":"__init__(cells)
","text":"Initialize network of cells and synapses.
Parameters:
Name Type Description Defaultcells
List[Cell]
A list of cells that make up the network.
required Source code injaxley/modules/network.py
def __init__(\n self,\n cells: List[Cell],\n):\n \"\"\"Initialize network of cells and synapses.\n\n Args:\n cells: A list of cells that make up the network.\n \"\"\"\n super().__init__()\n for cell in cells:\n self.xyzr += deepcopy(cell.xyzr)\n\n self._cells_list = cells\n self.ncomp_per_branch = np.concatenate(\n [cell.ncomp_per_branch for cell in cells]\n )\n self.ncomp = int(np.max(self.ncomp_per_branch))\n self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)\n self._internal_node_inds = np.arange(self.cumsum_ncomp[-1])\n self._append_params_and_states(self.network_params, self.network_states)\n\n self.nbranches_per_cell = [cell.total_nbranches for cell in cells]\n self.total_nbranches = sum(self.nbranches_per_cell)\n self._cumsum_nbranches = cumsum_leading_zero(self.nbranches_per_cell)\n\n self.nodes = pd.concat([c.nodes for c in cells], ignore_index=True)\n self.nodes[\"global_comp_index\"] = np.arange(self.cumsum_ncomp[-1])\n self.nodes[\"global_branch_index\"] = np.repeat(\n np.arange(self.total_nbranches), self.ncomp_per_branch\n ).tolist()\n self.nodes[\"global_cell_index\"] = list(\n itertools.chain(\n *[[i] * int(cell.cumsum_ncomp[-1]) for i, cell in enumerate(cells)]\n )\n )\n self._update_local_indices()\n self._init_view()\n\n parents = [cell.comb_parents for cell in cells]\n self.comb_parents = jnp.concatenate(\n [p.at[1:].add(self._cumsum_nbranches[i]) for i, p in enumerate(parents)]\n )\n\n # Two columns: `parent_branch_index` and `child_branch_index`. One row per\n # branch, apart from those branches which do not have a parent (i.e.\n # -1 in parents). For every branch, tracks the global index of that branch\n # (`child_branch_index`) and the global index of its parent\n # (`parent_branch_index`).\n self.branch_edges = pd.DataFrame(\n dict(\n parent_branch_index=self.comb_parents[self.comb_parents != -1],\n child_branch_index=np.where(self.comb_parents != -1)[0],\n )\n )\n\n # For morphology indexing of both `jax.sparse` and the custom `jaxley` solvers.\n self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (\n compute_children_and_parents(self.branch_edges)\n )\n\n # `nbranchpoints` in each cell == cell._par_inds (because `par_inds` are unique).\n nbranchpoints = jnp.asarray([len(cell._par_inds) for cell in cells])\n self._cumsum_nbranchpoints_per_cell = cumsum_leading_zero(nbranchpoints)\n\n # Channels.\n self._gather_channels_from_constituents(cells)\n\n self._initialize()\n del self._cells_list\n
"},{"location":"reference/modules/#jaxley.modules.network.Network.vis","title":"vis(detail='full', ax=None, col='k', synapse_col='b', dims=(0, 1), type='line', layers=None, morph_plot_kwargs={}, synapse_plot_kwargs={}, synapse_scatter_kwargs={}, networkx_options={}, layer_kwargs={})
","text":"Visualize the module.
Parameters:
Name Type Description Defaultdetail
str
Either of [point, full]. point
visualizes every neuron in the network as a dot (and it uses networkx
to obtain cell positions). full
plots the full morphology of every neuron. It requires that compute_xyz()
has been run and allows for indivual neurons to be moved with .move()
.
'full'
col
str
The color in which cells are plotted. Only takes effect if detail='full'
.
'k'
type
str
Either line
or scatter
. Only takes effect if detail='full'
.
'line'
synapse_col
str
The color in which synapses are plotted. Only takes effect if detail='full'
.
'b'
dims
Tuple[int]
Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them.
(0, 1)
layers
Optional[List]
Allows to plot the network in layers. Should provide the number of neurons in each layer, e.g., [5, 10, 1] would be a network with 5 input neurons, 10 hidden layer neurons, and 1 output neuron.
None
morph_plot_kwargs
Dict
Keyword arguments passed to the plotting function for cell morphologies. Only takes effect for detail='full'
.
{}
synapse_plot_kwargs
Dict
Keyword arguments passed to the plotting function for syanpses. Only takes effect for detail='full'
.
{}
synapse_scatter_kwargs
Dict
Keyword arguments passed to the scatter function for the end point of synapses. Only takes effect for detail='full'
.
{}
networkx_options
Dict
Options passed to networkx.draw()
. Only takes effect if detail='point'
.
{}
layer_kwargs
Dict
Only used if layers
is specified and if detail='full'
. Can have the following entries: within_layer_offset
(float), between_layer_offset
(float), vertical_layers
(bool).
{}
Source code in jaxley/modules/network.py
def vis(\n self,\n detail: str = \"full\",\n ax: Optional[Axes] = None,\n col: str = \"k\",\n synapse_col: str = \"b\",\n dims: Tuple[int] = (0, 1),\n type: str = \"line\",\n layers: Optional[List] = None,\n morph_plot_kwargs: Dict = {},\n synapse_plot_kwargs: Dict = {},\n synapse_scatter_kwargs: Dict = {},\n networkx_options: Dict = {},\n layer_kwargs: Dict = {},\n) -> Axes:\n \"\"\"Visualize the module.\n\n Args:\n detail: Either of [point, full]. `point` visualizes every neuron in the\n network as a dot (and it uses `networkx` to obtain cell positions).\n `full` plots the full morphology of every neuron. It requires that\n `compute_xyz()` has been run and allows for indivual neurons to be\n moved with `.move()`.\n col: The color in which cells are plotted. Only takes effect if\n `detail='full'`.\n type: Either `line` or `scatter`. Only takes effect if `detail='full'`.\n synapse_col: The color in which synapses are plotted. Only takes effect if\n `detail='full'`.\n dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n two of them.\n layers: Allows to plot the network in layers. Should provide the number of\n neurons in each layer, e.g., [5, 10, 1] would be a network with 5 input\n neurons, 10 hidden layer neurons, and 1 output neuron.\n morph_plot_kwargs: Keyword arguments passed to the plotting function for\n cell morphologies. Only takes effect for `detail='full'`.\n synapse_plot_kwargs: Keyword arguments passed to the plotting function for\n syanpses. Only takes effect for `detail='full'`.\n synapse_scatter_kwargs: Keyword arguments passed to the scatter function\n for the end point of synapses. Only takes effect for `detail='full'`.\n networkx_options: Options passed to `networkx.draw()`. Only takes effect if\n `detail='point'`.\n layer_kwargs: Only used if `layers` is specified and if `detail='full'`.\n Can have the following entries: `within_layer_offset` (float),\n `between_layer_offset` (float), `vertical_layers` (bool).\n \"\"\"\n if detail == \"point\":\n graph = self._build_graph(layers)\n\n if layers is not None:\n pos = nx.multipartite_layout(graph, subset_key=\"layer\")\n nx.draw(graph, pos, with_labels=True, **networkx_options)\n else:\n nx.draw(graph, with_labels=True, **networkx_options)\n elif detail == \"full\":\n if layers is not None:\n # Assemble cells in the network into layers.\n global_counter = 0\n layers_config = {\n \"within_layer_offset\": 500.0,\n \"between_layer_offset\": 1500.0,\n \"vertical_layers\": False,\n }\n layers_config.update(layer_kwargs)\n for layer_ind, num_in_layer in enumerate(layers):\n for ind_within_layer in range(num_in_layer):\n if layers_config[\"vertical_layers\"]:\n x_offset = (\n ind_within_layer - (num_in_layer - 1) / 2\n ) * layers_config[\"within_layer_offset\"]\n y_offset = (len(layers) - 1 - layer_ind) * layers_config[\n \"between_layer_offset\"\n ]\n else:\n x_offset = layer_ind * layers_config[\"between_layer_offset\"]\n y_offset = (\n ind_within_layer - (num_in_layer - 1) / 2\n ) * layers_config[\"within_layer_offset\"]\n\n self.cell(global_counter).move_to(x=x_offset, y=y_offset, z=0)\n global_counter += 1\n ax = super().vis(\n dims=dims,\n col=col,\n ax=ax,\n type=type,\n morph_plot_kwargs=morph_plot_kwargs,\n )\n\n pre_locs = self.edges[\"pre_locs\"].to_numpy()\n post_locs = self.edges[\"post_locs\"].to_numpy()\n pre_comp = self.edges[\"pre_global_comp_index\"].to_numpy()\n nodes = self.nodes.set_index(\"global_comp_index\")\n pre_branch = nodes.loc[pre_comp, \"global_branch_index\"].to_numpy()\n post_comp = self.edges[\"post_global_comp_index\"].to_numpy()\n post_branch = nodes.loc[post_comp, \"global_branch_index\"].to_numpy()\n\n dims_np = np.asarray(dims)\n\n for pre_loc, post_loc, pre_b, post_b in zip(\n pre_locs, post_locs, pre_branch, post_branch\n ):\n pre_coord = self.xyzr[pre_b]\n if len(pre_coord) == 2:\n # If only start and end point of a branch are traced, perform a\n # linear interpolation to get the synpase location.\n pre_coord = pre_coord[0] + (pre_coord[1] - pre_coord[0]) * pre_loc\n else:\n # If densely traced, use intermediate trace values for synapse loc.\n middle_ind = int((len(pre_coord) - 1) * pre_loc)\n pre_coord = pre_coord[middle_ind]\n\n post_coord = self.xyzr[post_b]\n if len(post_coord) == 2:\n # If only start and end point of a branch are traced, perform a\n # linear interpolation to get the synpase location.\n post_coord = (\n post_coord[0] + (post_coord[1] - post_coord[0]) * post_loc\n )\n else:\n # If densely traced, use intermediate trace values for synapse loc.\n middle_ind = int((len(post_coord) - 1) * post_loc)\n post_coord = post_coord[middle_ind]\n\n coords = np.stack([pre_coord[dims_np], post_coord[dims_np]]).T\n ax.plot(\n coords[0],\n coords[1],\n c=synapse_col,\n **synapse_plot_kwargs,\n )\n ax.scatter(\n post_coord[dims_np[0]],\n post_coord[dims_np[1]],\n c=synapse_col,\n **synapse_scatter_kwargs,\n )\n else:\n raise ValueError(\"detail must be in {full, point}.\")\n\n return ax\n
"},{"location":"reference/optimize/","title":"Optimization","text":""},{"location":"reference/optimize/#jaxley.optimize.optimizer.TypeOptimizer","title":"TypeOptimizer
","text":"optax
wrapper which allows different argument values for different params.
jaxley/optimize/optimizer.py
class TypeOptimizer:\n \"\"\"`optax` wrapper which allows different argument values for different params.\"\"\"\n\n def __init__(\n self,\n optimizer: Callable,\n optimizer_args: Dict[str, Any],\n opt_params: List[Dict[str, jnp.ndarray]],\n ):\n \"\"\"Create the optimizers.\n\n This requires access to `opt_params` in order to know how many optimizers\n should be created. It creates `len(opt_params)` optimizers.\n\n Example usage:\n ```\n lrs = {\"HH_gNa\": 0.01, \"radius\": 1.0}\n optimizer = TypeOptimizer(lambda lr: optax.adam(lr), lrs, opt_params)\n opt_state = optimizer.init(opt_params)\n ```\n\n ```\n optimizer_args = {\"HH_gNa\": [0.01, 0.4], \"radius\": [1.0, 0.8]}\n optimizer = TypeOptimizer(\n lambda args: optax.sgd(args[0], momentum=args[1]),\n optimizer_args,\n opt_params\n )\n opt_state = optimizer.init(opt_params)\n ```\n\n Args:\n optimizer: A Callable that takes the learning rate and returns the\n `optax.optimizer` which should be used.\n optimizer_args: The arguments for different kinds of parameters.\n Each item of the dictionary will be passed to the `Callable` passed to\n `optimizer`.\n opt_params: The parameters to be optimized. The exact values are not used,\n only the number of elements in the list and the key of each dict.\n \"\"\"\n self.base_optimizer = optimizer\n\n self.optimizers = []\n for params in opt_params:\n names = list(params.keys())\n assert len(names) == 1, \"Multiple parameters were added at once.\"\n name = names[0]\n optimizer = self.base_optimizer(optimizer_args[name])\n self.optimizers.append({name: optimizer})\n\n def init(self, opt_params: List[Dict[str, jnp.ndarray]]) -> List:\n \"\"\"Initialize the optimizers. Equivalent to `optax.optimizers.init()`.\"\"\"\n opt_states = []\n for params, optimizer in zip(opt_params, self.optimizers):\n name = list(optimizer.keys())[0]\n opt_state = optimizer[name].init(params)\n opt_states.append(opt_state)\n return opt_states\n\n def update(self, gradient: jnp.ndarray, opt_state: List) -> Tuple[List, List]:\n \"\"\"Update the optimizers. Equivalent to `optax.optimizers.update()`.\"\"\"\n all_updates = []\n new_opt_states = []\n for grad, state, opt in zip(gradient, opt_state, self.optimizers):\n name = list(opt.keys())[0]\n updates, new_opt_state = opt[name].update(grad, state)\n all_updates.append(updates)\n new_opt_states.append(new_opt_state)\n return all_updates, new_opt_states\n
"},{"location":"reference/optimize/#jaxley.optimize.optimizer.TypeOptimizer.__init__","title":"__init__(optimizer, optimizer_args, opt_params)
","text":"Create the optimizers.
This requires access to opt_params
in order to know how many optimizers should be created. It creates len(opt_params)
optimizers.
Example usage:
lrs = {\"HH_gNa\": 0.01, \"radius\": 1.0}\noptimizer = TypeOptimizer(lambda lr: optax.adam(lr), lrs, opt_params)\nopt_state = optimizer.init(opt_params)\n
optimizer_args = {\"HH_gNa\": [0.01, 0.4], \"radius\": [1.0, 0.8]}\noptimizer = TypeOptimizer(\n lambda args: optax.sgd(args[0], momentum=args[1]),\n optimizer_args,\n opt_params\n)\nopt_state = optimizer.init(opt_params)\n
Parameters:
Name Type Description Defaultoptimizer
Callable
A Callable that takes the learning rate and returns the optax.optimizer
which should be used.
optimizer_args
Dict[str, Any]
The arguments for different kinds of parameters. Each item of the dictionary will be passed to the Callable
passed to optimizer
.
opt_params
List[Dict[str, ndarray]]
The parameters to be optimized. The exact values are not used, only the number of elements in the list and the key of each dict.
required Source code injaxley/optimize/optimizer.py
def __init__(\n self,\n optimizer: Callable,\n optimizer_args: Dict[str, Any],\n opt_params: List[Dict[str, jnp.ndarray]],\n):\n \"\"\"Create the optimizers.\n\n This requires access to `opt_params` in order to know how many optimizers\n should be created. It creates `len(opt_params)` optimizers.\n\n Example usage:\n ```\n lrs = {\"HH_gNa\": 0.01, \"radius\": 1.0}\n optimizer = TypeOptimizer(lambda lr: optax.adam(lr), lrs, opt_params)\n opt_state = optimizer.init(opt_params)\n ```\n\n ```\n optimizer_args = {\"HH_gNa\": [0.01, 0.4], \"radius\": [1.0, 0.8]}\n optimizer = TypeOptimizer(\n lambda args: optax.sgd(args[0], momentum=args[1]),\n optimizer_args,\n opt_params\n )\n opt_state = optimizer.init(opt_params)\n ```\n\n Args:\n optimizer: A Callable that takes the learning rate and returns the\n `optax.optimizer` which should be used.\n optimizer_args: The arguments for different kinds of parameters.\n Each item of the dictionary will be passed to the `Callable` passed to\n `optimizer`.\n opt_params: The parameters to be optimized. The exact values are not used,\n only the number of elements in the list and the key of each dict.\n \"\"\"\n self.base_optimizer = optimizer\n\n self.optimizers = []\n for params in opt_params:\n names = list(params.keys())\n assert len(names) == 1, \"Multiple parameters were added at once.\"\n name = names[0]\n optimizer = self.base_optimizer(optimizer_args[name])\n self.optimizers.append({name: optimizer})\n
"},{"location":"reference/optimize/#jaxley.optimize.optimizer.TypeOptimizer.init","title":"init(opt_params)
","text":"Initialize the optimizers. Equivalent to optax.optimizers.init()
.
jaxley/optimize/optimizer.py
def init(self, opt_params: List[Dict[str, jnp.ndarray]]) -> List:\n \"\"\"Initialize the optimizers. Equivalent to `optax.optimizers.init()`.\"\"\"\n opt_states = []\n for params, optimizer in zip(opt_params, self.optimizers):\n name = list(optimizer.keys())[0]\n opt_state = optimizer[name].init(params)\n opt_states.append(opt_state)\n return opt_states\n
"},{"location":"reference/optimize/#jaxley.optimize.optimizer.TypeOptimizer.update","title":"update(gradient, opt_state)
","text":"Update the optimizers. Equivalent to optax.optimizers.update()
.
jaxley/optimize/optimizer.py
def update(self, gradient: jnp.ndarray, opt_state: List) -> Tuple[List, List]:\n \"\"\"Update the optimizers. Equivalent to `optax.optimizers.update()`.\"\"\"\n all_updates = []\n new_opt_states = []\n for grad, state, opt in zip(gradient, opt_state, self.optimizers):\n name = list(opt.keys())[0]\n updates, new_opt_state = opt[name].update(grad, state)\n all_updates.append(updates)\n new_opt_states.append(new_opt_state)\n return all_updates, new_opt_states\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.AffineTransform","title":"AffineTransform
","text":" Bases: Transform
jaxley/optimize/transforms.py
class AffineTransform(Transform):\n def __init__(self, scale: ArrayLike, shift: ArrayLike):\n \"\"\"This transform rescales and shifts the input.\n\n Args:\n scale (ArrayLike): Scaling factor.\n shift (ArrayLike): Additive shift.\n\n Raises:\n ValueError: Scale needs to be larger than 0\n \"\"\"\n if jnp.allclose(scale, 0):\n raise ValueError(\"a cannot be zero, must be invertible\")\n self.a = scale\n self.b = shift\n\n def forward(self, x: ArrayLike) -> Array:\n return self.a * x + self.b\n\n def inverse(self, x: ArrayLike) -> Array:\n return (x - self.b) / self.a\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.AffineTransform.__init__","title":"__init__(scale, shift)
","text":"This transform rescales and shifts the input.
Parameters:
Name Type Description Defaultscale
ArrayLike
Scaling factor.
requiredshift
ArrayLike
Additive shift.
requiredRaises:
Type DescriptionValueError
Scale needs to be larger than 0
Source code injaxley/optimize/transforms.py
def __init__(self, scale: ArrayLike, shift: ArrayLike):\n \"\"\"This transform rescales and shifts the input.\n\n Args:\n scale (ArrayLike): Scaling factor.\n shift (ArrayLike): Additive shift.\n\n Raises:\n ValueError: Scale needs to be larger than 0\n \"\"\"\n if jnp.allclose(scale, 0):\n raise ValueError(\"a cannot be zero, must be invertible\")\n self.a = scale\n self.b = shift\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.ChainTransform","title":"ChainTransform
","text":" Bases: Transform
Chaining together multiple transformations
Source code injaxley/optimize/transforms.py
class ChainTransform(Transform):\n \"\"\"Chaining together multiple transformations\"\"\"\n\n def __init__(self, transforms: Sequence[Transform]) -> None:\n \"\"\"A chain of transformations\n\n Args:\n transforms (Sequence[Transform]): Transforms to apply\n \"\"\"\n super().__init__()\n self.transforms = transforms\n\n def forward(self, x: ArrayLike) -> Array:\n for transform in self.transforms:\n x = transform(x)\n return x\n\n def inverse(self, y: ArrayLike) -> Array:\n for transform in reversed(self.transforms):\n y = transform.inverse(y)\n return y\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.ChainTransform.__init__","title":"__init__(transforms)
","text":"A chain of transformations
Parameters:
Name Type Description Defaulttransforms
Sequence[Transform]
Transforms to apply
required Source code injaxley/optimize/transforms.py
def __init__(self, transforms: Sequence[Transform]) -> None:\n \"\"\"A chain of transformations\n\n Args:\n transforms (Sequence[Transform]): Transforms to apply\n \"\"\"\n super().__init__()\n self.transforms = transforms\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.CustomTransform","title":"CustomTransform
","text":" Bases: Transform
Custom transformation
Source code injaxley/optimize/transforms.py
class CustomTransform(Transform):\n \"\"\"Custom transformation\"\"\"\n\n def __init__(self, forward_fn: Callable, inverse_fn: Callable) -> None:\n \"\"\"A custom transformation using a user-defined froward and\n inverse function\n\n Args:\n forward_fn (Callable): Forward transformation\n inverse_fn (Callable): Inverse transformation\n \"\"\"\n super().__init__()\n self.forward_fn = forward_fn\n self.inverse_fn = inverse_fn\n\n def forward(self, x: ArrayLike) -> Array:\n return self.forward_fn(x)\n\n def inverse(self, y: ArrayLike) -> Array:\n return self.inverse_fn(y)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.CustomTransform.__init__","title":"__init__(forward_fn, inverse_fn)
","text":"A custom transformation using a user-defined froward and inverse function
Parameters:
Name Type Description Defaultforward_fn
Callable
Forward transformation
requiredinverse_fn
Callable
Inverse transformation
required Source code injaxley/optimize/transforms.py
def __init__(self, forward_fn: Callable, inverse_fn: Callable) -> None:\n \"\"\"A custom transformation using a user-defined froward and\n inverse function\n\n Args:\n forward_fn (Callable): Forward transformation\n inverse_fn (Callable): Inverse transformation\n \"\"\"\n super().__init__()\n self.forward_fn = forward_fn\n self.inverse_fn = inverse_fn\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.MaskedTransform","title":"MaskedTransform
","text":" Bases: Transform
jaxley/optimize/transforms.py
class MaskedTransform(Transform):\n def __init__(self, mask: ArrayLike, transform: Transform) -> None:\n \"\"\"A masked transformation\n\n Args:\n mask (ArrayLike): Which elements to transform\n transform (Transform): Transformation to apply\n \"\"\"\n super().__init__()\n self.mask = mask\n self.transform = transform\n\n def forward(self, x: ArrayLike) -> Array:\n return jnp.where(self.mask, self.transform.forward(x), x)\n\n def inverse(self, y: ArrayLike) -> Array:\n return jnp.where(self.mask, self.transform.inverse(y), y)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.MaskedTransform.__init__","title":"__init__(mask, transform)
","text":"A masked transformation
Parameters:
Name Type Description Defaultmask
ArrayLike
Which elements to transform
requiredtransform
Transform
Transformation to apply
required Source code injaxley/optimize/transforms.py
def __init__(self, mask: ArrayLike, transform: Transform) -> None:\n \"\"\"A masked transformation\n\n Args:\n mask (ArrayLike): Which elements to transform\n transform (Transform): Transformation to apply\n \"\"\"\n super().__init__()\n self.mask = mask\n self.transform = transform\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.NegSoftplusTransform","title":"NegSoftplusTransform
","text":" Bases: SoftplusTransform
Negative softplus transformation.
Source code injaxley/optimize/transforms.py
class NegSoftplusTransform(SoftplusTransform):\n \"\"\"Negative softplus transformation.\"\"\"\n\n def __init__(self, upper: ArrayLike) -> None:\n \"\"\"This transform maps any value bijectively to the interval (-inf, upper].\n\n Args:\n upper (ArrayLike): Upper bound of the interval.\n \"\"\"\n super().__init__(upper)\n\n def forward(self, x: ArrayLike) -> Array:\n return -super().forward(-x)\n\n def inverse(self, y: ArrayLike) -> Array:\n return -super().inverse(-y)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.NegSoftplusTransform.__init__","title":"__init__(upper)
","text":"This transform maps any value bijectively to the interval (-inf, upper].
Parameters:
Name Type Description Defaultupper
ArrayLike
Upper bound of the interval.
required Source code injaxley/optimize/transforms.py
def __init__(self, upper: ArrayLike) -> None:\n \"\"\"This transform maps any value bijectively to the interval (-inf, upper].\n\n Args:\n upper (ArrayLike): Upper bound of the interval.\n \"\"\"\n super().__init__(upper)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.ParamTransform","title":"ParamTransform
","text":"Parameter transformation utility.
This class is used to transform parameters usually from an unconstrained space to a constrained space and back (bacause most biophysical parameter are bounded). The user can specify a PyTree of transforms that are applied to the parameters.
Attributes:
Name Type Descriptiontf_dict
A PyTree of transforms for each parameter.
Source code injaxley/optimize/transforms.py
class ParamTransform:\n \"\"\"Parameter transformation utility.\n\n This class is used to transform parameters usually from an unconstrained space to a constrained space\n and back (bacause most biophysical parameter are bounded). The user can specify a PyTree of transforms\n that are applied to the parameters.\n\n Attributes:\n tf_dict: A PyTree of transforms for each parameter.\n\n \"\"\"\n\n def __init__(self, tf_dict: List[Dict[str, Transform]] | Transform) -> None:\n \"\"\"Creates a new ParamTransform object.\n\n Args:\n tf_dict: A PyTree of transforms for each parameter.\n \"\"\"\n\n self.tf_dict = tf_dict\n\n def forward(\n self, params: List[Dict[str, ArrayLike]] | ArrayLike\n ) -> Dict[str, Array]:\n \"\"\"Pushes unconstrained parameters through a tf such that they fit the interval.\n\n Args:\n params: A list of dictionaries (or any PyTree) with unconstrained parameters.\n\n Returns:\n A list of dictionaries (or any PyTree) with transformed parameters.\n\n \"\"\"\n\n return jax.tree_util.tree_map(lambda x, tf: tf.forward(x), params, self.tf_dict)\n\n def inverse(\n self, params: List[Dict[str, ArrayLike]] | ArrayLike\n ) -> Dict[str, Array]:\n \"\"\"Takes parameters from within the interval and makes them unconstrained.\n\n Args:\n params: A list of dictionaries (or any PyTree) with transformed parameters.\n\n Returns:\n A list of dictionaries (or any PyTree) with unconstrained parameters.\n \"\"\"\n\n return jax.tree_util.tree_map(lambda x, tf: tf.inverse(x), params, self.tf_dict)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.ParamTransform.__init__","title":"__init__(tf_dict)
","text":"Creates a new ParamTransform object.
Parameters:
Name Type Description Defaulttf_dict
List[Dict[str, Transform]] | Transform
A PyTree of transforms for each parameter.
required Source code injaxley/optimize/transforms.py
def __init__(self, tf_dict: List[Dict[str, Transform]] | Transform) -> None:\n \"\"\"Creates a new ParamTransform object.\n\n Args:\n tf_dict: A PyTree of transforms for each parameter.\n \"\"\"\n\n self.tf_dict = tf_dict\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.ParamTransform.forward","title":"forward(params)
","text":"Pushes unconstrained parameters through a tf such that they fit the interval.
Parameters:
Name Type Description Defaultparams
List[Dict[str, ArrayLike]] | ArrayLike
A list of dictionaries (or any PyTree) with unconstrained parameters.
requiredReturns:
Type DescriptionDict[str, Array]
A list of dictionaries (or any PyTree) with transformed parameters.
Source code injaxley/optimize/transforms.py
def forward(\n self, params: List[Dict[str, ArrayLike]] | ArrayLike\n) -> Dict[str, Array]:\n \"\"\"Pushes unconstrained parameters through a tf such that they fit the interval.\n\n Args:\n params: A list of dictionaries (or any PyTree) with unconstrained parameters.\n\n Returns:\n A list of dictionaries (or any PyTree) with transformed parameters.\n\n \"\"\"\n\n return jax.tree_util.tree_map(lambda x, tf: tf.forward(x), params, self.tf_dict)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.ParamTransform.inverse","title":"inverse(params)
","text":"Takes parameters from within the interval and makes them unconstrained.
Parameters:
Name Type Description Defaultparams
List[Dict[str, ArrayLike]] | ArrayLike
A list of dictionaries (or any PyTree) with transformed parameters.
requiredReturns:
Type DescriptionDict[str, Array]
A list of dictionaries (or any PyTree) with unconstrained parameters.
Source code injaxley/optimize/transforms.py
def inverse(\n self, params: List[Dict[str, ArrayLike]] | ArrayLike\n) -> Dict[str, Array]:\n \"\"\"Takes parameters from within the interval and makes them unconstrained.\n\n Args:\n params: A list of dictionaries (or any PyTree) with transformed parameters.\n\n Returns:\n A list of dictionaries (or any PyTree) with unconstrained parameters.\n \"\"\"\n\n return jax.tree_util.tree_map(lambda x, tf: tf.inverse(x), params, self.tf_dict)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.SigmoidTransform","title":"SigmoidTransform
","text":" Bases: Transform
Sigmoid transformation.
Source code injaxley/optimize/transforms.py
class SigmoidTransform(Transform):\n \"\"\"Sigmoid transformation.\"\"\"\n\n def __init__(self, lower: ArrayLike, upper: ArrayLike) -> None:\n \"\"\"This transform maps any value bijectively to the interval [lower, upper].\n\n Args:\n lower (ArrayLike): Lower bound of the interval.\n upper (ArrayLike): Upper bound of the interval.\n \"\"\"\n super().__init__()\n self.lower = lower\n self.width = upper - lower\n\n def forward(self, x: ArrayLike) -> Array:\n y = 1.0 / (1.0 + save_exp(-x))\n return self.lower + self.width * y\n\n def inverse(self, y: ArrayLike) -> Array:\n x = (y - self.lower) / self.width\n x = -jnp.log((1.0 / x) - 1.0)\n return x\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.SigmoidTransform.__init__","title":"__init__(lower, upper)
","text":"This transform maps any value bijectively to the interval [lower, upper].
Parameters:
Name Type Description Defaultlower
ArrayLike
Lower bound of the interval.
requiredupper
ArrayLike
Upper bound of the interval.
required Source code injaxley/optimize/transforms.py
def __init__(self, lower: ArrayLike, upper: ArrayLike) -> None:\n \"\"\"This transform maps any value bijectively to the interval [lower, upper].\n\n Args:\n lower (ArrayLike): Lower bound of the interval.\n upper (ArrayLike): Upper bound of the interval.\n \"\"\"\n super().__init__()\n self.lower = lower\n self.width = upper - lower\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.SoftplusTransform","title":"SoftplusTransform
","text":" Bases: Transform
Softplus transformation.
Source code injaxley/optimize/transforms.py
class SoftplusTransform(Transform):\n \"\"\"Softplus transformation.\"\"\"\n\n def __init__(self, lower: ArrayLike) -> None:\n \"\"\"This transform maps any value bijectively to the interval [lower, inf).\n\n Args:\n lower (ArrayLike): Lower bound of the interval.\n \"\"\"\n super().__init__()\n self.lower = lower\n\n def forward(self, x: ArrayLike) -> Array:\n return jnp.log1p(save_exp(x)) + self.lower\n\n def inverse(self, y: ArrayLike) -> Array:\n return jnp.log(save_exp(y - self.lower) - 1.0)\n
"},{"location":"reference/optimize/#jaxley.optimize.transforms.SoftplusTransform.__init__","title":"__init__(lower)
","text":"This transform maps any value bijectively to the interval [lower, inf).
Parameters:
Name Type Description Defaultlower
ArrayLike
Lower bound of the interval.
required Source code injaxley/optimize/transforms.py
def __init__(self, lower: ArrayLike) -> None:\n \"\"\"This transform maps any value bijectively to the interval [lower, inf).\n\n Args:\n lower (ArrayLike): Lower bound of the interval.\n \"\"\"\n super().__init__()\n self.lower = lower\n
"},{"location":"reference/utils/","title":"Utils","text":""},{"location":"reference/utils/#jaxley.utils.cell_utils.build_radiuses_from_xyzr","title":"build_radiuses_from_xyzr(radius_fns, branch_indices, min_radius, ncomp)
","text":"Return the radiuses of branches given SWC file xyzr.
Returns an array of shape (num_branches, ncomp)
.
Parameters:
Name Type Description Defaultradius_fns
List[Callable]
Functions which, given compartment locations return the radius.
requiredbranch_indices
List[int]
The indices of the branches for which to return the radiuses.
requiredmin_radius
Optional[float]
If passed, the radiuses are clipped to be at least as large.
requiredncomp
int
The number of compartments that every branch is discretized into.
required Source code injaxley/utils/cell_utils.py
def build_radiuses_from_xyzr(\n radius_fns: List[Callable],\n branch_indices: List[int],\n min_radius: Optional[float],\n ncomp: int,\n) -> jnp.ndarray:\n \"\"\"Return the radiuses of branches given SWC file xyzr.\n\n Returns an array of shape `(num_branches, ncomp)`.\n\n Args:\n radius_fns: Functions which, given compartment locations return the radius.\n branch_indices: The indices of the branches for which to return the radiuses.\n min_radius: If passed, the radiuses are clipped to be at least as large.\n ncomp: The number of compartments that every branch is discretized into.\n \"\"\"\n # Compartment locations are at the center of the internal nodes.\n non_split = 1 / ncomp\n range_ = np.linspace(non_split / 2, 1 - non_split / 2, ncomp)\n\n # Build radiuses.\n radiuses = np.asarray([radius_fns[b](range_) for b in branch_indices])\n radiuses_each = radiuses.ravel(order=\"C\")\n if min_radius is None:\n assert np.all(\n radiuses_each > 0.0\n ), \"Radius 0.0 in SWC file. Set `read_swc(..., min_radius=...)`.\"\n else:\n radiuses_each[radiuses_each < min_radius] = min_radius\n\n return radiuses_each\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_axial_conductances","title":"compute_axial_conductances(comp_edges, params)
","text":"Given comp_edges
, radius, length, r_a, cm, compute the axial conductances.
Note that the resulting axial conductances will already by divided by the capacitance cm
.
jaxley/utils/cell_utils.py
def compute_axial_conductances(\n comp_edges: pd.DataFrame, params: Dict[str, jnp.ndarray]\n) -> jnp.ndarray:\n \"\"\"Given `comp_edges`, radius, length, r_a, cm, compute the axial conductances.\n\n Note that the resulting axial conductances will already by divided by the\n capacitance `cm`.\n \"\"\"\n # `Compartment-to-compartment` (c2c) axial coupling conductances.\n condition = comp_edges[\"type\"].to_numpy() == 0\n source_comp_inds = np.asarray(comp_edges[condition][\"source\"].to_list())\n sink_comp_inds = np.asarray(comp_edges[condition][\"sink\"].to_list())\n\n if len(sink_comp_inds) > 0:\n conds_c2c = (\n vmap(compute_coupling_cond, in_axes=(0, 0, 0, 0, 0, 0))(\n params[\"radius\"][sink_comp_inds],\n params[\"radius\"][source_comp_inds],\n params[\"axial_resistivity\"][sink_comp_inds],\n params[\"axial_resistivity\"][source_comp_inds],\n params[\"length\"][sink_comp_inds],\n params[\"length\"][source_comp_inds],\n )\n / params[\"capacitance\"][sink_comp_inds]\n )\n else:\n conds_c2c = jnp.asarray([])\n\n # `branchpoint-to-compartment` (bp2c) axial coupling conductances.\n condition = comp_edges[\"type\"].isin([1, 2])\n sink_comp_inds = np.asarray(comp_edges[condition][\"sink\"].to_list())\n\n if len(sink_comp_inds) > 0:\n conds_bp2c = (\n vmap(compute_coupling_cond_branchpoint, in_axes=(0, 0, 0))(\n params[\"radius\"][sink_comp_inds],\n params[\"axial_resistivity\"][sink_comp_inds],\n params[\"length\"][sink_comp_inds],\n )\n / params[\"capacitance\"][sink_comp_inds]\n )\n else:\n conds_bp2c = jnp.asarray([])\n\n # `compartment-to-branchpoint` (c2bp) axial coupling conductances.\n condition = comp_edges[\"type\"].isin([3, 4])\n source_comp_inds = np.asarray(comp_edges[condition][\"source\"].to_list())\n\n if len(source_comp_inds) > 0:\n conds_c2bp = vmap(compute_impact_on_node, in_axes=(0, 0, 0))(\n params[\"radius\"][source_comp_inds],\n params[\"axial_resistivity\"][source_comp_inds],\n params[\"length\"][source_comp_inds],\n )\n # For numerical stability. These values are very small, but their scale\n # does not matter.\n conds_c2bp *= 1_000\n else:\n conds_c2bp = jnp.asarray([])\n\n # All axial coupling conductances.\n return jnp.concatenate([conds_c2c, conds_bp2c, conds_c2bp])\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_children_and_parents","title":"compute_children_and_parents(branch_edges)
","text":"Build indices used during `._init_morph_custom_spsolve().
Source code injaxley/utils/cell_utils.py
def compute_children_and_parents(\n branch_edges: pd.DataFrame,\n) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int]:\n \"\"\"Build indices used during `._init_morph_custom_spsolve().\"\"\"\n par_inds = branch_edges[\"parent_branch_index\"].to_numpy()\n child_inds = branch_edges[\"child_branch_index\"].to_numpy()\n child_belongs_to_branchpoint = remap_to_consecutive(par_inds)\n par_inds = np.unique(par_inds)\n return par_inds, child_inds, child_belongs_to_branchpoint\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_children_indices","title":"compute_children_indices(parents)
","text":"Return all children indices of every branch.
Example:
parents = [-1, 0, 0]\ncompute_children_indices(parents) -> [[1, 2], [], []]\n
Source code in jaxley/utils/cell_utils.py
def compute_children_indices(parents) -> List[jnp.ndarray]:\n \"\"\"Return all children indices of every branch.\n\n Example:\n ```\n parents = [-1, 0, 0]\n compute_children_indices(parents) -> [[1, 2], [], []]\n ```\n \"\"\"\n num_branches = len(parents)\n child_indices = []\n for b in range(num_branches):\n child_indices.append(np.where(parents == b)[0])\n return child_indices\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_coupling_cond","title":"compute_coupling_cond(rad1, rad2, r_a1, r_a2, l1, l2)
","text":"Return the coupling conductance between two compartments.
Equations taken from https://en.wikipedia.org/wiki/Compartmental_neuron_models
.
radius
: um r_a
: ohm cm length_single_compartment
: um coupling_conds
: S * um / cm / um^2 = S / cm / um -> *10**7 -> mS / cm^2
jaxley/utils/cell_utils.py
def compute_coupling_cond(rad1, rad2, r_a1, r_a2, l1, l2):\n \"\"\"Return the coupling conductance between two compartments.\n\n Equations taken from `https://en.wikipedia.org/wiki/Compartmental_neuron_models`.\n\n `radius`: um\n `r_a`: ohm cm\n `length_single_compartment`: um\n `coupling_conds`: S * um / cm / um^2 = S / cm / um -> *10**7 -> mS / cm^2\n \"\"\"\n # Multiply by 10**7 to convert (S / cm / um) -> (mS / cm^2).\n return rad1 * rad2**2 / (r_a1 * rad2**2 * l1 + r_a2 * rad1**2 * l2) / l1 * 10**7\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_coupling_cond_branchpoint","title":"compute_coupling_cond_branchpoint(rad, r_a, l)
","text":"Return the coupling conductance between one compartment and a comp with l=0.
From https://en.wikipedia.org/wiki/Compartmental_neuron_models
If one compartment has l=0.0 then the equations simplify.
R_long = \\sum_i r_a * L_i/2 / crosssection_i
with crosssection = pi * r**2
For a single compartment with L>0, this turns into: R_long = r_a * L/2 / crosssection
Then, g_long = crosssection * 2 / L / r_a
Then, the effective conductance is g_long / zylinder_area. So: g = pi * r**2 * 2 / L / r_a / 2 / pi / r / L g = r / r_a / L**2
Source code injaxley/utils/cell_utils.py
def compute_coupling_cond_branchpoint(rad, r_a, l):\n r\"\"\"Return the coupling conductance between one compartment and a comp with l=0.\n\n From https://en.wikipedia.org/wiki/Compartmental_neuron_models\n\n If one compartment has l=0.0 then the equations simplify.\n\n R_long = \\sum_i r_a * L_i/2 / crosssection_i\n\n with crosssection = pi * r**2\n\n For a single compartment with L>0, this turns into:\n R_long = r_a * L/2 / crosssection\n\n Then, g_long = crosssection * 2 / L / r_a\n\n Then, the effective conductance is g_long / zylinder_area. So:\n g = pi * r**2 * 2 / L / r_a / 2 / pi / r / L\n g = r / r_a / L**2\n \"\"\"\n return rad / r_a / l**2 * 10**7 # Convert (S / cm / um) -> (mS / cm^2)\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_impact_on_node","title":"compute_impact_on_node(rad, r_a, l)
","text":"Compute the weight with which a compartment influences its node.
In order to satisfy Kirchhoffs current law, the current at a branch point must be proportional to the crosssection of the compartment. We only require proportionality here because the branch point equation reads: g_1 * (V_1 - V_b) + g_2 * (V_2 - V_b) = 0.0
Because R_long = r_a * L/2 / crosssection, we get g_long = crosssection * 2 / L / r_a \\propto rad**2 / L / r_a
This equation can be multiplied by any constant.
Source code injaxley/utils/cell_utils.py
def compute_impact_on_node(rad, r_a, l):\n r\"\"\"Compute the weight with which a compartment influences its node.\n\n In order to satisfy Kirchhoffs current law, the current at a branch point must be\n proportional to the crosssection of the compartment. We only require proportionality\n here because the branch point equation reads:\n `g_1 * (V_1 - V_b) + g_2 * (V_2 - V_b) = 0.0`\n\n Because R_long = r_a * L/2 / crosssection, we get\n g_long = crosssection * 2 / L / r_a \\propto rad**2 / L / r_a\n\n This equation can be multiplied by any constant.\"\"\"\n return rad**2 / r_a / l\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.compute_morphology_indices_in_levels","title":"compute_morphology_indices_in_levels(num_branchpoints, child_belongs_to_branchpoint, par_inds, child_inds)
","text":"Return (row, col) to build the sparse matrix defining the voltage eqs.
This is run at init
, not during runtime.
jaxley/utils/cell_utils.py
def compute_morphology_indices_in_levels(\n num_branchpoints,\n child_belongs_to_branchpoint,\n par_inds,\n child_inds,\n):\n \"\"\"Return (row, col) to build the sparse matrix defining the voltage eqs.\n\n This is run at `init`, not during runtime.\n \"\"\"\n branchpoint_inds_parents = jnp.arange(num_branchpoints)\n branchpoint_inds_children = child_belongs_to_branchpoint\n branch_inds_parents = par_inds\n branch_inds_children = child_inds\n\n children = jnp.stack([branch_inds_children, branchpoint_inds_children])\n parents = jnp.stack([branch_inds_parents, branchpoint_inds_parents])\n\n return {\"children\": children.T, \"parents\": parents.T}\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.convert_point_process_to_distributed","title":"convert_point_process_to_distributed(current, radius, length)
","text":"Convert current point process (nA) to distributed current (uA/cm2).
This function gets called for synapses and for external stimuli.
Parameters:
Name Type Description Defaultcurrent
ndarray
Current in nA
.
radius
ndarray
Compartment radius in um
.
length
ndarray
Compartment length in um
.
Current in uA/cm2
.
jaxley/utils/cell_utils.py
def convert_point_process_to_distributed(\n current: jnp.ndarray, radius: jnp.ndarray, length: jnp.ndarray\n) -> jnp.ndarray:\n \"\"\"Convert current point process (nA) to distributed current (uA/cm2).\n\n This function gets called for synapses and for external stimuli.\n\n Args:\n current: Current in `nA`.\n radius: Compartment radius in `um`.\n length: Compartment length in `um`.\n\n Return:\n Current in `uA/cm2`.\n \"\"\"\n area = 2 * pi * radius * length\n current /= area # nA / um^2\n return current * 100_000 # Convert (nA / um^2) to (uA / cm^2)\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.equal_segments","title":"equal_segments(branch_property, ncomp_per_branch)
","text":"Generates segments where some property is the same in each segment.
Parameters:
Name Type Description Defaultbranch_property
list
List of values of the property in each branch. Should have len(branch_property) == num_branches
.
jaxley/utils/cell_utils.py
def equal_segments(branch_property: list, ncomp_per_branch: int):\n \"\"\"Generates segments where some property is the same in each segment.\n\n Args:\n branch_property: List of values of the property in each branch. Should have\n `len(branch_property) == num_branches`.\n \"\"\"\n assert isinstance(branch_property, list), \"branch_property must be a list.\"\n return jnp.asarray([branch_property] * ncomp_per_branch).T\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.get_num_neighbours","title":"get_num_neighbours(num_children, ncomp_per_branch, num_branches)
","text":"Number of neighbours of each compartment.
Source code injaxley/utils/cell_utils.py
def get_num_neighbours(\n num_children: jnp.ndarray,\n ncomp_per_branch: int,\n num_branches: int,\n):\n \"\"\"\n Number of neighbours of each compartment.\n \"\"\"\n num_neighbours = 2 * jnp.ones((num_branches * ncomp_per_branch))\n num_neighbours = num_neighbours.at[ncomp_per_branch - 1].set(1.0)\n num_neighbours = num_neighbours.at[jnp.arange(num_branches) * ncomp_per_branch].set(\n num_children + 1.0\n )\n return num_neighbours\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.group_and_sum","title":"group_and_sum(values_to_sum, inds_to_group_by, num_branchpoints)
","text":"Group values by whether they have the same integer and sum values within group.
This is used to construct the last diagonals at the branch points.
Written by ChatGPT.
Source code injaxley/utils/cell_utils.py
def group_and_sum(\n values_to_sum: jnp.ndarray, inds_to_group_by: jnp.ndarray, num_branchpoints: int\n) -> jnp.ndarray:\n \"\"\"Group values by whether they have the same integer and sum values within group.\n\n This is used to construct the last diagonals at the branch points.\n\n Written by ChatGPT.\n \"\"\"\n # Initialize an array to hold the sum of each group\n group_sums = jnp.zeros(num_branchpoints)\n\n # `.at[inds]` requires that `inds` is not empty, so we need an if-case here.\n # `len(inds) == 0` is the case for branches and compartments.\n if num_branchpoints > 0:\n group_sums = group_sums.at[inds_to_group_by].add(values_to_sum)\n\n return group_sums\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.interpolate_xyzr","title":"interpolate_xyzr(loc, coords)
","text":"Perform a linear interpolation between xyz-coordinates.
Parameters:
Name Type Description Defaultloc
float
The location in [0,1] along the branch.
requiredcoords
ndarray
Array containing the reconstructed xyzr points of the branch.
required ReturnInterpolated xyz coordinate at loc
, shape `(3,).
jaxley/utils/cell_utils.py
def interpolate_xyzr(loc: float, coords: np.ndarray):\n \"\"\"Perform a linear interpolation between xyz-coordinates.\n\n Args:\n loc: The location in [0,1] along the branch.\n coords: Array containing the reconstructed xyzr points of the branch.\n\n Return:\n Interpolated xyz coordinate at `loc`, shape `(3,).\n \"\"\"\n dl = np.sqrt(np.sum(np.diff(coords[:, :3], axis=0) ** 2, axis=1))\n pathlens = np.insert(np.cumsum(dl), 0, 0) # cummulative length of sections\n norm_pathlens = pathlens / np.maximum(1e-8, pathlens[-1]) # norm lengths to [0,1].\n\n return v_interp(loc, norm_pathlens, coords)\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.linear_segments","title":"linear_segments(initial_val, endpoint_vals, parents, ncomp_per_branch)
","text":"Generates segments where some property is linearly interpolated.
Parameters:
Name Type Description Defaultinitial_val
float
The value at the tip of the soma.
requiredendpoint_vals
list
The value at the endpoints of each branch.
required Source code injaxley/utils/cell_utils.py
def linear_segments(\n initial_val: float, endpoint_vals: list, parents: jnp.ndarray, ncomp_per_branch: int\n):\n \"\"\"Generates segments where some property is linearly interpolated.\n\n Args:\n initial_val: The value at the tip of the soma.\n endpoint_vals: The value at the endpoints of each branch.\n \"\"\"\n branch_property = endpoint_vals + [initial_val]\n num_branches = len(parents)\n # Compute radiuses by linear interpolation.\n endpoint_radiuses = jnp.asarray(branch_property)\n\n def compute_rad(branch_ind, loc):\n start = endpoint_radiuses[parents[branch_ind]]\n end = endpoint_radiuses[branch_ind]\n return (end - start) * loc + start\n\n branch_inds_of_each_comp = jnp.tile(jnp.arange(num_branches), ncomp_per_branch)\n locs_of_each_comp = jnp.linspace(1, 0, ncomp_per_branch).repeat(num_branches)\n rad_of_each_comp = compute_rad(branch_inds_of_each_comp, locs_of_each_comp)\n\n return jnp.reshape(rad_of_each_comp, (ncomp_per_branch, num_branches)).T\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.loc_of_index","title":"loc_of_index(global_comp_index, global_branch_index, ncomp_per_branch)
","text":"Return location corresponding to global compartment index.
Source code injaxley/utils/cell_utils.py
def loc_of_index(global_comp_index, global_branch_index, ncomp_per_branch):\n \"\"\"Return location corresponding to global compartment index.\"\"\"\n cumsum_ncomp = cumsum_leading_zero(ncomp_per_branch)\n index = global_comp_index - cumsum_ncomp[global_branch_index]\n ncomp = ncomp_per_branch[global_branch_index]\n return (0.5 + index) / ncomp\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.local_index_of_loc","title":"local_index_of_loc(loc, global_branch_ind, ncomp_per_branch)
","text":"Returns the local index of a comp given a loc [0, 1] and the index of a branch.
This is used because we specify locations such as synapses as a value between 0 and 1. We have to convert this onto a discrete segment here.
Parameters:
Name Type Description Defaultbranch_ind
Index of the branch.
requiredloc
float
Location (in [0, 1]) along that branch.
requiredncomp_per_branch
int
Number of segments of each branch.
requiredReturns:
Type Descriptionint
The local index of the compartment.
Source code injaxley/utils/cell_utils.py
def local_index_of_loc(\n loc: float, global_branch_ind: int, ncomp_per_branch: int\n) -> int:\n \"\"\"Returns the local index of a comp given a loc [0, 1] and the index of a branch.\n\n This is used because we specify locations such as synapses as a value between 0 and\n 1. We have to convert this onto a discrete segment here.\n\n Args:\n branch_ind: Index of the branch.\n loc: Location (in [0, 1]) along that branch.\n ncomp_per_branch: Number of segments of each branch.\n\n Returns:\n The local index of the compartment.\n \"\"\"\n ncomp = ncomp_per_branch[global_branch_ind] # only for convenience.\n possible_locs = np.linspace(0.5 / ncomp, 1 - 0.5 / ncomp, ncomp)\n ind_along_branch = np.argmin(np.abs(possible_locs - loc))\n return ind_along_branch\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.merge_cells","title":"merge_cells(cumsum_num_branches, cumsum_num_branchpoints, arrs, exclude_first=True)
","text":"Build full list of which branches are solved in which iteration.
From the branching pattern of single cells, this \u201cmerges\u201d them into a single ordering of branches.
Parameters:
Name Type Description Defaultcumsum_num_branches
List[int]
cumulative number of branches. E.g., for three cells with 10, 15, and 5 branches respectively, this will should be a list containing [0, 10, 25, 30]
.
arrs
List[List[ndarray]]
A list of a list of arrays that should be merged.
requiredexclude_first
bool
If True
, the first element of each list in arrs
will remain unchanged. Useful if a -1
(which indicates \u201cno parent\u201d) entry should not be changed.
True
Returns:
Type Descriptionndarray
A list of arrays which contain the branch indices that are computed at each
ndarray
level (i.e., iteration).
Source code injaxley/utils/cell_utils.py
def merge_cells(\n cumsum_num_branches: List[int],\n cumsum_num_branchpoints: List[int],\n arrs: List[List[np.ndarray]],\n exclude_first: bool = True,\n) -> np.ndarray:\n \"\"\"\n Build full list of which branches are solved in which iteration.\n\n From the branching pattern of single cells, this \"merges\" them into a single\n ordering of branches.\n\n Args:\n cumsum_num_branches: cumulative number of branches. E.g., for three cells with\n 10, 15, and 5 branches respectively, this will should be a list containing\n `[0, 10, 25, 30]`.\n arrs: A list of a list of arrays that should be merged.\n exclude_first: If `True`, the first element of each list in `arrs` will remain\n unchanged. Useful if a `-1` (which indicates \"no parent\") entry should not\n be changed.\n\n Returns:\n A list of arrays which contain the branch indices that are computed at each\n level (i.e., iteration).\n \"\"\"\n ps = []\n for i, att in enumerate(arrs):\n p = att\n if exclude_first:\n raise NotImplementedError\n p = [p[0]] + [p_in_level + cumsum_num_branches[i] for p_in_level in p[1:]]\n else:\n p = [\n p_in_level\n + np.asarray([cumsum_num_branches[i], cumsum_num_branchpoints[i]])\n for p_in_level in p\n ]\n ps.append(p)\n\n max_len = max([len(att) for att in arrs])\n combined_parents_in_level = []\n for i in range(max_len):\n current_ps = []\n for p in ps:\n if len(p) > i:\n current_ps.append(p[i])\n combined_parents_in_level.append(np.concatenate(current_ps))\n\n return combined_parents_in_level\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.params_to_pstate","title":"params_to_pstate(params, indices_set_by_trainables)
","text":"Make outputs get_parameters()
conform with outputs of .data_set()
.
make_trainable()
followed by params=get_parameters()
does not return indices because these indices would also be differentiated by jax.grad
(as soon as the params
are passed to def simulate(params)
. Therefore, in jx.integrate
, we run the function to add indices to the dict. The outputs of params_to_pstate
are of the same shape as the outputs of .data_set()
.
jaxley/utils/cell_utils.py
def params_to_pstate(\n params: List[Dict[str, jnp.ndarray]],\n indices_set_by_trainables: List[jnp.ndarray],\n):\n \"\"\"Make outputs `get_parameters()` conform with outputs of `.data_set()`.\n\n `make_trainable()` followed by `params=get_parameters()` does not return indices\n because these indices would also be differentiated by `jax.grad` (as soon as\n the `params` are passed to `def simulate(params)`. Therefore, in `jx.integrate`,\n we run the function to add indices to the dict. The outputs of `params_to_pstate`\n are of the same shape as the outputs of `.data_set()`.\"\"\"\n return [\n {\"key\": list(p.keys())[0], \"val\": list(p.values())[0], \"indices\": i}\n for p, i in zip(params, indices_set_by_trainables)\n ]\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.query_channel_states_and_params","title":"query_channel_states_and_params(d, keys, idcs)
","text":"Get dict with subset of keys and values from d.
This is used to restrict a dict where every item contains all states to only the ones that are relevant for the channel. E.g.
states = {'eCa': Array([ 0., 0., nan]}
will be states = {'eCa': Array([ 0., 0.]}
Only loops over necessary keys, as opposed to looping over d.items()
.
jaxley/utils/cell_utils.py
def query_channel_states_and_params(d, keys, idcs):\n \"\"\"Get dict with subset of keys and values from d.\n\n This is used to restrict a dict where every item contains __all__ states to only\n the ones that are relevant for the channel. E.g.\n\n ```states = {'eCa': Array([ 0., 0., nan]}```\n\n will be\n ```states = {'eCa': Array([ 0., 0.]}```\n\n Only loops over necessary keys, as opposed to looping over `d.items()`.\"\"\"\n return dict(zip(keys, (v[idcs] for v in map(d.get, keys))))\n
"},{"location":"reference/utils/#jaxley.utils.cell_utils.remap_to_consecutive","title":"remap_to_consecutive(arr)
","text":"Maps an array of integers to an array of consecutive integers.
E.g. [0, 0, 1, 4, 4, 6, 6] -> [0, 0, 1, 2, 2, 3, 3]
jaxley/utils/cell_utils.py
def remap_to_consecutive(arr):\n \"\"\"Maps an array of integers to an array of consecutive integers.\n\n E.g. `[0, 0, 1, 4, 4, 6, 6] -> [0, 0, 1, 2, 2, 3, 3]`\n \"\"\"\n _, inverse_indices = jnp.unique(arr, return_inverse=True)\n return inverse_indices\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.compute_rotation_matrix","title":"compute_rotation_matrix(axis, angle)
","text":"Return the rotation matrix associated with counterclockwise rotation about the given axis by the given angle.
Can be used to rotate a coordinate vector by multiplying it with the rotation matrix.
Parameters:
Name Type Description Defaultaxis
ndarray
The axis of rotation.
requiredangle
float
The angle of rotation in radians.
requiredReturns:
Type Descriptionndarray
A 3x3 rotation matrix.
Source code injaxley/utils/plot_utils.py
def compute_rotation_matrix(axis: ndarray, angle: float) -> ndarray:\n \"\"\"\n Return the rotation matrix associated with counterclockwise rotation about\n the given axis by the given angle.\n\n Can be used to rotate a coordinate vector by multiplying it with the rotation\n matrix.\n\n Args:\n axis: The axis of rotation.\n angle: The angle of rotation in radians.\n\n Returns:\n A 3x3 rotation matrix.\n \"\"\"\n axis = axis / np.sqrt(np.dot(axis, axis))\n a = np.cos(angle / 2.0)\n b, c, d = -axis * np.sin(angle / 2.0)\n aa, bb, cc, dd = a * a, b * b, c * c, d * d\n bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d\n return np.array(\n [\n [aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],\n [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],\n [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc],\n ]\n )\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.create_cone_frustum_mesh","title":"create_cone_frustum_mesh(length, radius_bottom, radius_top, bottom_dome=False, top_dome=False, resolution=100)
","text":"Generates mesh points for a cone frustum, with optional domes at either end.
This is used to render the traced morphology in 3D (and to project it to 2D) as part of plot_morph
. Sections between two traced coordinates with two different radii can be represented by a cone frustum. Additionally, the ends of the frustum can be capped with hemispheres to ensure that two neighbouring frustums are connected smoothly (like ball joints).
Parameters:
Name Type Description Defaultlength
float
The length of the frustum.
requiredradius_bottom
float
The radius of the bottom of the frustum.
requiredradius_top
float
The radius of the top of the frustum.
requiredbottom_dome
bool
If True, a dome is added to the bottom of the frustum. The dome is a hemisphere with radius radius_bottom
.
False
top_dome
bool
If True, a dome is added to the top of the frustum. The dome is a hemisphere with radius radius_top
.
False
resolution
int
defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.
100
Returns:
Type Descriptionndarray
An array of mesh points.
Source code injaxley/utils/plot_utils.py
def create_cone_frustum_mesh(\n length: float,\n radius_bottom: float,\n radius_top: float,\n bottom_dome: bool = False,\n top_dome: bool = False,\n resolution: int = 100,\n) -> ndarray:\n \"\"\"Generates mesh points for a cone frustum, with optional domes at either end.\n\n This is used to render the traced morphology in 3D (and to project it to 2D)\n as part of `plot_morph`. Sections between two traced coordinates with two\n different radii can be represented by a cone frustum. Additionally, the ends\n of the frustum can be capped with hemispheres to ensure that two neighbouring\n frustums are connected smoothly (like ball joints).\n\n Args:\n length: The length of the frustum.\n radius_bottom: The radius of the bottom of the frustum.\n radius_top: The radius of the top of the frustum.\n bottom_dome: If True, a dome is added to the bottom of the frustum.\n The dome is a hemisphere with radius `radius_bottom`.\n top_dome: If True, a dome is added to the top of the frustum.\n The dome is a hemisphere with radius `radius_top`.\n resolution: defines the resolution of the mesh.\n If too low (typically <10), can result in errors.\n Useful too have a simpler mesh for plotting.\n\n Returns:\n An array of mesh points.\n \"\"\"\n\n t = np.linspace(0, 2 * np.pi, resolution)\n\n # Determine the total height including domes\n total_height = length\n total_height += radius_bottom if bottom_dome else 0\n total_height += radius_top if top_dome else 0\n\n z = np.linspace(0, total_height, resolution)\n t_grid, z_coords = np.meshgrid(t, z)\n\n # Initialize arrays\n x_coords = np.zeros_like(t_grid)\n y_coords = np.zeros_like(t_grid)\n r_coords = np.zeros_like(t_grid)\n\n # Bottom hemisphere\n if bottom_dome:\n dome_mask = z_coords < radius_bottom\n arg = 1 - z_coords[dome_mask] / radius_bottom\n arg[np.isclose(arg, 1, atol=1e-6, rtol=1e-6)] = 1\n arg[np.isclose(arg, -1, atol=1e-6, rtol=1e-6)] = -1\n phi = np.arccos(1 - z_coords[dome_mask] / radius_bottom)\n r_coords[dome_mask] = radius_bottom * np.sin(phi)\n z_coords[dome_mask] = z_coords[dome_mask]\n\n # Frustum\n frustum_start = radius_bottom if bottom_dome else 0\n frustum_end = total_height - (radius_top if top_dome else 0)\n frustum_mask = (z_coords >= frustum_start) & (z_coords <= frustum_end)\n z_frustum = z_coords[frustum_mask] - frustum_start\n r_coords[frustum_mask] = radius_bottom + (radius_top - radius_bottom) * (\n z_frustum / length\n )\n\n # Top hemisphere\n if top_dome:\n dome_mask = z_coords > (total_height - radius_top)\n arg = (z_coords[dome_mask] - (total_height - radius_top)) / radius_top\n arg[np.isclose(arg, 1, atol=1e-6, rtol=1e-6)] = 1\n arg[np.isclose(arg, -1, atol=1e-6, rtol=1e-6)] = -1\n phi = np.arccos(arg)\n r_coords[dome_mask] = radius_top * np.sin(phi)\n\n x_coords = r_coords * np.cos(t_grid)\n y_coords = r_coords * np.sin(t_grid)\n\n return np.stack([x_coords, y_coords, z_coords])\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.create_cylinder_mesh","title":"create_cylinder_mesh(length, radius, resolution=100)
","text":"Generates mesh points for a cylinder.
This is used to render cylindrical compartments in 3D (and to project it to 2D) as part of plot_comps
.
Parameters:
Name Type Description Defaultlength
float
The length of the cylinder.
requiredradius
float
The radius of the cylinder.
requiredresolution
int
defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.
100
Returns:
Type Descriptionndarray
An array of mesh points.
Source code injaxley/utils/plot_utils.py
def create_cylinder_mesh(\n length: float, radius: float, resolution: int = 100\n) -> ndarray:\n \"\"\"Generates mesh points for a cylinder.\n\n This is used to render cylindrical compartments in 3D (and to project it to 2D)\n as part of `plot_comps`.\n\n Args:\n length: The length of the cylinder.\n radius: The radius of the cylinder.\n resolution: defines the resolution of the mesh.\n If too low (typically <10), can result in errors.\n Useful too have a simpler mesh for plotting.\n\n Returns:\n An array of mesh points.\n \"\"\"\n # Define cylinder\n t = np.linspace(0, 2 * np.pi, resolution)\n z_coords = np.linspace(-length / 2, length / 2, resolution)\n t_grid, z_coords = np.meshgrid(t, z_coords)\n\n x_coords = radius * np.cos(t_grid)\n y_coords = radius * np.sin(t_grid)\n return np.stack([x_coords, y_coords, z_coords])\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.create_sphere_mesh","title":"create_sphere_mesh(radius, resolution=100)
","text":"Generates mesh points for a sphere.
This is used to render spherical compartments in 3D (and to project it to 2D) as part of plot_comps
.
Parameters:
Name Type Description Defaultradius
float
The radius of the sphere.
requiredresolution
int
defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.
100
Returns:
Type Descriptionndarray
An array of mesh points.
Source code injaxley/utils/plot_utils.py
def create_sphere_mesh(radius: float, resolution: int = 100) -> np.ndarray:\n \"\"\"Generates mesh points for a sphere.\n\n This is used to render spherical compartments in 3D (and to project it to 2D)\n as part of `plot_comps`.\n\n Args:\n radius: The radius of the sphere.\n resolution: defines the resolution of the mesh.\n If too low (typically <10), can result in errors.\n Useful too have a simpler mesh for plotting.\n\n Returns:\n An array of mesh points.\n \"\"\"\n phi = np.linspace(0, np.pi, resolution)\n theta = np.linspace(0, 2 * np.pi, resolution)\n\n # Create a 2D meshgrid for phi and theta\n phi_coords, theta_coords = np.meshgrid(phi, theta)\n\n # Convert spherical coordinates to Cartesian coordinates\n x_coords = radius * np.sin(phi_coords) * np.cos(theta_coords)\n y_coords = radius * np.sin(phi_coords) * np.sin(theta_coords)\n z_coords = radius * np.cos(phi_coords)\n\n return np.stack([x_coords, y_coords, z_coords])\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.extract_outline","title":"extract_outline(points)
","text":"Get the outline of a 2D/3D shape.
Extracts the subset of points which form the convex hull, i.e. the outline of the input points.
Parameters:
Name Type Description Defaultpoints
ndarray
An array of points / corrdinates.
requiredReturns:
Type Descriptionndarray
An array of points which form the convex hull.
Source code injaxley/utils/plot_utils.py
def extract_outline(points: ndarray) -> ndarray:\n \"\"\"Get the outline of a 2D/3D shape.\n\n Extracts the subset of points which form the convex hull, i.e. the outline of\n the input points.\n\n Args:\n points: An array of points / corrdinates.\n\n Returns:\n An array of points which form the convex hull.\n \"\"\"\n hull = ConvexHull(points)\n hull_points = points[hull.vertices]\n return hull_points\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.plot_comps","title":"plot_comps(module_or_view, dims=(0, 1), col='k', ax=None, comp_plot_kwargs={}, true_comp_length=True, resolution=100)
","text":"Plot compartmentalized neural morphology.
Plots the projection of the cylindrical compartments.
Parameters:
Name Type Description Defaultmodule_or_view
Union[Module, View]
The module or view to plot.
requireddims
Tuple[int]
The dimensions to plot / to project the cylinder onto, i.e. [0,1] xy-plane or [0,1,2] for 3D.
(0, 1)
col
str
The color for all compartments
'k'
ax
Optional[Axes]
The matplotlib axis to plot on.
None
comp_plot_kwargs
Dict
The plot kwargs for plt.fill.
{}
true_comp_length
bool
If True, the length of the compartment is used, i.e. the length of the traced neurite. This means for zig-zagging neurites the cylinders will be longer than the straight-line distance between the start and end point of the neurite. This can lead to overlapping and miss-aligned cylinders. Setting this False will use the straight-line distance instead for nicer plots.
True
resolution
int
defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.
100
Returns:
Type DescriptionAxes
Plot of the compartmentalized morphology.
Source code injaxley/utils/plot_utils.py
def plot_comps(\n module_or_view: Union[\"jx.Module\", \"jx.View\"],\n dims: Tuple[int] = (0, 1),\n col: str = \"k\",\n ax: Optional[Axes] = None,\n comp_plot_kwargs: Dict = {},\n true_comp_length: bool = True,\n resolution: int = 100,\n) -> Axes:\n \"\"\"Plot compartmentalized neural morphology.\n\n Plots the projection of the cylindrical compartments.\n\n Args:\n module_or_view: The module or view to plot.\n dims: The dimensions to plot / to project the cylinder onto,\n i.e. [0,1] xy-plane or [0,1,2] for 3D.\n col: The color for all compartments\n ax: The matplotlib axis to plot on.\n comp_plot_kwargs: The plot kwargs for plt.fill.\n true_comp_length: If True, the length of the compartment is used, i.e. the\n length of the traced neurite. This means for zig-zagging neurites the\n cylinders will be longer than the straight-line distance between the\n start and end point of the neurite. This can lead to overlapping and\n miss-aligned cylinders. Setting this False will use the straight-line\n distance instead for nicer plots.\n resolution: defines the resolution of the mesh.\n If too low (typically <10), can result in errors.\n Useful too have a simpler mesh for plotting.\n\n Returns:\n Plot of the compartmentalized morphology.\n \"\"\"\n if ax is None:\n fig = plt.figure(figsize=(3, 3))\n ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection=\"3d\")\n\n assert not np.any(\n np.isnan(module_or_view.xyzr[0][:, :3])\n ), \"missing xyz coordinates.\"\n if \"x\" not in module_or_view.nodes.columns:\n module_or_view.compute_compartment_centers()\n\n for idx, xyzr in zip(module_or_view._branches_in_view, module_or_view.xyzr):\n locs = xyzr[:, :3]\n if locs.shape[0] == 1: # assume spherical comp\n radius = xyzr[:, -1]\n center = xyzr[0, :3]\n if len(dims) == 3:\n xyz = create_sphere_mesh(radius, resolution)\n ax = plot_mesh(\n xyz,\n np.array([0, 0, 1]),\n center,\n np.array(dims),\n ax,\n color=col,\n **comp_plot_kwargs,\n )\n else:\n ax.add_artist(plt.Circle(locs[0, dims], radius, color=col))\n else:\n lens = np.sqrt(np.nansum(np.diff(locs, axis=0) ** 2, axis=1))\n lens = np.cumsum([0] + lens.tolist())\n comp_ends = v_interp(\n np.linspace(0, lens[-1], module_or_view.ncomp + 1), lens, locs\n ).T\n axes = np.diff(comp_ends, axis=0)\n cylinder_lens = np.sqrt(np.sum(axes**2, axis=1))\n\n branch_df = module_or_view.nodes[\n module_or_view.nodes[\"global_branch_index\"] == idx\n ]\n for l, axis, (i, comp) in zip(cylinder_lens, axes, branch_df.iterrows()):\n center = comp[[\"x\", \"y\", \"z\"]]\n radius = comp[\"radius\"]\n length = comp[\"length\"] if true_comp_length else l\n xyz = create_cylinder_mesh(length, radius, resolution)\n ax = plot_mesh(\n xyz,\n axis,\n center,\n np.array(dims),\n ax,\n color=col,\n **comp_plot_kwargs,\n )\n return ax\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.plot_graph","title":"plot_graph(xyzr, dims=(0, 1), col='k', ax=None, type='line', morph_plot_kwargs={})
","text":"Plot morphology.
Parameters:
Name Type Description Defaultxyzr
ndarray
The coordinates of the morphology.
requireddims
Tuple[int]
Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two or three of them.
(0, 1)
col
str
The color for all branches.
'k'
ax
Optional[Axes]
The matplotlib axis to plot on.
None
type
str
Either line
or scatter
.
'line'
morph_plot_kwargs
Dict
The plot kwargs for plt.plot or plt.scatter.
{}
Source code in jaxley/utils/plot_utils.py
def plot_graph(\n xyzr: ndarray,\n dims: Tuple[int] = (0, 1),\n col: str = \"k\",\n ax: Optional[Axes] = None,\n type: str = \"line\",\n morph_plot_kwargs: Dict = {},\n) -> Axes:\n \"\"\"Plot morphology.\n\n Args:\n xyzr: The coordinates of the morphology.\n dims: Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of\n two or three of them.\n col: The color for all branches.\n ax: The matplotlib axis to plot on.\n type: Either `line` or `scatter`.\n morph_plot_kwargs: The plot kwargs for plt.plot or plt.scatter.\n \"\"\"\n\n if ax is None:\n fig = plt.figure(figsize=(3, 3))\n ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection=\"3d\")\n\n for coords_of_branch in xyzr:\n points = coords_of_branch[:, dims].T\n\n if \"line\" in type.lower():\n _ = ax.plot(*points, color=col, **morph_plot_kwargs)\n elif \"scatter\" in type.lower():\n _ = ax.scatter(*points, color=col, **morph_plot_kwargs)\n else:\n raise NotImplementedError\n\n return ax\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.plot_mesh","title":"plot_mesh(mesh_points, orientation, center, dims, ax=None, **kwargs)
","text":"Plot the 2D projection of a volume mesh on a cardinal plane.
Project the projection of a cylinder that is oriented in 3D space. - Create cylinder mesh - rotate cylinder mesh to orient it lengthwise along a given orientation vector. - move its center - project onto plane - compute outline of projected mesh. - fill area inside the outline
Parameters:
Name Type Description Defaultmesh_points
ndarray
coordinates of the xyz mesh that define the volume
requiredorientation
ndarray
orientation vector. The cylinder will be oriented along this vector.
requiredcenter
ndarray
The x,y,z coordinates of the center of the cylinder.
requireddims
Tuple[int]
The dimensions to plot / to project the cylinder onto,
requiredax
Axes
The matplotlib axis to plot on.
None
Returns:
Type DescriptionAxes
Plot of the cylinder projection.
Source code injaxley/utils/plot_utils.py
def plot_mesh(\n mesh_points: ndarray,\n orientation: ndarray,\n center: ndarray,\n dims: Tuple[int],\n ax: Axes = None,\n **kwargs,\n) -> Axes:\n \"\"\"Plot the 2D projection of a volume mesh on a cardinal plane.\n\n Project the projection of a cylinder that is oriented in 3D space.\n - Create cylinder mesh\n - rotate cylinder mesh to orient it lengthwise along a given orientation vector.\n - move its center\n - project onto plane\n - compute outline of projected mesh.\n - fill area inside the outline\n\n Args:\n mesh_points: coordinates of the xyz mesh that define the volume\n orientation: orientation vector. The cylinder will be oriented along this vector.\n center: The x,y,z coordinates of the center of the cylinder.\n dims: The dimensions to plot / to project the cylinder onto,\n i.e. [0,1] xy-plane or [0,1,2] for 3D.\n ax: The matplotlib axis to plot on.\n\n Returns:\n Plot of the cylinder projection.\n \"\"\"\n if ax is None:\n fig = plt.figure(figsize=(3, 3))\n ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection=\"3d\")\n\n # Normalize axis vector\n orientation = np.array(orientation)\n orientation = orientation / np.linalg.norm(orientation)\n\n # Create a rotation matrix to align the cylinder with the given axis\n z_axis = np.array([0, 0, 1])\n rotation_axis = np.cross(z_axis, orientation)\n rotation_angle = np.arccos(np.dot(z_axis, orientation))\n\n if np.allclose(rotation_axis, 0):\n rotation_matrix = np.eye(3)\n else:\n rotation_matrix = compute_rotation_matrix(rotation_axis, rotation_angle)\n\n # Rotate mesh\n x_mesh, y_mesh, z_mesh = mesh_points\n rotated_mesh_points = np.dot(\n rotation_matrix,\n np.array([x_mesh.flatten(), y_mesh.flatten(), z_mesh.flatten()]),\n )\n rotated_mesh_points = rotated_mesh_points.reshape(3, -1)\n\n # project onto plane and move\n rotated_mesh_points = rotated_mesh_points[dims]\n rotated_mesh_points += np.array(center)[dims, np.newaxis]\n\n if len(dims) < 3:\n # get outline of cylinder mesh\n mesh_outline = extract_outline(rotated_mesh_points.T).T\n ax.fill(*mesh_outline.reshape(mesh_outline.shape[0], -1), **kwargs)\n else:\n # plot 3d mesh\n ax.plot_surface(*rotated_mesh_points.reshape(*mesh_points.shape), **kwargs)\n return ax\n
"},{"location":"reference/utils/#jaxley.utils.plot_utils.plot_morph","title":"plot_morph(module_or_view, dims=(0, 1), col='k', ax=None, resolution=100, morph_plot_kwargs={})
","text":"Plot the detailed morphology.
Plots the traced morphology it was traced. That means at every point that was traced a disc of radius r
is plotted. The outline of the discs are then connected to form the morphology. This means every trace segement can be represented by a cone frustum. To prevent breaks in the morphology, each segement is connected with a ball joint.
Parameters:
Name Type Description Defaultmodule_or_view
Union[Module, View]
The module or view to plot.
requireddims
Tuple[int]
The dimensions to plot / to project the cylinder onto, i.e. [0,1] xy-plane or [0,1,2] for 3D.
(0, 1)
col
str
The color for all branches
'k'
ax
Optional[Axes]
The matplotlib axis to plot on.
None
morph_plot_kwargs
Dict
The plot kwargs for plt.fill.
{}
resolution
int
defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.
100
Returns:
Type DescriptionAxes
Plot of the detailed morphology.
Source code injaxley/utils/plot_utils.py
def plot_morph(\n module_or_view: Union[\"jx.Module\", \"jx.View\"],\n dims: Tuple[int] = (0, 1),\n col: str = \"k\",\n ax: Optional[Axes] = None,\n resolution: int = 100,\n morph_plot_kwargs: Dict = {},\n) -> Axes:\n \"\"\"Plot the detailed morphology.\n\n Plots the traced morphology it was traced. That means at every point that was\n traced a disc of radius `r` is plotted. The outline of the discs are then\n connected to form the morphology. This means every trace segement can be\n represented by a cone frustum. To prevent breaks in the morphology, each\n segement is connected with a ball joint.\n\n Args:\n module_or_view: The module or view to plot.\n dims: The dimensions to plot / to project the cylinder onto,\n i.e. [0,1] xy-plane or [0,1,2] for 3D.\n col: The color for all branches\n ax: The matplotlib axis to plot on.\n morph_plot_kwargs: The plot kwargs for plt.fill.\n\n resolution: defines the resolution of the mesh.\n If too low (typically <10), can result in errors.\n Useful too have a simpler mesh for plotting.\n\n Returns:\n Plot of the detailed morphology.\"\"\"\n if ax is None:\n fig = plt.figure(figsize=(3, 3))\n ax = fig.add_subplot(111) if len(dims) < 3 else plt.axes(projection=\"3d\")\n if len(dims) == 3:\n warn(\n \"rendering large morphologies in 3D can take a while. Consider projecting to 2D instead.\"\n )\n\n assert not np.any(\n np.isnan(module_or_view.xyzr[0][:, :3])\n ), \"missing xyz coordinates.\"\n\n for xyzr in module_or_view.xyzr:\n if len(xyzr) > 1:\n for xyzr1, xyzr2 in zip(xyzr[1:, :], xyzr[:-1, :]):\n dxyz = xyzr2[:3] - xyzr1[:3]\n length = np.sqrt(np.sum(dxyz**2))\n points = create_cone_frustum_mesh(\n length,\n xyzr1[-1],\n xyzr2[-1],\n bottom_dome=True,\n top_dome=True,\n resolution=resolution,\n )\n plot_mesh(\n points,\n dxyz,\n xyzr1[:3],\n np.array(dims),\n color=col,\n ax=ax,\n **morph_plot_kwargs,\n )\n else:\n points = create_cone_frustum_mesh(\n 0,\n xyzr[:, -1],\n xyzr[:, -1],\n bottom_dome=True,\n top_dome=True,\n resolution=resolution,\n )\n plot_mesh(\n points,\n np.ones(3),\n xyzr[0, :3],\n dims=np.array(dims),\n color=col,\n ax=ax,\n **morph_plot_kwargs,\n )\n\n return ax\n
"},{"location":"reference/utils/#jaxley.utils.jax_utils.nested_checkpoint_scan","title":"nested_checkpoint_scan(f, init, xs, length=None, *, nested_lengths, scan_fn=jax.lax.scan, checkpoint_fn=jax.checkpoint)
","text":"A version of lax.scan that supports recursive gradient checkpointing.
Code taken from: https://github.com/google/jax/issues/2139
The interface of nested_checkpoint_scan
exactly matches lax.scan, except for the required nested_lengths
argument.
The key feature of nested_checkpoint_scan
is that gradient calculations require O(max(nested_lengths)) memory, vs O(prod(nested_lengths)) for unnested scans, which it achieves by re-evaluating the forward pass len(nested_lengths) - 1
times.
nested_checkpoint_scan
reduces to lax.scan
when nested_lengths
has a single element.
Parameters:
Name Type Description Defaultf
Callable[[Carry, Dict[str, ndarray]], Tuple[Carry, Output]]
function to scan over.
requiredinit
Carry
initial value.
requiredxs
Dict[str, ndarray]
scanned over values.
requiredlength
Optional[int]
leading length of all dimensions
None
nested_lengths
Sequence[int]
required list of lengths to scan over for each level of checkpointing. The product of nested_lengths must match length (if provided) and the size of the leading axis for all arrays in xs
.
scan_fn
function matching the API of lax.scan
scan
checkpoint_fn
Callable[[Func], Func]
function matching the API of jax.checkpoint.
checkpoint
Source code in jaxley/utils/jax_utils.py
def nested_checkpoint_scan(\n f: Callable[[Carry, Dict[str, jnp.ndarray]], Tuple[Carry, Output]],\n init: Carry,\n xs: Dict[str, jnp.ndarray],\n length: Optional[int] = None,\n *,\n nested_lengths: Sequence[int],\n scan_fn=jax.lax.scan,\n checkpoint_fn: Callable[[Func], Func] = jax.checkpoint,\n):\n \"\"\"A version of lax.scan that supports recursive gradient checkpointing.\n\n Code taken from: https://github.com/google/jax/issues/2139\n\n The interface of `nested_checkpoint_scan` exactly matches lax.scan, except for\n the required `nested_lengths` argument.\n\n The key feature of `nested_checkpoint_scan` is that gradient calculations\n require O(max(nested_lengths)) memory, vs O(prod(nested_lengths)) for unnested\n scans, which it achieves by re-evaluating the forward pass\n `len(nested_lengths) - 1` times.\n\n `nested_checkpoint_scan` reduces to `lax.scan` when `nested_lengths` has a\n single element.\n\n Args:\n f: function to scan over.\n init: initial value.\n xs: scanned over values.\n length: leading length of all dimensions\n nested_lengths: required list of lengths to scan over for each level of\n checkpointing. The product of nested_lengths must match length (if\n provided) and the size of the leading axis for all arrays in ``xs``.\n scan_fn: function matching the API of lax.scan\n checkpoint_fn: function matching the API of jax.checkpoint.\n \"\"\"\n if length is not None and length != math.prod(nested_lengths):\n raise ValueError(f\"inconsistent {length=} and {nested_lengths=}\")\n\n def nested_reshape(x):\n x = jnp.asarray(x)\n new_shape = tuple(nested_lengths) + x.shape[1:]\n return x.reshape(new_shape)\n\n sub_xs = jax.tree_util.tree_map(nested_reshape, xs)\n return _inner_nested_scan(f, init, sub_xs, nested_lengths, scan_fn, checkpoint_fn)\n
"},{"location":"reference/utils/#jaxley.utils.syn_utils.gather_synapes","title":"gather_synapes(number_of_compartments, post_syn_comp_inds, current_each_synapse_voltage_term, current_each_synapse_constant_term)
","text":"Compute current at the post synapse.
All this does it that it sums the synaptic currents that come into a particular compartment. It returns an array of as many elements as there are compartments.
Source code injaxley/utils/syn_utils.py
def gather_synapes(\n number_of_compartments: jnp.ndarray,\n post_syn_comp_inds: np.ndarray,\n current_each_synapse_voltage_term: jnp.ndarray,\n current_each_synapse_constant_term: jnp.ndarray,\n) -> Tuple[jnp.ndarray, jnp.ndarray]:\n \"\"\"Compute current at the post synapse.\n\n All this does it that it sums the synaptic currents that come into a particular\n compartment. It returns an array of as many elements as there are compartments.\n \"\"\"\n incoming_currents_voltages = jnp.zeros((number_of_compartments,))\n incoming_currents_contant = jnp.zeros((number_of_compartments,))\n\n dnums = ScatterDimensionNumbers(\n update_window_dims=(),\n inserted_window_dims=(0,),\n scatter_dims_to_operand_dims=(0,),\n )\n incoming_currents_voltages = scatter_add(\n incoming_currents_voltages,\n post_syn_comp_inds[:, None],\n current_each_synapse_voltage_term,\n dnums,\n )\n incoming_currents_contant = scatter_add(\n incoming_currents_contant,\n post_syn_comp_inds[:, None],\n current_each_synapse_constant_term,\n dnums,\n )\n return incoming_currents_voltages, incoming_currents_contant\n
"},{"location":"tutorial/00_jaxley_api/","title":"Key concepts in Jaxley","text":"In this tutorial, we will introduce you to the basic concepts of Jaxley. You will learn about:
Here is a code snippet which you will learn to understand in this tutorial:
import jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import connect\nimport matplotlib.pyplot as plt\nimport numpy as np\n\n\n# Assembling different Modules into a Network\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=1)\ncell = jx.Cell(branch, parents=[-1, 0, 0])\nnet = jx.Network([cell]*3)\n\n# Navigating and inspecting the Modules using Views\ncell0 = net.cell(0)\ncell0.nodes\n\n# How to group together parts of Modules\nnet.cell(1).add_to_group(\"cell1\")\n\n# inserting channels in the membrane\nwith net.cell(0) as cell0:\n cell0.insert(Na())\n cell0.insert(K())\n\n# connecting two cells using a Synapse\npre_comp = cell0.branch(1).comp(0)\npost_comp = net.cell1.branch(0).comp(0)\n\nconnect(pre_comp, post_comp)\n
First, we import the relevant libraries:
from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import connect\nimport matplotlib.pyplot as plt\nimport numpy as np\n
"},{"location":"tutorial/00_jaxley_api/#modules","title":"Modules","text":"In Jaxley, we heavily rely on the concept of Modules to build biophyiscal models of neural systems at various scales. Jaxley implements four types of Modules: - Compartment
- Branch
- Cell
- Network
Modules can be connected together to build increasingly detailed and complex models. Compartment
-> Branch
-> Cell
-> Network
.
Compartment
s are the atoms of biophysical models in Jaxley. All mechanisms and synaptic connections live on the level of Compartment
s and can already be simulated using jx.integrate
on their own. Everything you do in Jaxley starts with a Compartment
.
comp = jx.Compartment() # single compartment model.\n
Mutliple Compartments
can be connected together to form longer, linear cables, which we call Branch
es and are equivalent to sections in NEURON
.
ncomp = 4\nbranch = jx.Branch([comp] * ncomp)\n
In order to construct cell morphologies in Jaxley, multiple Branches
can to be connected together as a Cell
:
# -1 indicates that the first branch has no parent branch.\n# The other two branches both have the 0-eth branch as their parent.\nparents = [-1, 0, 0]\ncell = jx.Cell([branch] * len(parents), parents)\n
Finally, several Cell
s can be grouped together to form a Network
, which can than be connected together using Synpase
s.
ncells = 2\nnet = jx.Network([cell]*ncells)\n\nnet.shape # shows you the num_cells, num_branches, num_comps\n
(2, 6, 24)\n
Every module tracks information about its current state and parameters in two Dataframes called nodes
and edges
. nodes
contains all the information that we associate with compartments in the model (each row corresponds to one compartment) and edges
tracks all the information relevant to synapses.
This means that you can easily keep track of the current state of your Module
and how it changes at all times.
net.nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v global_cell_index global_branch_index global_comp_index controlled_by_param 0 0 0 0 10.0 1.0 5000.0 1.0 -70.0 0 0 0 0 1 0 0 1 10.0 1.0 5000.0 1.0 -70.0 0 0 1 0 2 0 0 2 10.0 1.0 5000.0 1.0 -70.0 0 0 2 0 3 0 0 3 10.0 1.0 5000.0 1.0 -70.0 0 0 3 0 4 0 1 0 10.0 1.0 5000.0 1.0 -70.0 0 1 4 0 5 0 1 1 10.0 1.0 5000.0 1.0 -70.0 0 1 5 0 6 0 1 2 10.0 1.0 5000.0 1.0 -70.0 0 1 6 0 7 0 1 3 10.0 1.0 5000.0 1.0 -70.0 0 1 7 0 8 0 2 0 10.0 1.0 5000.0 1.0 -70.0 0 2 8 0 9 0 2 1 10.0 1.0 5000.0 1.0 -70.0 0 2 9 0 10 0 2 2 10.0 1.0 5000.0 1.0 -70.0 0 2 10 0 11 0 2 3 10.0 1.0 5000.0 1.0 -70.0 0 2 11 0 12 1 0 0 10.0 1.0 5000.0 1.0 -70.0 1 3 12 0 13 1 0 1 10.0 1.0 5000.0 1.0 -70.0 1 3 13 0 14 1 0 2 10.0 1.0 5000.0 1.0 -70.0 1 3 14 0 15 1 0 3 10.0 1.0 5000.0 1.0 -70.0 1 3 15 0 16 1 1 0 10.0 1.0 5000.0 1.0 -70.0 1 4 16 0 17 1 1 1 10.0 1.0 5000.0 1.0 -70.0 1 4 17 0 18 1 1 2 10.0 1.0 5000.0 1.0 -70.0 1 4 18 0 19 1 1 3 10.0 1.0 5000.0 1.0 -70.0 1 4 19 0 20 1 2 0 10.0 1.0 5000.0 1.0 -70.0 1 5 20 0 21 1 2 1 10.0 1.0 5000.0 1.0 -70.0 1 5 21 0 22 1 2 2 10.0 1.0 5000.0 1.0 -70.0 1 5 22 0 23 1 2 3 10.0 1.0 5000.0 1.0 -70.0 1 5 23 0 net.edges.head() # this is currently empty since we have not made any connections yet\n
global_edge_index global_pre_comp_index global_post_comp_index pre_locs post_locs type type_ind"},{"location":"tutorial/00_jaxley_api/#views","title":"Views","text":"Since these Module
s can become very complex, Jaxley utilizes so called View
s to make working with Module
s easy and intuitive.
The simplest way to navigate Modules is by navigating them via the hierachy that we introduced above. A View
is what you get when you index into the module. For example, for a Network
:
net.cell(0)\n
View with 0 different channels. Use `.nodes` for details.\n
Views behave very similarly to Module
s, i.e. the cell(0)
(the 0th cell of the network) behaves like the cell
we instantiated earlier. As such, cell(0)
also has a nodes
attribute, which keeps track of it\u2019s part of the network:
net.cell(0).nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v global_cell_index global_branch_index global_comp_index controlled_by_param 0 0 0 0 10.0 1.0 5000.0 1.0 -70.0 0 0 0 0 1 0 0 1 10.0 1.0 5000.0 1.0 -70.0 0 0 1 0 2 0 0 2 10.0 1.0 5000.0 1.0 -70.0 0 0 2 0 3 0 0 3 10.0 1.0 5000.0 1.0 -70.0 0 0 3 0 4 0 1 0 10.0 1.0 5000.0 1.0 -70.0 0 1 4 0 5 0 1 1 10.0 1.0 5000.0 1.0 -70.0 0 1 5 0 6 0 1 2 10.0 1.0 5000.0 1.0 -70.0 0 1 6 0 7 0 1 3 10.0 1.0 5000.0 1.0 -70.0 0 1 7 0 8 0 2 0 10.0 1.0 5000.0 1.0 -70.0 0 2 8 0 9 0 2 1 10.0 1.0 5000.0 1.0 -70.0 0 2 9 0 10 0 2 2 10.0 1.0 5000.0 1.0 -70.0 0 2 10 0 11 0 2 3 10.0 1.0 5000.0 1.0 -70.0 0 2 11 0 Let\u2019s use View
s to visualize only parts of the Network
. Before we do that, we create x, y, and z coordinates for the Network
:
# Compute xyz coordinates of the cells.\nnet.compute_xyz()\n\n# Move cells (since they are placed on top of each other by default).\nnet.cell(0).move(y=30)\n
We can now visualize the entire net
(i.e., the entire Module
) with the .vis()
method\u2026
# We can use the vis function to visualize Modules.\nfig, ax = plt.subplots(1, 1, figsize=(3,3))\nnet.vis(ax=ax)\n
<Axes: >\n
\u2026but we can also create a View
to visualize only parts of the net
:
# ... and Views\nfig, ax = plt.subplots(1,1, figsize=(3,3))\nnet.cell(0).vis(ax=ax, col=\"blue\") # View of the 0th cell of the network\nnet.cell(1).vis(ax=ax, col=\"red\") # View of the 1st cell of the network\n\nnet.cell(0).branch(0).vis(ax=ax, col=\"green\") # View of the 1st branch of the 0th cell of the network\nnet.cell(1).branch(1).comp(1).vis(ax=ax, col=\"black\", type=\"scatter\") # View of the 0th comp of the 1st branch of the 0th cell of the network\n
<Axes: >\n
"},{"location":"tutorial/00_jaxley_api/#how-to-create-views","title":"How to create View
s","text":"Above, we used net.cell(0)
to generate a View
of the 0-eth cell. Jaxley
supports many ways of performing such indexing:
# several types of indices are supported (lists, ranges, ...)\nnet.cell([0,1]).branch(\"all\").comp(0) # View of all 0th comps of all branches of cell 0 and 1\n\nbranch.loc(0.1) # Equivalent to `NEURON`s `loc`. Assumes branches are continous from 0-1.\n\nnet[0,0,0] # Modules/Views can also be lazily indexed\n\ncell0 = net.cell(0) # Views can be assigned to variables and only track the parts of the Module they belong to\ncell0.branch(1).comp(0) # Views can be continuely indexed\n
View with 0 different channels. Use `.nodes` for details.\n
cell0.nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v x y z global_cell_index global_branch_index global_comp_index controlled_by_param 0 0 0 0 10.0 1.0 5000.0 1.0 -70.0 5.000000 30.000000 0.0 0 0 0 0 1 0 0 1 10.0 1.0 5000.0 1.0 -70.0 15.000000 30.000000 0.0 0 0 1 0 2 0 0 2 10.0 1.0 5000.0 1.0 -70.0 25.000000 30.000000 0.0 0 0 2 0 3 0 0 3 10.0 1.0 5000.0 1.0 -70.0 35.000000 30.000000 0.0 0 0 3 0 4 0 1 0 10.0 1.0 5000.0 1.0 -70.0 44.850713 28.787322 0.0 0 1 4 0 5 0 1 1 10.0 1.0 5000.0 1.0 -70.0 54.552138 26.361966 0.0 0 1 5 0 6 0 1 2 10.0 1.0 5000.0 1.0 -70.0 64.253563 23.936609 0.0 0 1 6 0 7 0 1 3 10.0 1.0 5000.0 1.0 -70.0 73.954988 21.511253 0.0 0 1 7 0 8 0 2 0 10.0 1.0 5000.0 1.0 -70.0 44.850713 31.212678 0.0 0 2 8 0 9 0 2 1 10.0 1.0 5000.0 1.0 -70.0 54.552138 33.638034 0.0 0 2 9 0 10 0 2 2 10.0 1.0 5000.0 1.0 -70.0 64.253563 36.063391 0.0 0 2 10 0 11 0 2 3 10.0 1.0 5000.0 1.0 -70.0 73.954988 38.488747 0.0 0 2 11 0 net.shape\n
(2, 6, 24)\n
Note: In case you need even more flexibility in how you select parts of a Module, Jaxley provides a select
method, to give full control over the exact parts of the nodes
and edges
that are part of a View
. On examples of how this can be used, see the tutorial on advanced indexing.
You can also iterate over networks, cells, and branches:
# We set the radiuses to random values...\nradiuses = np.random.rand((24))\nnet.set(\"radius\", radiuses)\n\n# ...and then we set the length to 100.0 um if the radius is >0.5.\nfor cell in net:\n for branch in cell:\n for comp in branch:\n if comp.nodes.iloc[0][\"radius\"] > 0.5:\n comp.set(\"length\", 100.0)\n\n# Show the first five compartments:\nnet.nodes[[\"radius\", \"length\"]][:5]\n
radius length 0 0.763057 100.0 1 0.334882 10.0 2 0.805696 100.0 3 0.717921 100.0 4 0.079569 10.0 Finally, you can also use View
s in a context manager:
with net.cell(0).branch(0) as branch0:\n branch0.set(\"radius\", 2.0)\n branch0.set(\"length\", 2.5)\n\n# Show the first five compartments.\nnet.nodes[[\"radius\", \"length\"]][:5]\n
radius length 0 2.000000 2.5 1 2.000000 2.5 2 2.000000 2.5 3 2.000000 2.5 4 0.079569 10.0"},{"location":"tutorial/00_jaxley_api/#channels","title":"Channels","text":"The Module
s that we have created above will not do anything interesting, since by default Jaxley initializes them without any mechanisms in the membrane. To change this, we have to insert channels into the membrane. For this purpose Jaxley
implements Channel
s that can be inserted into any compartment using the insert
method of a Module
or a View
:
# insert a Leak channel into all compartments in the Module.\nnet.insert(Leak())\nnet.nodes.head() # Channel parameters are now also added to `nodes`.\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v global_cell_index global_branch_index global_comp_index controlled_by_param x y z Leak Leak_gLeak Leak_eLeak 0 0 0 0 2.5 2.000000 5000.0 1.0 -70.0 0 0 0 0 5.000000 30.000000 0.0 True 0.0001 -70.0 1 0 0 1 2.5 2.000000 5000.0 1.0 -70.0 0 0 1 0 15.000000 30.000000 0.0 True 0.0001 -70.0 2 0 0 2 2.5 2.000000 5000.0 1.0 -70.0 0 0 2 0 25.000000 30.000000 0.0 True 0.0001 -70.0 3 0 0 3 2.5 2.000000 5000.0 1.0 -70.0 0 0 3 0 35.000000 30.000000 0.0 True 0.0001 -70.0 4 0 1 0 10.0 0.079569 5000.0 1.0 -70.0 0 1 4 0 44.850713 28.787322 0.0 True 0.0001 -70.0 This is also were View
s come in handy, as it allows to easily target the insertion of channels to specific compartments.
# inserting several channels into parts of the network\nwith net.cell(0) as cell0:\n cell0.insert(Na())\n cell0.insert(K())\n\n# # The above is equivalent to:\n# net.cell(0).insert(Na())\n# net.cell(0).insert(K())\n\n# K and Na channels were only insert into cell 0\nnet.cell(\"all\").branch(0).comp(0).nodes[[\"global_cell_index\", \"Na\", \"K\", \"Leak\"]]\n
global_cell_index Na K Leak 0 0 True True True 12 1 False False True"},{"location":"tutorial/00_jaxley_api/#synapses","title":"Synapses","text":"To connect different cells together, Jaxley implements a connect
method, that can be used to couple 2 compartments together using a Synapse
. Synapses in Jaxley work only on the compartment level, that means to be able to connect two cells, you need to specify the exact compartments on a given cell to make the connections between. Below is an example of this:
# connecting two cells using a Synapse\npre_comp = cell0.branch(1).comp(0)\npost_comp = net.cell(1).branch(0).comp(0)\n\nconnect(pre_comp, post_comp, IonotropicSynapse())\n\nnet.edges\n
global_edge_index global_pre_comp_index global_post_comp_index type type_ind pre_locs post_locs IonotropicSynapse_gS IonotropicSynapse_e_syn IonotropicSynapse_k_minus IonotropicSynapse_s controlled_by_param 0 0 4 12 IonotropicSynapse 0 0.125 0.125 0.0001 0.0 0.025 0.2 0 As you can see above, now the edges
dataframe is also updated with the information of the newly added synapse.
Congrats! You should now have an intuitive understand of how to use Jaxley\u2019s API to construct, navigate and manipulate neuron models.
"},{"location":"tutorial/01_morph_neurons/","title":"Basics of Jaxley","text":"In this tutorial, you will learn how to:
Here is a code snippet which you will learn to understand in this tutorial:
import jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nimport matplotlib.pyplot as plt\n\n\n# Build the cell.\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=2)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1])\n\n# Insert channels.\ncell.insert(Leak())\ncell.branch(0).insert(Na())\ncell.branch(0).insert(K())\n\n# Change parameters.\ncell.set(\"axial_resistivity\", 200.0)\n\n# Visualize the morphology.\ncell.compute_xyz()\nfig, ax = plt.subplots(1, 1, figsize=(4, 4))\ncell.vis(ax=ax)\n\n# Stimulate.\ncurrent = jx.step_current(i_delay=1.0, i_dur=1.0, i_amp=0.1, delta_t=0.025, t_max=10.0)\ncell.branch(0).loc(0.0).stimulate(current)\n\n# Record.\ncell.branch(0).loc(0.0).record(\"v\")\n\n# Simulate and plot.\nv = jx.integrate(cell, delta_t=0.025)\nplt.plot(v.T)\n
First, we import the relevant libraries:
from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax.numpy as jnp\nfrom jax import jit\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import fully_connect\n
We will now build our first cell in Jaxley
. You have two options to do this: you can either build a cell bottom-up by defining the morphology yourselve, or you can load cells from SWC files.
To define a cell from scratch you first have to define a single compartment and then assemble those compartments into a branch:
comp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=2)\n
Next, we can assemble branches into a cell. To do so, we have to define for each branch what its parent branch is. A -1
entry means that this branch does not have a parent.
parents = jnp.asarray([-1, 0, 0, 1, 1])\ncell = jx.Cell(branch, parents=parents)\n
To learn more about Compartment
s, Branch
es, and Cell
s, see this tutorial.
Alternatively, you could also load cells from SWC with
cell = jx.read_swc(fname, ncomp=4)
Details on handling SWC files can be found in this tutorial.
"},{"location":"tutorial/01_morph_neurons/#visualize-the-cells","title":"Visualize the cells","text":"Cells can be visualized as follows:
cell.compute_xyz() # Only needed for visualization.\n\nfig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = cell.vis(ax=ax, col=\"k\")\n
"},{"location":"tutorial/01_morph_neurons/#insert-mechanisms","title":"Insert mechanisms","text":"Currently, the cell does not contain any kind of ion channel (not even a leak
). We can fix this by inserting a leak channel into the entire cell, and by inserting sodium and potassium into the zero-eth branch.
cell.insert(Leak())\ncell.branch(0).insert(Na())\ncell.branch(0).insert(K())\n
Once the cell is created, we can inspect its .nodes
attribute which lists all properties of the cell:
cell.nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v global_cell_index global_branch_index ... Na Na_gNa eNa vt Na_m Na_h K K_gK eK K_n 0 0 0 0 10.0 1.0 5000.0 1.0 -70.0 0 0 ... True 0.05 50.0 -60.0 0.2 0.2 True 0.005 -90.0 0.2 1 0 0 1 10.0 1.0 5000.0 1.0 -70.0 0 0 ... True 0.05 50.0 -60.0 0.2 0.2 True 0.005 -90.0 0.2 2 0 1 0 10.0 1.0 5000.0 1.0 -70.0 0 1 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 3 0 1 1 10.0 1.0 5000.0 1.0 -70.0 0 1 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 4 0 2 0 10.0 1.0 5000.0 1.0 -70.0 0 2 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 5 0 2 1 10.0 1.0 5000.0 1.0 -70.0 0 2 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 6 0 3 0 10.0 1.0 5000.0 1.0 -70.0 0 3 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 7 0 3 1 10.0 1.0 5000.0 1.0 -70.0 0 3 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 8 0 4 0 10.0 1.0 5000.0 1.0 -70.0 0 4 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 9 0 4 1 10.0 1.0 5000.0 1.0 -70.0 0 4 ... False NaN NaN NaN NaN NaN False NaN NaN NaN 10 rows \u00d7 25 columns
Note that Jaxley
uses the same units as the NEURON
simulator, which are listed here.
You can also inspect just parts of the cell
, for example its 1st branch:
cell.branch(1).nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v Leak Leak_gLeak ... Na_m Na_h K K_gK eK K_n global_cell_index global_branch_index global_comp_index controlled_by_param 2 0 0 0 10.0 1.0 5000.0 1.0 -70.0 True 0.0001 ... NaN NaN False NaN NaN NaN 0 1 2 1 3 0 0 1 10.0 1.0 5000.0 1.0 -70.0 True 0.0001 ... NaN NaN False NaN NaN NaN 0 1 3 1 2 rows \u00d7 25 columns
The easiest way to know which branch is the 1st branch (or, e.g., the zero-eth compartment of the 1st branch) is to plot it in a different color:
fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = cell.vis(ax=ax, col=\"k\")\n_ = cell.branch(1).vis(ax=ax, col=\"r\")\n_ = cell.branch(1).comp(1).vis(ax=ax, col=\"b\")\n
More background and features on indexing as cell.branch(0)
is in this tutorial.
You can change properties of the cell with the .set()
method:
cell.branch(1).set(\"axial_resistivity\", 200.0)\n
And we can again inspect the .nodes
to make sure that the axial resistivity indeed changed:
cell.branch(1).nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v Leak Leak_gLeak ... Na_m Na_h K K_gK eK K_n global_cell_index global_branch_index global_comp_index controlled_by_param 2 0 0 0 10.0 1.0 200.0 1.0 -70.0 True 0.0001 ... NaN NaN False NaN NaN NaN 0 1 2 1 3 0 0 1 10.0 1.0 200.0 1.0 -70.0 True 0.0001 ... NaN NaN False NaN NaN NaN 0 1 3 1 2 rows \u00d7 25 columns
In a similar way, you can modify channel properties or initial states (units are again here):
cell.branch(0).set(\"K_gK\", 0.01) # modify potassium conductance.\ncell.set(\"v\", -65.0) # modify initial voltage.\n
"},{"location":"tutorial/01_morph_neurons/#stimulate-the-cell","title":"Stimulate the cell","text":"We next stimulate one of the compartments with a step current. For this, we first define the step current (units are again here):
dt = 0.025\nt_max = 10.0\ntime_vec = np.arange(0, t_max+dt, dt)\ncurrent = jx.step_current(i_delay=1.0, i_dur=2.0, i_amp=0.08, delta_t=dt, t_max=t_max)\n\nfig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = plt.plot(time_vec, current)\n
We then stimulate one of the compartments of the cell with this step current:
cell.delete_stimuli()\ncell.branch(0).loc(0.0).stimulate(current)\n
Added 1 external_states. See `.externals` for details.\n
"},{"location":"tutorial/01_morph_neurons/#define-recordings","title":"Define recordings","text":"Next, you have to define where to record the voltage. In this case, we will record the voltage at two locations:
cell.delete_recordings()\ncell.branch(0).loc(0.0).record(\"v\")\ncell.branch(3).loc(1.0).record(\"v\")\n
Added 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\n
We can again visualize these locations to understand where we inserted recordings:
fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = cell.vis(ax=ax)\n_ = cell.branch(0).loc(0.0).vis(ax=ax, col=\"b\")\n_ = cell.branch(3).loc(1.0).vis(ax=ax, col=\"g\")\n
"},{"location":"tutorial/01_morph_neurons/#simulate-the-cell-response","title":"Simulate the cell response","text":"Having set up the cell, inserted stimuli and recordings, we are now ready to run a simulation with jx.integrate
:
voltages = jx.integrate(cell, delta_t=dt)\nprint(\"voltages.shape\", voltages.shape)\n
voltages.shape (2, 402)\n
The jx.integrate
function returns an array of shape (num_recordings, num_timepoints)
. In our case, we inserted 2
recordings and we simulated for 10ms at a 0.025 time step, which leads to 402 time steps.
We can now visualize the voltage response:
fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(voltages[0], c=\"b\")\n_ = ax.plot(voltages[1], c=\"orange\")\n
At the location of the first recording (in blue) the cell spiked, whereas at the second recording, it did not. This makes sense because we only inserted sodium and potassium channels into the first branch, but not in the entire cell.
Congrats! You have just run your first morphologically detailed neuron simulation in Jaxley
. We suggest to continue by learning how to build networks. If you are only interested in single cell simulations, you can directly jump to learning how to speed up simulations. If you want to simulate detailed morphologies from SWC files, checkout our tutorial on working with detailed morphologies.
In this tutorial, you will learn how to:
.edges
attribute to inspect and change synaptic parametersHere is a code snippet which you will learn to understand in this tutorial:
import jaxley as jx\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import connect\n\n\n# Define a network. `cell` is defined as in previous tutorial.\nnet = jx.Network([cell for _ in range(11)])\n\n# Define synapses.\nfully_connect(\n net.cell(range(10)),\n net.cell(10),\n IonotropicSynapse(),\n)\n\n# Change synaptic parameters.\nnet.select(edges=[0, 1]).set(\"IonotropicSynapse_gS\", 0.1) # nS\n\n# Visualize the network.\nnet.compute_xyz()\nfig, ax = plt.subplots(1, 1, figsize=(4, 4))\nnet.vis(ax=ax, detail=\"full\", layers=[10, 1]) # or `detail=\"point\"`.\n
In the previous tutorial, you learned how to build single cells with morphological detail, how to insert stimuli and recordings, and how to run a first simulation. In this tutorial, we will define networks of multiple cells and connect them with synapses. Let\u2019s get started:
from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax.numpy as jnp\nfrom jax import jit\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import fully_connect, connect\n
"},{"location":"tutorial/02_small_network/#define-the-network","title":"Define the network","text":"First, we define a cell as you saw in the previous tutorial.
comp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=4)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1, 2, 2])\n
We can assemble multiple cells into a network by using jx.Network
, which takes a list of jx.Cell
s. Here, we assemble 11 cells into a network:
num_cells = 11\nnet = jx.Network([cell for _ in range(num_cells)])\n
At this point, we can already visualize this network:
net.compute_xyz()\nnet.rotate(180)\nfig, ax = plt.subplots(1, 1, figsize=(3, 6))\n_ = net.vis(ax=ax, detail=\"full\", layers=[10, 1], layer_kwargs={\"within_layer_offset\": 150, \"between_layer_offset\": 200})\n
Note: you can use move_to
to have more control over the location of cells, e.g.: network.cell(i).move_to(x=0, y=200)
.
As you can see, the neurons are not connected yet. Let\u2019s fix this by connecting neurons with synapses. We will build a network consisting of two layers: 10 neurons in the input layer and 1 neuron in the output layer.
We can use Jaxley
\u2019s fully_connect
method to connect these layers:
pre = net.cell(range(10))\npost = net.cell(10)\nfully_connect(pre, post, IonotropicSynapse())\n
Let\u2019s visualize this again:
fig, ax = plt.subplots(1, 1, figsize=(3, 6))\n_ = net.vis(ax=ax, detail=\"full\", layers=[10, 1], layer_kwargs={\"within_layer_offset\": 150, \"between_layer_offset\": 200})\n
As you can see, the full_connect
method inserted one synapse (in blue) from every neuron in the first layer to the output neuron. The fully_connect
method builds this synapse from the zero-eth compartment and zero-eth branch of the presynaptic neuron onto a random branch of the postsynaptic neuron. If you want more control over the pre- and post-synaptic branches, you can use the connect
method:
pre = net.cell(0).branch(5).loc(1.0)\npost = net.cell(10).branch(0).loc(0.0)\nconnect(pre, post, IonotropicSynapse())\n
fig, ax = plt.subplots(1, 1, figsize=(3, 6))\n_ = net.vis(ax=ax, detail=\"full\", layers=[10, 1], layer_kwargs={\"within_layer_offset\": 150, \"between_layer_offset\": 200})\n
"},{"location":"tutorial/02_small_network/#inspecting-and-changing-synaptic-parameters","title":"Inspecting and changing synaptic parameters","text":"You can inspect synaptic parameters via the .edges
attribute:
net.edges\n
global_edge_index global_pre_comp_index global_post_comp_index type type_ind pre_locs post_locs IonotropicSynapse_gS IonotropicSynapse_e_syn IonotropicSynapse_k_minus IonotropicSynapse_s controlled_by_param 0 0 0 286 IonotropicSynapse 0 0.125 0.625 0.0001 0.0 0.025 0.2 0 1 1 28 298 IonotropicSynapse 0 0.125 0.625 0.0001 0.0 0.025 0.2 0 2 2 56 286 IonotropicSynapse 0 0.125 0.625 0.0001 0.0 0.025 0.2 0 3 3 84 295 IonotropicSynapse 0 0.125 0.875 0.0001 0.0 0.025 0.2 0 4 4 112 302 IonotropicSynapse 0 0.125 0.625 0.0001 0.0 0.025 0.2 0 5 5 140 288 IonotropicSynapse 0 0.125 0.125 0.0001 0.0 0.025 0.2 0 6 6 168 287 IonotropicSynapse 0 0.125 0.875 0.0001 0.0 0.025 0.2 0 7 7 196 305 IonotropicSynapse 0 0.125 0.375 0.0001 0.0 0.025 0.2 0 8 8 224 299 IonotropicSynapse 0 0.125 0.875 0.0001 0.0 0.025 0.2 0 9 9 252 284 IonotropicSynapse 0 0.125 0.125 0.0001 0.0 0.025 0.2 0 10 10 23 280 IonotropicSynapse 0 0.875 0.125 0.0001 0.0 0.025 0.2 0 To modify a parameter of all synapses you can again use .set()
:
net.set(\"IonotropicSynapse_gS\", 0.0003) # nS\n
To modify individual syanptic parameters, use the .select()
method. Below, we change the values of the first two synapses:
net.select(edges=[0, 1]).set(\"IonotropicSynapse_gS\", 0.0004) # nS\n
For more details on how to flexibly set synaptic parameters (e.g., by cell type, or by pre-synaptic cell index,\u2026), see this tutorial.
"},{"location":"tutorial/02_small_network/#stimulating-recording-and-simulating-the-network","title":"Stimulating, recording, and simulating the network","text":"We will now set up a simulation of the network. This works exactly as it does for single neurons:
# Stimulus.\ni_delay = 3.0 # ms\ni_amp = 0.05 # nA\ni_dur = 2.0 # ms\n\n# Duration and step size.\ndt = 0.025 # ms\nt_max = 50.0 # ms\n
time_vec = jnp.arange(0.0, t_max + dt, dt)\n
As a simple example, we insert sodium, potassium, and leak into every compartment of every cell of the network.
net.insert(Na())\nnet.insert(K())\nnet.insert(Leak())\n
We stimulate every neuron in the input layer and record the voltage from the output neuron:
current = jx.step_current(i_delay, i_dur, i_amp, dt, t_max)\nnet.delete_stimuli()\nfor stim_ind in range(10):\n net.cell(stim_ind).branch(0).loc(0.0).stimulate(current)\n\nnet.delete_recordings()\nnet.cell(10).branch(0).loc(0.0).record()\n
Added 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 external_states. See `.externals` for details.\nAdded 1 recordings. See `.recordings` for details.\n
Finally, we can again run the network simulation and plot the result:
s = jx.integrate(net, delta_t=dt)\n
fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(s.T)\n
That\u2019s it! You now know how to simulate networks of morphologically detailed neurons. We recommend that you now have a look at how you can speed up your simulation. To learn more about handling synaptic parameters, we recommend to check out this tutorial.
"},{"location":"tutorial/04_jit_and_vmap/","title":"Speeding up simulations","text":"In this tutorial, you will learn how to:
Jaxley
jit
to compile your simulations and make them faster vmap
to parallelize simulations on GPUs Here is a code snippet which you will learn to understand in this tutorial:
from jax import jit, vmap\n\n\ncell = ... # See tutorial on Basics of Jaxley.\n\ndef simulate(params):\n param_state = None\n param_state = cell.data_set(\"Na_gNa\", params[0], param_state)\n param_state = cell.data_set(\"K_gK\", params[1], param_state)\n return jx.integrate(cell, param_state=param_state, delta_t=0.025)\n\n# Define 100 sets of sodium and potassium conductances.\nall_params = jnp.asarray(np.random.rand(100, 2))\n\n# Fast for-loops with jit compilation.\njitted_simulate = jit(simulate)\nvoltages = [jitted_simulate(params) for params in all_params]\n\n# Using vmap for parallelization.\nvmapped_simulate = vmap(jitted_simulate, in_axes=(0,))\nvoltages = vmapped_simulate(all_params)\n
In the previous tutorials, you learned how to build single cells or networks and how to change their parameters. In this tutorial, you will learn how to speed up such simulations by many orders of magnitude. This can be achieved in to ways:
Let\u2019s get started!
"},{"location":"tutorial/04_jit_and_vmap/#using-gpu-or-cpu","title":"Using GPU or CPU","text":"In Jaxley
you can set whether you want to use gpu
or cpu
with the following lines at the beginning of your script:
from jax import config\nconfig.update(\"jax_platform_name\", \"cpu\")\n
JAX
(and Jaxley
) also allow to choose between float32
and float64
. Especially on GPUs, float32
will be faster, but we have experienced stability issues when simulating morphologically detailed neurons with float32
.
config.update(\"jax_enable_x64\", True) # Set to false to use `float32`.\n
Next, we will import relevant libraries:
import matplotlib.pyplot as plt\nimport numpy as np\nimport jax.numpy as jnp\nfrom jax import jit, vmap\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\n
"},{"location":"tutorial/04_jit_and_vmap/#building-the-cell-or-network","title":"Building the cell or network","text":"We first build a cell (or network) in the same way as we showed in the previous tutorials:
dt = 0.025\nt_max = 10.0\n\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=4)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1, 2, 2])\n\ncell.insert(Na())\ncell.insert(K())\ncell.insert(Leak())\n\ncell.delete_stimuli()\ncurrent = jx.step_current(i_delay=1.0, i_dur=1.0, i_amp=0.1, delta_t=dt, t_max=t_max)\ncell.branch(0).loc(0.0).stimulate(current)\n\ncell.delete_recordings()\ncell.branch(0).loc(0.0).record()\n
Added 1 external_states. See `.externals` for details.\nAdded 1 recordings. See `.recordings` for details.\n
"},{"location":"tutorial/04_jit_and_vmap/#parameter-sweeps","title":"Parameter sweeps","text":"Assume you want to run the same cell with many different values for the sodium and potassium conductance, for example for genetic algorithms or for parameter sweeps. To do this efficiently in Jaxley
, you have to use the data_set()
method (in combination with jit
and vmap
, as shown later):
def simulate(params):\n param_state = None\n param_state = cell.data_set(\"Na_gNa\", params[0], param_state)\n param_state = cell.data_set(\"K_gK\", params[1], param_state)\n return jx.integrate(cell, param_state=param_state, delta_t=dt)\n
The .data_set()
method takes three arguments:
1) the name of the parameter you want to set. Jaxley
allows to set the following parameters: \u201cradius\u201d, \u201clength\u201d, \u201caxial_resistivity\u201d, as well as all parameters of channels and synapses. 2) the value of the parameter. 3) a param_state
which is initialized as None
and is modified by .data_set()
. This has to be passed to jx.integrate()
.
Having done this, the simplest (but least efficient) way to perform the parameter sweep is to run a for-loop over many parameter sets:
# Define 5 sets of sodium and potassium conductances.\nall_params = jnp.asarray(np.random.rand(5, 2))\n\nvoltages = jnp.asarray([simulate(params) for params in all_params])\nprint(\"voltages.shape\", voltages.shape)\n
voltages.shape (5, 1, 402)\n
The resulting voltages have shape (num_simulations, num_recordings, num_timesteps)
.
In addition to running sweeps across multiple parameters, you can also run sweeeps across multiple stimuli (e.g. step current stimuli of different amplitudes. You can achieve this with the data_stimulate()
method:
def simulate(i_amp):\n current = jx.step_current(1.0, 1.0, i_amp, 0.025, 10.0)\n\n data_stimuli = None\n data_stimuli = cell.branch(0).comp(0).data_stimulate(current, data_stimuli)\n return jx.integrate(cell, data_stimuli=data_stimuli)\n
"},{"location":"tutorial/04_jit_and_vmap/#speeding-up-for-loops-via-jit-compilation","title":"Speeding up for loops via jit
compilation","text":"We can speed up such parameter sweeps (or stimulus sweeps) with jit
compilation. jit
compilation will compile the simulation when it is run for the first time, such that every other simulation will be must faster. This can be achieved by defining a new function which uses JAX
\u2019s jit()
:
jitted_simulate = jit(simulate)\n
# First run, will be slow.\nvoltages = jitted_simulate(all_params[0])\n
# More runs, will be much faster.\nvoltages = jnp.asarray([jitted_simulate(params) for params in all_params])\nprint(\"voltages.shape\", voltages.shape)\n
voltages.shape (5, 1, 402)\n
jit
compilation can be up to 10k times faster, especially for small simulations with few compartments. For very large models, the gain obtained with jit
will be much smaller (jit
may even provide no speed up at all).
vmap
","text":"Another way to speed up parameter sweeps is with GPU parallelization. Parallelization in Jaxley
can be achieved by using vmap
of JAX
. To do this, we first create a new function that handles multiple parameter sets directly:
# Using vmap for parallelization.\nvmapped_simulate = vmap(jitted_simulate)\n
We can then run this method on all parameter sets (all_params.shape == (100, 2)
), and Jaxley
will automatically parallelize across them. Of course, you will only get a speed-up if you have a GPU available and you specified gpu
as device in the beginning of this tutorial.
voltages = vmapped_simulate(all_params)\n
GPU parallelization with vmap
can give a large speed-up, which can easily be 2-3 orders of magnitude.
jit
and vmap
","text":"Finally, you can also combine using jit
and vmap
. For example, you can run multiple batches of many parallel simulations. Each batch can be parallelized with vmap
and simulating each batch can be compiled with jit
:
jitted_vmapped_simulate = jit(vmap(simulate))\n
for batch in range(10):\n all_params = jnp.asarray(np.random.rand(5, 2))\n voltages_batch = jitted_vmapped_simulate(all_params)\n
That\u2019s all you have to know about jit
and vmap
! If you have worked through this and the previous tutorials, you should be ready to set up your first network simulations.
If you want to learn more, we recommend you to read the tutorial on building channel and synapse models.
Alternatively, you can also directly jump ahead to the tutorial on training biophysical networks which will teach you how you can optimize parameters of biophysical models with gradient descent.
Finally, if you want to learn more about JAX, check out their tutorial on jit or their tutorial on vmap.
"},{"location":"tutorial/05_channel_and_synapse_models/","title":"Building ion channel models","text":"In this tutorial, you will learn how to:
Jaxley
This tutorial assumes that you have already learned how to build basic simulations.
from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax\nimport jax.numpy as jnp\nfrom jax import jit, value_and_grad\n\nimport jaxley as jx\n
First, we define a cell as you saw in the previous tutorial:
comp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=4)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1, 1, 2, 2])\n
You have also already learned how to insert preconfigured channels into Jaxley
models:
cell.insert(Na())\ncell.insert(K())\ncell.insert(Leak())\n
In this tutorial, we will show you how to build your own channel and synapse models.
"},{"location":"tutorial/05_channel_and_synapse_models/#your-own-channel","title":"Your own channel","text":"Below is how you can define your own channel. We will go into detail about individual parts of the code in the next couple of cells.
import jax.numpy as jnp\nfrom jaxley.channels import Channel\nfrom jaxley.solver_gate import solve_gate_exponential\n\n\ndef exp_update_alpha(x, y):\n return x / (jnp.exp(x / y) - 1.0)\n\nclass Potassium(Channel):\n \"\"\"Potassium channel.\"\"\"\n\n def __init__(self, name = None):\n self.current_is_in_mA_per_cm2 = True\n super().__init__(name)\n self.channel_params = {\"gK_new\": 1e-4}\n self.channel_states = {\"n_new\": 0.0}\n self.current_name = \"i_K\"\n\n def update_states(self, states, dt, v, params):\n \"\"\"Update state.\"\"\"\n ns = states[\"n_new\"]\n alpha = 0.01 * exp_update_alpha(-(v + 55), 10)\n beta = 0.125 * jnp.exp(-(v + 65) / 80)\n new_n = solve_gate_exponential(ns, dt, alpha, beta)\n return {\"n_new\": new_n}\n\n def compute_current(self, states, v, params):\n \"\"\"Return current.\"\"\"\n ns = states[\"n_new\"]\n kd_conds = params[\"gK_new\"] * ns**4 # S/cm^2\n\n e_kd = -77.0 \n return kd_conds * (v - e_kd)\n\n def init_state(self, states, v, params, delta_t):\n alpha = 0.01 * exp_update_alpha(-(v + 55), 10)\n beta = 0.125 * jnp.exp(-(v + 65) / 80)\n return {\"n_new\": alpha / (alpha + beta)}\n
Let\u2019s look at each part of this in detail.
The below is simply a helper function for the solver of the gate variables:
def exp_update_alpha(x, y):\n return x / (jnp.exp(x / y) - 1.0)\n
Next, we define our channel as a class. It should inherit from the Channel
class and define channel_params
, channel_states
, and current_name
. You also need to set self.current_is_in_mA_per_cm2=True
as the first line on your __init__()
method. This is to acknowledge that your current is returned in mA/cm2
(not in uA/cm2
, as would have been required in Jaxley versions 0.4.0 or older).
class Potassium(Channel):\n \"\"\"Potassium channel.\"\"\"\n\n def __init__(self, name=None):\n self.current_is_in_mA_per_cm2 = True\n super().__init__(name)\n self.channel_params = {\"gK_new\": 1e-4}\n self.channel_states = {\"n_new\": 0.0}\n self.current_name = \"i_K\"\n
Next, we have the update_states()
method, which updates the gating variables:
def update_states(self, states, dt, v, params):\n
Every channel you define must have an update_states()
method which takes exactly these five arguments (self, states, dt, v, params). The inputs states
to the update_states
method is a dictionary which contains all states that are updated (including states of other channels). v
is a jnp.ndarray
which contains the voltage of a single compartment (shape ()
). Let\u2019s get the state of the potassium channel which we are building here:
ns = states[\"n_new\"]\n
Next, we update the state of the channel. In this example, we do this with exponential Euler, but you can implement any solver yourself:
alpha = 0.01 * exp_update_alpha(-(v + 55), 10)\nbeta = 0.125 * jnp.exp(-(v + 65) / 80)\nnew_n = solve_gate_exponential(ns, dt, alpha, beta)\nreturn {\"n_new\": new_n}\n
A channel also needs a compute_current()
method which returns the current throught the channel:
def compute_current(self, states, v, params):\n ns = states[\"n_new\"]\n\n # Multiply with 1000 to convert Siemens to milli Siemens.\n kd_conds = params[\"gK_new\"] * ns**4 # S/cm^2\n\n e_kd = -77.0 \n current = kd_conds * (v - e_kd)\n return current\n
Finally, the init_state()
method can be implemented optionally. It can be used to automatically compute the initial state based on the voltage when cell.init_states()
is run.
Alright, done! We can now insert this channel into any jx.Module
such as our cell:
cell.insert(Potassium())\n
cell.delete_stimuli()\ncurrent = jx.step_current(1.0, 1.0, 0.1, 0.025, 10.0)\ncell.branch(0).comp(0).stimulate(current)\n\ncell.delete_recordings()\ncell.branch(0).comp(0).record()\n
Added 1 external_states. See `.externals` for details.\nAdded 1 recordings. See `.recordings` for details.\n
s = jx.integrate(cell)\n
fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(s.T[:-1])\n_ = ax.set_ylim([-80, 50])\n_ = ax.set_xlabel(\"Time (ms)\")\n_ = ax.set_ylabel(\"Voltage (mV)\")\n
"},{"location":"tutorial/05_channel_and_synapse_models/#your-own-synapse","title":"Your own synapse","text":"The parts below assume that you have already learned how to build network simulations in Jaxley
.
Note that again, a synapse needs to have the two functions update_states
and compute_current
with all input arguments shown below.
The below is an example of how to define your own synapse model in Jaxley
:
import jax.numpy as jnp\nfrom jaxley.synapses.synapse import Synapse\n\n\nclass TestSynapse(Synapse):\n \"\"\"\n Compute syanptic current and update syanpse state.\n \"\"\"\n def __init__(self, name = None):\n super().__init__(name)\n self.synapse_params = {\"gChol\": 0.001, \"eChol\": 0.0}\n self.synapse_states = {\"s_chol\": 0.1}\n\n def update_states(self, states, delta_t, pre_voltage, post_voltage, params):\n \"\"\"Return updated synapse state and current.\"\"\"\n s_inf = 1.0 / (1.0 + jnp.exp((-35.0 - pre_voltage) / 10.0))\n exp_term = jnp.exp(-delta_t)\n new_s = states[\"s_chol\"] * exp_term + s_inf * (1.0 - exp_term)\n return {\"s_chol\": new_s}\n\n def compute_current(self, states, pre_voltage, post_voltage, params):\n g_syn = params[\"gChol\"] * states[\"s_chol\"]\n return g_syn * (post_voltage - params[\"eChol\"])\n
As you can see above, synapses follow closely how channels are defined. The main difference is that the compute_current
method takes two voltages: the pre-synaptic voltage (a jnp.ndarray
of shape ()
) and the post-synaptic voltage (a jnp.ndarray
of shape ()
).
net = jx.Network([cell for _ in range(3)])\n
from jaxley.connect import connect\n\npre = net.cell(0).branch(0).loc(0.0)\npost = net.cell(1).branch(0).loc(0.0)\nconnect(pre, post, TestSynapse())\n
net.cell(0).branch(0).loc(0.0).stimulate(jx.step_current(1.0, 2.0, 0.1, 0.025, 10.0))\nfor i in range(3):\n net.cell(i).branch(0).loc(0.0).record()\n
Added 1 external_states. See `.externals` for details.\nAdded 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\n
s = jx.integrate(net)\n
fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(s.T[:-1])\n_ = ax.set_ylim([-80, 50])\n_ = ax.set_xlabel(\"Time (ms)\")\n_ = ax.set_ylabel(\"Voltage (mV)\")\n
That\u2019s it! You are now ready to build your own custom simulations and equip them with channel and synapse models!
This tutorial does not have an immediate follow-up tutorial. If you have not done so already, you can check out our tutorial on training biophysical networks which will teach you how you can optimize parameters of biophysical models with gradient descent.
"},{"location":"tutorial/06_groups/","title":"Defining groups","text":"In this tutorial, you will learn how to:
Jaxley
Here is a code snippet which you will learn to understand in this tutorial:
from jax import jit, vmap\n\n\nnet = ... # See tutorial on Basics of Jaxley.\n\nnet.cell(0).add_to_group(\"fast_spiking\")\nnet.cell(1).add_to_group(\"slow_spiking\")\n\ndef simulate(params):\n param_state = None\n param_state = net.fast_spiking.data_set(\"HH_gNa\", params[0], param_state)\n param_state = net.slow_spiking.data_set(\"HH_gNa\", params[1], param_state)\n return jx.integrate(net, param_state=param_state)\n\n# Define sodium for fast and slow spiking neurons.\nparams = jnp.asarray([1.0, 0.1])\n\n# Run simulation.\nvoltages = simulate(params)\n
In many cases, you might want to group several compartments (or branches, or cells) and assign a unique parameter or mechanism to this group. For example, you might want to define a couple of branches as basal and then assign a Hodgkin-Huxley mechanism only to those branches. Or you might define a couple of cells as fast spiking and assign them a high value for the sodium conductance. We describe how you can do this in this tutorial.
from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport time\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax\nimport jax.numpy as jnp\nfrom jax import jit, value_and_grad\n\nimport jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.synapses import IonotropicSynapse\nfrom jaxley.connect import fully_connect\n
First, we define a network as you saw in the previous tutorial:
comp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=2)\ncell = jx.Cell(branch, parents=[-1, 0, 0, 1])\nnetwork = jx.Network([cell for _ in range(3)])\n\npre = network.cell([0, 1])\npost = network.cell([2])\nfully_connect(pre, post, IonotropicSynapse())\n\nnetwork.insert(Na())\nnetwork.insert(K())\nnetwork.insert(Leak())\n
"},{"location":"tutorial/06_groups/#group-apical-dendrites","title":"Group: apical dendrites","text":"Assume that, in each of the five neurons in this network, the second and forth branch are apical dendrites. We can define this as:
for cell_ind in range(3):\n network.cell(cell_ind).branch(1).add_to_group(\"apical\")\n network.cell(cell_ind).branch(3).add_to_group(\"apical\")\n
After this, we can access network.apical
as we previously accesses anything else:
network.apical.set(\"radius\", 0.3)\n
network.apical.view\n
View with 3 different channels. Use `.nodes` for details.\n
"},{"location":"tutorial/06_groups/#group-fast-spiking","title":"Group: fast spiking","text":"Similarly, you could define a group of fast-spiking cells. Assume that the first and second cell are fast-spiking:
network.cell(0).add_to_group(\"fast_spiking\")\nnetwork.cell(1).add_to_group(\"fast_spiking\")\n
network.fast_spiking.set(\"Na_gNa\", 0.4)\n
network.fast_spiking.view\n
View with 3 different channels. Use `.nodes` for details.\n
"},{"location":"tutorial/06_groups/#groups-from-swc-files","title":"Groups from SWC files","text":"If you are reading .swc
morphologigies, you can automatically assign groups with
jx.read_swc(file_name, nseg=n, assign_groups=True).\n
After that, you can directly use cell.soma
, cell.apical
, cell.basal
, or cell.axon
."},{"location":"tutorial/06_groups/#how-groups-are-interpreted-by-make_trainable","title":"How groups are interpreted by .make_trainable()
","text":"If you make a parameter of a group
trainable, then it will be treated as a single shared parameter for a given property:
network.fast_spiking.make_trainable(\"Na_gNa\")\n
Number of newly added trainable parameters: 1. Total number of trainable parameters: 1\n
As such, get_parameters()
returns only a single trainable parameter, which will be the sodium conductance for every compartment of every fast-spiking neuron:
network.get_parameters()\n
[{'Na_gNa': Array([0.4], dtype=float64)}]\n
If, instead, you would want a separate parameter for every fast-spiking cell, you should not use the group, but instead do the following (remember that fast-spiking neurons had indices [0,1]):
network.cell([0,1]).make_trainable(\"axial_resistivity\")\n
Number of newly added trainable parameters: 2. Total number of trainable parameters: 3\n
network.get_parameters()\n
[{'Na_gNa': Array([0.4], dtype=float64)},\n {'axial_resistivity': Array([5000., 5000.], dtype=float64)}]\n
This generated two parameters for the axial resistivitiy, each corresponding to one cell.
"},{"location":"tutorial/06_groups/#summary","title":"Summary","text":"Groups allow you to organize your simulation in a more intuitive way, and they allow to perform parameter sharing with make_trainable()
.
In this tutorial, you will learn how to train biophysical models in Jaxley
. This includes the following:
Here is a code snippet which you will learn to understand in this tutorial:
from jax import jit, vmap, value_and_grad\nimport jaxley as jx\nimport jaxley.optimize.transforms as jt\n\nnet = ... # See tutorial on the basics of `Jaxley`.\n\n# Define which parameters to train.\nnet.cell(\"all\").make_trainable(\"HH_gNa\")\nnet.IonotropicSynapse.make_trainable(\"IonotropicSynapse_gS\")\nparameters = net.get_parameters()\n\n# Define parameter transform and apply it to the parameters.\ntransform = jx.ParamTransform([\n {\"IonotropicSynapse_gS\": jt.SigmoidTransform(0.0, 1.0)},\n {\"HH_gNa\":jt.SigmoidTransform(0.0, 1, 0)}\n])\n\nopt_params = transform.inverse(parameters)\n\n# Define simulation and batch it across stimuli.\ndef simulate(params, datapoint):\n current = jx.datapoint_to_step_currents(i_delay=1.0, i_dur=1.0, i_amps=datapoint, dt=0.025, t_max=5.0)\n data_stimuli = net.cell(0).branch(0).comp(0).data_stimulate(current, None)\n return jx.integrate(net, params=params, data_stimuli=data_stimuli, checkpoint_inds=[20, 20], delta_t=0.025)\n\nbatch_simulate = vmap(simulate, in_axes=(None, 0))\n\n# Define loss function and its gradient.\ndef loss_fn(opt_params, datapoints, label):\n params = transform.forward(opt_params)\n voltages = batch_simulate(params, datapoints)\n return jnp.abs(jnp.mean(voltages) - label)\n\ngrad_fn = jit(value_and_grad(loss_fn, argnums=0))\n\n# Define data and dataloader.\ndata = jnp.asarray(np.random.randn(100, 3))\ndataloader = Dataset.from_tensor_slices((inputs, labels))\ndataloader = dataloader.shuffle(dataloader.cardinality()).batch(4)\n\n# Define the optimizer.\noptimizer = optax.Adam(lr=0.01)\nopt_state = optimizer.init_state(opt_params)\n\nfor epoch in range(10):\n for batch in dataloader:\n stimuli = batch[0].numpy()\n labels = batch[1].numpy()\n loss, gradient = grad_fn(opt_params, stimuli, labels)\n\n # Optimizer step.\n updates, opt_state = optimizer.update(gradient, opt_state)\n opt_params = optax.apply_updates(opt_params, updates)\n
from jax import config\nconfig.update(\"jax_enable_x64\", True)\nconfig.update(\"jax_platform_name\", \"cpu\")\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport jax\nimport jax.numpy as jnp\nfrom jax import jit, vmap, value_and_grad\n\nimport jaxley as jx\nfrom jaxley.channels import Leak\nfrom jaxley.synapses import TanhRateSynapse\nfrom jaxley.connect import fully_connect\n
First, we define a network as you saw in the previous tutorial:
_ = np.random.seed(0) # For synaptic locations.\n\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=2)\ncell = jx.Cell(branch, parents=[-1, 0, 0])\nnet = jx.Network([cell for _ in range(3)])\n\npre = net.cell([0, 1])\npost = net.cell([2])\nfully_connect(pre, post, TanhRateSynapse())\n\n# Change some default values of the tanh synapse.\nnet.TanhRateSynapse.set(\"TanhRateSynapse_x_offset\", -60.0)\nnet.TanhRateSynapse.set(\"TanhRateSynapse_gS\", 1e-3)\nnet.TanhRateSynapse.set(\"TanhRateSynapse_slope\", 0.1)\n\nnet.insert(Leak())\n
This network consists of three neurons arranged in two layers:
net.compute_xyz()\nnet.rotate(180)\nfig, ax = plt.subplots(1, 1, figsize=(3, 2))\n_ = net.vis(ax=ax, detail=\"full\", layers=[2, 1], layer_kwargs={\"within_layer_offset\": 100.0, \"between_layer_offset\": 100.0}) \n
We consider the last neuron as the output neuron and record the voltage from there:
net.delete_recordings()\nnet.cell(0).branch(0).loc(0.0).record()\nnet.cell(1).branch(0).loc(0.0).record()\nnet.cell(2).branch(0).loc(0.0).record()\n
Added 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\nAdded 1 recordings. See `.recordings` for details.\n
"},{"location":"tutorial/07_gradient_descent/#defining-a-dataset","title":"Defining a dataset","text":"We will train this biophysical network on a classification task. The inputs will be values and the label is binary:
inputs = jnp.asarray(np.random.rand(100, 2))\nlabels = jnp.asarray((inputs[:, 0] + inputs[:, 1]) > 1.0)\n
fig, ax = plt.subplots(1, 1, figsize=(3, 2))\n_ = ax.scatter(inputs[labels, 0], inputs[labels, 1])\n_ = ax.scatter(inputs[~labels, 0], inputs[~labels, 1])\n
labels = labels.astype(float)\n
"},{"location":"tutorial/07_gradient_descent/#defining-trainable-parameters","title":"Defining trainable parameters","text":"net.delete_trainables()\n
This follows the same API as .set()
seen in the previous tutorial. If you want to use a single parameter for all radius
es in the entire network, do:
net.make_trainable(\"radius\")\n
Number of newly added trainable parameters: 1. Total number of trainable parameters: 1\n
We can also define parameters for individual compartments. To do this, use the \"all\"
key. The following defines a separate parameter the sodium conductance for every compartment in the entire network:
net.cell(\"all\").branch(\"all\").loc(\"all\").make_trainable(\"Leak_gLeak\")\n
Number of newly added trainable parameters: 18. Total number of trainable parameters: 19\n
"},{"location":"tutorial/07_gradient_descent/#making-synaptic-parameters-trainable","title":"Making synaptic parameters trainable","text":"Synaptic parameters can be made trainable in the exact same way. To use a single parameter for all syanptic conductances in the entire network, do
net.TanhRateSynapse.make_trainable(\"TanhRateSynapse_gS\")\n
Here, we use a different syanptic conductance for all syanpses. This can be done as follows:
net.TanhRateSynapse.edge(\"all\").make_trainable(\"TanhRateSynapse_gS\")\n
Number of newly added trainable parameters: 2. Total number of trainable parameters: 21\n
"},{"location":"tutorial/07_gradient_descent/#running-the-simulation","title":"Running the simulation","text":"Once all parameters are defined, you have to use .get_parameters()
to obtain all trainable parameters. This is also the time to check how many trainable parameters your network has:
params = net.get_parameters()\n
You can now run the simulation with the trainable parameters by passing them to the jx.integrate
function.
s = jx.integrate(net, params=params, t_max=10.0)\n
"},{"location":"tutorial/07_gradient_descent/#stimulating-the-network","title":"Stimulating the network","text":"The network above does not yet get any stimuli. We will use the 2D inputs from the dataset to stimulate the two input neurons. The amplitude of the step current corresponds to the input value. Below is the simulator that defines this:
def simulate(params, inputs):\n currents = jx.datapoint_to_step_currents(i_delay=1.0, i_dur=1.0, i_amp=inputs / 10, delta_t=0.025, t_max=10.0)\n\n data_stimuli = None\n data_stimuli = net.cell(0).branch(2).loc(1.0).data_stimulate(currents[0], data_stimuli=data_stimuli)\n data_stimuli = net.cell(1).branch(2).loc(1.0).data_stimulate(currents[1], data_stimuli=data_stimuli)\n\n return jx.integrate(net, params=params, data_stimuli=data_stimuli, delta_t=0.025)\n\nbatched_simulate = vmap(simulate, in_axes=(None, 0))\n
We can also inspect some traces:
traces = batched_simulate(params, inputs[:4])\n
fig, ax = plt.subplots(1, 1, figsize=(4, 2))\n_ = ax.plot(traces[:, 2, :].T)\n
"},{"location":"tutorial/07_gradient_descent/#defining-a-loss-function","title":"Defining a loss function","text":"Let us define a loss function to be optimized:
def loss(params, inputs, labels):\n traces = batched_simulate(params, inputs) # Shape `(batchsize, num_recordings, timepoints)`.\n prediction = jnp.mean(traces[:, 2], axis=1) # Use the average over time of the output neuron (2) as prediction.\n prediction = (prediction + 72.0) / 5 # Such that the prediction is roughly in [0, 1].\n losses = jnp.abs(prediction - labels) # Mean absolute error loss.\n return jnp.mean(losses) # Average across the batch.\n
And we can use JAX
\u2019s inbuilt functions to take the gradient through the entire ODE:
jitted_grad = jit(value_and_grad(loss, argnums=0))\n
value, gradient = jitted_grad(params, inputs[:4], labels[:4])\n
"},{"location":"tutorial/07_gradient_descent/#defining-parameter-transformations","title":"Defining parameter transformations","text":"Before training, however, we will enforce for all parameters to be within a prespecified range (such that, e.g., conductances can not become negative)
import jaxley.optimize.transforms as jt\n
# Define a function to create appropriate transforms for each parameter\ndef create_transform(name):\n if name == \"axial_resistivity\":\n # Must be positive; apply Softplus and scale to match initialization\n return jt.ChainTransform([jt.SoftplusTransform(0), jt.AffineTransform(5000, 0)])\n elif name == \"length\":\n # Apply Softplus and affine transform for the 'length' parameter\n return jt.ChainTransform([jt.SoftplusTransform(0), jt.AffineTransform(10, 0)])\n else:\n # Default to a Softplus transform for other parameters\n return jt.SoftplusTransform(0)\n\n# Apply the transforms to the parameters\ntransforms = [{k: create_transform(k) for k in param} for param in params]\ntf = jt.ParamTransform(transforms)\n
transform = jx.ParamTransform([{\"radius\": jt.SigmoidTransform(0.1, 5.0)},\n {\"Leak_gLeak\":jt.SigmoidTransform(1e-5, 1e-3)},\n {\"TanhRateSynapse_gS\" : jt.SigmoidTransform(1e-5, 1e-2)}])\n
With these modify the loss function acocrdingly:
def loss(opt_params, inputs, labels):\n transform.forward(opt_params)\n\n traces = batched_simulate(params, inputs) # Shape `(batchsize, num_recordings, timepoints)`.\n prediction = jnp.mean(traces[:, 2], axis=1) # Use the average over time of the output neuron (2) as prediction.\n prediction = (prediction + 72.0) # Such that the prediction is around 0.\n losses = jnp.abs(prediction - labels) # Mean absolute error loss.\n return jnp.mean(losses) # Average across the batch.\n
"},{"location":"tutorial/07_gradient_descent/#using-checkpointing","title":"Using checkpointing","text":"Checkpointing allows to vastly reduce the memory requirements of training biophysical models (see also JAX\u2019s full tutorial on checkpointing).
t_max = 5.0\ndt = 0.025\n\nlevels = 2\ntime_points = t_max // dt + 2\ncheckpoints = [int(np.ceil(time_points**(1/levels))) for _ in range(levels)]\n
To enable checkpointing, we have to modify the simulate
function appropriately and use
jx.integrate(..., checkpoint_inds=checkpoints)\n
as done below: def simulate(params, inputs):\n currents = jx.datapoint_to_step_currents(i_delay=1.0, i_dur=1.0, i_amp=inputs / 10.0, delta_t=dt, t_max=t_max)\n\n data_stimuli = None\n data_stimuli = net.cell(0).branch(2).loc(1.0).data_stimulate(currents[0], data_stimuli=data_stimuli)\n data_stimuli = net.cell(1).branch(2).loc(1.0).data_stimulate(currents[1], data_stimuli=data_stimuli)\n\n return jx.integrate(net, params=params, data_stimuli=data_stimuli, checkpoint_lengths=checkpoints)\n\nbatched_simulate = vmap(simulate, in_axes=(None, 0))\n\n\ndef predict(params, inputs):\n traces = simulate(params, inputs) # Shape `(batchsize, num_recordings, timepoints)`.\n prediction = jnp.mean(traces[2]) # Use the average over time of the output neuron (2) as prediction.\n return prediction + 72.0 # Such that the prediction is around 0.\n\nbatched_predict = vmap(predict, in_axes=(None, 0))\n\n\ndef loss(opt_params, inputs, labels):\n params = transform.forward(opt_params)\n\n predictions = batched_predict(params, inputs)\n losses = jnp.abs(predictions - labels) # Mean absolute error loss.\n return jnp.mean(losses) # Average across the batch.\n\njitted_grad = jit(value_and_grad(loss, argnums=0))\n
"},{"location":"tutorial/07_gradient_descent/#training","title":"Training","text":"We will use the ADAM optimizer from the optax library to optimize the free parameters (you have to install the package with pip install optax
first):
import optax\n
opt_params = transform.inverse(params)\noptimizer = optax.adam(learning_rate=0.01)\nopt_state = optimizer.init(opt_params)\n
"},{"location":"tutorial/07_gradient_descent/#writing-a-dataloader","title":"Writing a dataloader","text":"Below, we just write our own (very simple) dataloader. Alternatively, you could use the dataloader from any deep learning library such as pytorch or tensorflow:
class Dataset:\n def __init__(self, inputs: np.ndarray, labels: np.ndarray):\n \"\"\"Simple Dataloader.\n\n Args:\n inputs: Array of shape (num_samples, num_dim)\n labels: Array of shape (num_samples,)\n \"\"\"\n assert len(inputs) == len(labels), \"Inputs and labels must have same length\"\n self.inputs = inputs\n self.labels = labels\n self.num_samples = len(inputs)\n self._rng_state = None\n self.batch_size = 1\n\n def shuffle(self, seed=None):\n \"\"\"Shuffle the dataset in-place\"\"\"\n self._rng_state = np.random.get_state()[1][0] if seed is None else seed\n np.random.seed(self._rng_state)\n indices = np.random.permutation(self.num_samples)\n self.inputs = self.inputs[indices]\n self.labels = self.labels[indices]\n return self\n\n def batch(self, batch_size):\n \"\"\"Create batches of the data\"\"\"\n self.batch_size = batch_size\n return self\n\n def __iter__(self):\n self.shuffle(seed=self._rng_state)\n for start in range(0, self.num_samples, self.batch_size):\n end = min(start + self.batch_size, self.num_samples)\n yield self.inputs[start:end], self.labels[start:end]\n self._rng_state += 1\n
"},{"location":"tutorial/07_gradient_descent/#training-loop","title":"Training loop","text":"batch_size = 4\ndataloader = Dataset(inputs, labels)\ndataloader = dataloader.shuffle(seed=0).batch(batch_size)\n\nfor epoch in range(10):\n epoch_loss = 0.0\n\n for batch_ind, batch in enumerate(dataloader):\n current_batch, label_batch = batch\n loss_val, gradient = jitted_grad(opt_params, current_batch, label_batch)\n updates, opt_state = optimizer.update(gradient, opt_state)\n opt_params = optax.apply_updates(opt_params, updates)\n epoch_loss += loss_val\n\n print(f\"epoch {epoch}, loss {epoch_loss}\")\n\nfinal_params = transform.forward(opt_params)\n
epoch 0, loss 25.033223182772293\nepoch 1, loss 21.00894915349165\nepoch 2, loss 15.092242959956026\nepoch 3, loss 9.061544660383163\nepoch 4, loss 6.925509860325612\nepoch 5, loss 6.273630037897756\nepoch 6, loss 6.1757316054693145\nepoch 7, loss 6.135132525725265\nepoch 8, loss 6.145608619185389\nepoch 9, loss 6.135660902068834\n
ntest = 32\npredictions = batched_predict(final_params, inputs[:ntest])\n
fig, ax = plt.subplots(1, 1, figsize=(3, 2))\n_ = ax.scatter(labels[:ntest], predictions)\n_ = ax.set_xlabel(\"Label\")\n_ = ax.set_ylabel(\"Prediction\")\n
Indeed, the loss goes down and the network successfully classifies the patterns.
"},{"location":"tutorial/07_gradient_descent/#summary","title":"Summary","text":"Puh, this was a pretty dense tutorial with a lot of material. You should have learned how to:
This was the last \u201cbasic\u201d tutorial of the Jaxley
toolbox. If you want to learn more, check out our Advanced Tutorials. If anything is still unclear please create a discussion. If you find any bugs, please open an issue. Happy coding!
In this tutorial, you will learn how to:
Jaxley
Here is a code snippet which you will learn to understand in this tutorial:
import jaxley as jx\n\ncell = jx.read_swc(\"my_cell.swc\", ncomp=4)\ncell.branch(2).set_ncomp(2) # Modify the number of compartments of a branch.\n
To work with more complicated morphologies, Jaxley
supports importing morphological reconstructions via .swc
files. .swc
is currently the only supported format. Other formats like .asc
need to be converted to .swc
first, for example using the BlueBrain\u2019s morph-tool. For more information on the exact specifications of .swc
see here.
import jaxley as jx\nfrom jaxley.synapses import IonotropicSynapse\nimport matplotlib.pyplot as plt\n
To work with .swc
files, Jaxley
implements a custom .swc
reader. The reader traces the morphology and identifies all uninterrupted sections. These are then partitioned into branches, each of which will be approximated by a number of equally many compartments that can be simulated fully in parallel.
To demonstrate this, let\u2019s import an example morphology of a Layer 5 pyramidal cell and visualize it.
# import swc file into jx.Cell object\nfname = \"data/morph.swc\"\ncell = jx.read_swc(fname, ncomp=8) # Use four compartments per branch.\n\n# print shape (num_branches, num_comps)\nprint(cell.shape)\n\ncell.show()\n
(157, 1256)\n
local_comp_index global_comp_index local_branch_index global_branch_index local_cell_index global_cell_index 0 0 0 0 0 0 0 1 1 1 0 0 0 0 2 2 2 0 0 0 0 3 3 3 0 0 0 0 4 4 4 0 0 0 0 ... ... ... ... ... ... ... 1251 3 1251 156 156 0 0 1252 4 1252 156 156 0 0 1253 5 1253 156 156 0 0 1254 6 1254 156 156 0 0 1255 7 1255 156 156 0 0 1256 rows \u00d7 6 columns
As we can see, this yields a morphology that is approximated by 1256 compartments. Depending on the amount of detail that you need, you can also change the number of compartments in each branch:
cell = jx.read_swc(fname, ncomp=2)\n\n# print shape (num_branches, num_comps)\nprint(cell.shape)\n\ncell.show()\n
(157, 314)\n
local_comp_index global_comp_index local_branch_index global_branch_index local_cell_index global_cell_index 0 0 0 0 0 0 0 1 1 1 0 0 0 0 2 0 2 1 1 0 0 3 1 3 1 1 0 0 4 0 4 2 2 0 0 ... ... ... ... ... ... ... 309 1 309 154 154 0 0 310 0 310 155 155 0 0 311 1 311 155 155 0 0 312 0 312 156 156 0 0 313 1 313 156 156 0 0 314 rows \u00d7 6 columns
The above assigns the same number of compartments to every branch. To use a different number of compartments in individual branches, you can use .set_ncomp()
:
cell.branch(1).set_ncomp(4)\n
As you can see below, branch 0
has two compartments (because this is what was passed to jx.read_swc(..., ncomp=2)
), but branch 1
has four compartments:
cell.branch([0, 1]).nodes\n
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v global_cell_index global_branch_index global_comp_index controlled_by_param 0 0 0 0 0.050000 8.119000 5000.0 1.0 -70.0 0 0 0 0 1 0 0 1 0.050000 8.119000 5000.0 1.0 -70.0 0 0 1 0 2 0 1 0 3.120779 7.806172 5000.0 1.0 -70.0 0 1 2 1 3 0 1 1 3.120779 7.111231 5000.0 1.0 -70.0 0 1 3 1 4 0 1 2 3.120779 5.652394 5000.0 1.0 -70.0 0 1 4 1 5 0 1 3 3.120779 3.869247 5000.0 1.0 -70.0 0 1 5 1 Once imported the compartmentalized morphology can be viewed using vis
.
# visualize the cell\ncell.vis()\nplt.axis(\"off\")\nplt.title(\"L5PC\")\nplt.show()\n
vis
can be called on any jx.Module
and every View
of the module. This means we can also for example use vis
to highlight each branch. This can be done by iterating over each branch index and calling cell.branch(i).vis()
. Within the loop.
fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n# define colorwheel with 10 colors\ncolors = plt.cm.tab10.colors\nfor i, branch in enumerate(cell.branches):\n branch.vis(ax=ax, col=colors[i % 10])\nplt.axis(\"off\")\nplt.title(\"Branches\")\nplt.show()\n
While we only use two compartments to approximate each branch in this example, we can see the morphology is still plotted in great detail. This is because we always plot the full .swc
reconstruction irrespective of the number of compartments used. The morphology lives seperately in the cell.xyzr
attribute in a per branch fashion.
In addition to plotting the full morphology of the cell using points vis(type=\"scatter\")
or lines vis(type=\"line\")
, Jaxley
also supports plotting a detailed morphological vis(type=\"morph\")
or approximate compartmental reconstruction vis(type=\"comp\")
that correctly considers the thickness of the neurite. Note that \"comp\"
plots the lengths of each compartment which is equal to the length of the traced neurite. While neurites can be zigzaggy, the compartments that approximate them are straight lines. This can lead to miss-aligment of the compartment ends. For details see the documentation of vis
.
The morphologies can either be projected onto 2D or also rendered in 3D.
# visualize the cell\nfig, ax = plt.subplots(1, 4, figsize=(10, 3), layout=\"constrained\", sharex=True, sharey=True)\ncell.vis(ax=ax[0], type=\"morph\", dims=[0,1])\ncell.vis(ax=ax[1], type=\"comp\", dims=[0,1])\ncell.vis(ax=ax[2], type=\"scatter\", dims=[0,1], morph_plot_kwargs={\"s\": 1})\ncell.vis(ax=ax[3], type=\"line\", dims=[0,1])\nfig.suptitle(\"Comparison of plot types\")\nplt.show()\n
# set to interactive mode\n# %matplotlib notebook\n
# plot in 3D\nfig = plt.figure()\nax = fig.add_subplot(111, projection='3d')\ncell.vis(ax=ax, type=\"line\", dims=[2,0,1])\nax.view_init(elev=20, azim=5)\nplt.show()\n
Since Jaxley
supports grouping different branches or compartments together, we can also use the id
labels provided by the .swc
file to assign group labels to the jx.Cell
object.
print(list(cell.groups.keys()))\n\nfig, ax = plt.subplots(1, 1, figsize=(5, 5))\ncolors = plt.cm.tab10.colors\ncell.basal.vis(ax=ax, col=colors[2])\ncell.soma.vis(ax=ax, col=colors[1])\ncell.apical.vis(ax=ax, col=colors[0])\nplt.axis(\"off\")\nplt.title(\"Groups\")\nplt.show()\n
['soma', 'basal', 'apical', 'custom']\n
To build a network of morphologically detailed cells, we can now connect several reconstructed cells together and also visualize the network. However, since all cells are going to have the same center, Jaxley
will naively plot all of them on top of each other. To seperate out the cells, we therefore have to move them to a new location first.
net = jx.Network([cell]*5)\njx.connect(net[0,0,0], net[2,0,0], IonotropicSynapse())\njx.connect(net[0,0,0], net[3,0,0], IonotropicSynapse())\njx.connect(net[0,0,0], net[4,0,0], IonotropicSynapse())\n\njx.connect(net[1,0,0], net[2,0,0], IonotropicSynapse())\njx.connect(net[1,0,0], net[3,0,0], IonotropicSynapse())\njx.connect(net[1,0,0], net[4,0,0], IonotropicSynapse())\n\nnet.rotate(-90)\n\nnet.cell(0).move(0, 300)\nnet.cell(1).move(0, 500)\n\nnet.cell(2).move(900, 200)\nnet.cell(3).move(900, 400)\nnet.cell(4).move(900, 600)\n\nnet.vis()\nplt.axis(\"off\")\nplt.show()\n
Congrats! You have now learned how to vizualize and build networks out of very complex morphologies. To simulate this network, you can follow the steps in the tutorial on how to build a network.
"},{"location":"tutorial/09_advanced_indexing/","title":"Customizing synaptic parameters","text":"In this tutorial, you will learn how to:
select()
method to fully customize network simulations with Jaxley
. copy_node_property_to_edges()
method to flexibly modify synapses. Here is a code snippet which you will learn to understand in this tutorial:
net = ... # See tutorial on Basics of Jaxley.\n\n# Set synaptic conductance of the synapse with index 0 and 1.\nnet.select(edges=[0, 1]).set(\"Ionotropic_gS\", 0.1)\n\n# Set synaptic conductance of all synapses that have cells 3 or 4 as presynaptic neuron.\nnet.copy_node_property_to_edges(\"global_cell_index\")\ndf = net.edges\ndf = df.query(\"pre_global_cell_index in [3, 4]\")\nnet.select(edges=df.index).set(\"Ionotropic_gS\", 0.2)\n\n# Set synaptic conductance of all synapses that\n# 1) have cells 2 or 3 as presynaptic neuron and\n# 2) has cell 5 as postsynaptic neuron\ndf = net.edges\ndf = df.query(\"pre_global_cell_index in [2, 3]\")\ndf = df.query(\"post_global_cell_index == 5\")\nnet.select(edges=df.index).set(\"Ionotropic_gS\", 0.3)\n
In a previous tutorial you learned how to set parameters of a jx.Network
. In that tutorial, we briefly mentioned the select()
method which allowed to set individual synapses to particular values. In this tutorial, we will go into detail in how you can fully customize your Jaxley
simulation.
Let\u2019s go!
import jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.connect import fully_connect\nfrom jaxley.synapses import IonotropicSynapse\n
"},{"location":"tutorial/09_advanced_indexing/#preface-building-the-network","title":"Preface: Building the network","text":"We first build a network consisting of six neurons, in the same way as we showed in the previous tutorials:
dt = 0.025\nt_max = 10.0\n\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, nseg=2)\ncell = jx.Cell(branch, parents=[-1, 0])\nnet = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n
"},{"location":"tutorial/09_advanced_indexing/#setting-individual-synapse-parameters","title":"Setting individual synapse parameters","text":"As always, you can use the .edges
table to inspect synaptic parameters of the network:
net.edges\n
global_edge_index pre_global_comp_index post_global_comp_index type type_ind pre_locs post_locs IonotropicSynapse_gS IonotropicSynapse_e_syn IonotropicSynapse_k_minus IonotropicSynapse_s controlled_by_param 0 0 0 13 IonotropicSynapse 0 0.25 0.75 0.0001 0.0 0.025 0.2 0 1 1 0 19 IonotropicSynapse 0 0.25 0.75 0.0001 0.0 0.025 0.2 0 2 2 0 20 IonotropicSynapse 0 0.25 0.25 0.0001 0.0 0.025 0.2 0 3 3 4 12 IonotropicSynapse 0 0.25 0.25 0.0001 0.0 0.025 0.2 0 4 4 4 16 IonotropicSynapse 0 0.25 0.25 0.0001 0.0 0.025 0.2 0 5 5 4 21 IonotropicSynapse 0 0.25 0.75 0.0001 0.0 0.025 0.2 0 6 6 8 13 IonotropicSynapse 0 0.25 0.75 0.0001 0.0 0.025 0.2 0 7 7 8 17 IonotropicSynapse 0 0.25 0.75 0.0001 0.0 0.025 0.2 0 8 8 8 21 IonotropicSynapse 0 0.25 0.75 0.0001 0.0 0.025 0.2 0 This table has nine rows, each corresponding to one synapse. This makes sense because we fully connected three neurons (0, 1, 2) to three other neurons (3, 4, 5), giving a total of 3x3=9
synapses.
You can modify parameters of individual synapses as follows:
net.select(edges=[3, 4, 5]).set(\"IonotropicSynapse_gS\", 0.2)\n
Above, we are modifying the synapses with indices [3, 4, 5]
(i.e., the indices of the net.edges
DataFrame). The resulting values are indeed changed:
net.edges.IonotropicSynapse_gS\n
0 0.0001\n1 0.0001\n2 0.0001\n3 0.2000\n4 0.2000\n5 0.2000\n6 0.0001\n7 0.0001\n8 0.0001\nName: IonotropicSynapse_gS, dtype: float64\n
"},{"location":"tutorial/09_advanced_indexing/#example-1-setting-synaptic-parameters-which-connect-particular-neurons","title":"Example 1: Setting synaptic parameters which connect particular neurons","text":"This is great, but setting synaptic parameters just by their index can be exhausting, in particular in very large networks. Instead, we would want to, for example, set the maximal conductance of all synapses that connect from cell 0 or 1 to any other neuron.
In Jaxley
, such customization can be achieved by filtering the .edges
dataframe accordingly, as shown below:
net = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n\nnet.copy_node_property_to_edges(\"global_cell_index\")\ndf = net.edges\ndf = df.query(\"pre_global_cell_index in [0, 1]\")\nnet.select(edges=df.index).set(\"IonotropicSynapse_gS\", 0.23)\n
net.edges.IonotropicSynapse_gS\n
0 0.2300\n1 0.2300\n2 0.2300\n3 0.2300\n4 0.2300\n5 0.2300\n6 0.0001\n7 0.0001\n8 0.0001\nName: IonotropicSynapse_gS, dtype: float64\n
Indeed, the first six synapses now have the value 0.23
! Let\u2019s look at the individual lines to understand how this worked:
We want to set parameter by cell index. However, by default, the pre- or post-synaptic cell-indices are not listed in net.edges
. We can add the cell index to the .edges
dataframe by calling .copy_node_property_to_edges()
:
net.copy_node_property_to_edges(\"global_cell_index\")\n
After this, the pre- and post-synaptic cell indices are listed in net.edges
as pre_global_cell_index
and post_global_cell_index
.
Next, we take .edges
, which is a pandas DataFrame:
df = net.edges\n
We then modify this DataFrame to only contain those rows where the global cell index is in 0 or 1:
df = df.query(\"pre_global_cell_index in [0, 1]\")\n
For the above step, you use any column of the DataFrame to filter it (you can see all columns with df.columns
). Note that, while we used .query()
here, you can really filter the pandas DataFrame however you want. For example, the query
above is identical to df = df[df[\"pre_global_cell_index\"].isin([0, 1])]
.
Finally, we use the .select()
method, which returns a subset of the Network
at the specified indices. This subset of the network can be modified with .set()
:
net.select(edges=df.index).set(\"IonotropicSynapse_gS\", 0.23)\n
"},{"location":"tutorial/09_advanced_indexing/#example-2-setting-parameters-given-pre-and-post-synaptic-cell-indices","title":"Example 2: Setting parameters given pre- and post-synaptic cell indices","text":"Say you want to select all synapses that have cells 1 or 2 as presynaptic neuron and cell 4 or 5 as postsynaptic neuron.
net = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n
Just like before, we can simply use .query()
as already shown above. However, this time, call .query()
to twice to filter by pre- and post-synaptic cell indices:
net.copy_node_property_to_edges(\"global_cell_index\")\n\ndf = net.edges\ndf = df.query(\"pre_global_cell_index in [1, 2]\")\ndf = df.query(\"post_global_cell_index in [4, 5]\")\nnet.select(edges=df.index).set(\"IonotropicSynapse_gS\", 0.3)\n
net.edges.IonotropicSynapse_gS\n
0 0.0001\n1 0.0001\n2 0.0001\n3 0.0001\n4 0.3000\n5 0.3000\n6 0.0001\n7 0.3000\n8 0.3000\nName: IonotropicSynapse_gS, dtype: float64\n
"},{"location":"tutorial/09_advanced_indexing/#example-3-applying-this-strategy-to-cell-level-parameters","title":"Example 3: Applying this strategy to cell level parameters","text":"You had previously seen that you can modify parameters with, e.g., net.cell(0).set(...)
. However, if you need more flexibility than this, you can also use the above strategy to modify cell-level parameters:
net = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n\ndf = net.nodes\ndf = df.query(\"global_cell_index in [0, 1]\")\nnet.select(nodes=df.index).set(\"radius\", 0.1)\n
"},{"location":"tutorial/09_advanced_indexing/#example-4-flexibly-setting-parameters-based-on-their-groups","title":"Example 4: Flexibly setting parameters based on their groups
","text":"If you are using groups, as shown in this tutorial, then you can also use this for querying synapses. To demonstrate this, let\u2019s create a group of excitatory neurons (e.g., cells 0, 3, 5):
# Redefine network.\nnet = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n\nnet.cell([0, 3, 5]).add_to_group(\"exc\")\n
Now, say we want all synapses that start from these excitatory neurons. You can do this as follows:
# First, we have to identify which cells are in the `exc` group.\nindices_of_excitatory_cells = net.exc.nodes[\"global_cell_index\"].unique().tolist() # [0, 3, 5]\n\n# Then we can proceed as before:\nnet.copy_node_property_to_edges(\"global_cell_index\")\ndf = net.edges\ndf = df.query(f\"pre_global_cell_index in {indices_of_excitatory_cells}\")\nnet.select(edges=df.index).set(\"IonotropicSynapse_gS\", 0.4)\n
"},{"location":"tutorial/09_advanced_indexing/#example-5-setting-synaptic-parameters-based-on-properties-of-the-presynaptic-cell","title":"Example 5: Setting synaptic parameters based on properties of the presynaptic cell","text":"Let\u2019s discuss one more example: Imagine we only want to modify those synapses whose presynaptic compartment has a sodium channel. Let\u2019s first add a sodium channel to some of the cells:
net = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n\nnet.cell(0).branch(0).comp(0).insert(Na())\nnet.cell(2).branch(1).comp(1).insert(Na())\n
Now, let us query which cells have the desired synapses:
df = net.nodes\ndf = df.query(\"Na\")\nindices_of_sodium_compartments = df[\"global_comp_index\"].unique().tolist()\n
indices_of_sodium_compartments
lists all compartments which contained sodium:
print(indices_of_sodium_compartments)\n
[0, 11]\n
Then, we can proceed as always and filter for the global pre-synaptic compartment index:
df = net.edges\ndf = df.query(f\"pre_global_comp_index in {indices_of_sodium_compartments}\")\nnet.select(edges=df.index).set(\"IonotropicSynapse_gS\", 0.6)\n
net.edges.IonotropicSynapse_gS\n
0 0.6000\n1 0.6000\n2 0.6000\n3 0.0001\n4 0.0001\n5 0.0001\n6 0.0001\n7 0.0001\n8 0.0001\nName: IonotropicSynapse_gS, dtype: float64\n
Indeed, only synapses coming from the first neuron were modified (as its presynaptic compartment contained sodium), in contrast to synapses from neuron 2 (whose presynaptic compartment did not).
"},{"location":"tutorial/09_advanced_indexing/#summary","title":"Summary","text":"In this tutorial, you learned how to fully customize your Jaxley
simulation. This works by querying rows from the .edges
DataFrame.
In this tutorial, you will learn how to:
Here is a code snippet which you will learn to understand in this tutorial:
net = ... # See tutorial on Basics of Jaxley.\n\n# The same parameter for all synapses\nnet.make_trainable(\"Ionotropic_gS\")\n\n# An individual parameter for every synapse.\nnet.select(edges=\"all\").make_trainable(\"Ionotropic_gS\")\n\n# Share synaptic conductances emerging from the same neurons.\nnet.copy_node_property_to_edges(\"cell_index\")\nsub_net = net.select(edges=[0, 1, 2])\nsub_net.edges[\"controlled_by_param\"] = sub_net.edges[\"pre_global_cell_index\"]\nsub_net.make_trainable(\"Ionotropic_gS\")\n
In a previous tutorial about training networks, we briefly touched on parameter sharing. In this tutorial, we will show you how you can flexibly share parameters within a network.
import jaxley as jx\nfrom jaxley.channels import Na, K, Leak\nfrom jaxley.connect import fully_connect\nfrom jaxley.synapses import IonotropicSynapse\n
"},{"location":"tutorial/10_advanced_parameter_sharing/#preface-building-the-network","title":"Preface: Building the network","text":"We first build a network consisting of six neurons, in the same way as we showed in the previous tutorials:
dt = 0.025\nt_max = 10.0\n\ncomp = jx.Compartment()\nbranch = jx.Branch(comp, ncomp=2)\ncell = jx.Cell(branch, parents=[-1, 0])\nnet = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell([0, 1, 2]), net.cell([3, 4, 5]), IonotropicSynapse())\n
"},{"location":"tutorial/10_advanced_parameter_sharing/#sharing-parameters-by-modifying-controlled_by_param","title":"Sharing parameters by modifying controlled_by_param
","text":"net.copy_node_property_to_edges(\"global_cell_index\")\n\ndf = net.edges\ndf = df.query(\"pre_global_cell_index in [0, 1, 2]\")\nsubnetwork = net.select(edges=df.index)\n\ndf = subnetwork.edges\ndf[\"controlled_by_param\"] = df[\"pre_global_cell_index\"]\nsubnetwork.make_trainable(\"IonotropicSynapse_gS\")\n
Number of newly added trainable parameters: 3. Total number of trainable parameters: 3\n
Let\u2019s look at this line by line. First, we exactly follow the previous tutorial in selecting the synapses which we are interested in training (i.e., the ones whose presynaptic neuron has index 0, 1, 2):
df = net.edges\ndf = df.query(\"pre_global_cell_index in [0, 1, 2]\")\nsubnetwork = net.select(edges=df.index)\n
As second step, we enable parameter sharing. This is done by setting the controlled_by_param
. Synapses that have the same value in controlled_by_param
will be shared. Let\u2019s inspect controlled_by_param
before we modify it:
subnetwork.edges[[\"pre_global_cell_index\", \"controlled_by_param\"]]\n
pre_global_cell_index controlled_by_param 0 0 0 1 0 1 2 0 2 3 1 3 4 1 4 5 1 5 6 2 6 7 2 7 8 2 8 Every synapse has a different value. Because of this, no synaptic parameters will be shared. To enable parameter sharing we override the controlled_by_param
column with the presynaptic cell index:
df = subnetwork.edges\ndf[\"controlled_by_param\"] = df[\"pre_global_cell_index\"]\n
df[[\"pre_global_cell_index\", \"controlled_by_param\"]]\n
pre_global_cell_index controlled_by_param 0 0 0 1 0 0 2 0 0 3 1 1 4 1 1 5 1 1 6 2 2 7 2 2 8 2 2 Now, all we have to do is to make these synaptic parameters trainable with the make_trainable()
method:
subnetwork.make_trainable(\"IonotropicSynapse_gS\")\n
Number of newly added trainable parameters: 3. Total number of trainable parameters: 6\n
It correctly says that we added three parameters (because we have three cells, and we share individual synaptic parameters). We now have 6 trainable parameters in total (because we already added 3 trainable parameters above).
"},{"location":"tutorial/10_advanced_parameter_sharing/#a-more-involved-example-sharing-by-pre-and-post-synaptic-cell-type","title":"A more involved example: sharing by pre- and post-synaptic cell type","text":"As an example, consider the following: We have a fully connected network of six cells. Each cell falls into one of three cell types:
from typing import Union, List\n
net = jx.Network([cell for _ in range(6)])\nfully_connect(net.cell(\"all\"), net.cell(\"all\"), IonotropicSynapse())\n\nnet.cell([0, 1]).add_to_group(\"exc\")\nnet.cell([2, 3]).add_to_group(\"inh\")\nnet.cell([4, 5]).add_to_group(\"unknown\")\n
We want to make all synapses that start from excitatory or inhibitory neurons trainable. In addition, we want to use the same parameter for synapses if they have the same pre- and post-synaptic cell type.
To achieve this, we will first want a column in net.nodes
which indicates the cell type.
for group, inds in net.groups.items():\n net.nodes.loc[inds, \"cell_type\"] = group\n
net.nodes[\"cell_type\"]\n
0 exc\n1 exc\n2 exc\n3 exc\n4 exc\n5 exc\n6 exc\n7 exc\n8 inh\n9 inh\n10 inh\n11 inh\n12 inh\n13 inh\n14 inh\n15 inh\n16 unknown\n17 unknown\n18 unknown\n19 unknown\n20 unknown\n21 unknown\n22 unknown\n23 unknown\nName: cell_type, dtype: object\n
The cell_type
is now part of the net.nodes
. However, we would like to do parameter sharing of synapses based on the pre- and post-synaptic node values. To do so, we import the cell_type
column into net.edges
. To do this, we use the .copy_node_property_to_edges()
which the name of the property you are copying from nodes:
net.copy_node_property_to_edges(\"cell_type\")\n
After this, you have columns in the .edges
which indicate the pre- and post-synaptic cell type:
net.edges[[\"pre_cell_type\", \"post_cell_type\"]]\n
pre_cell_type post_cell_type 0 exc exc 1 exc exc 2 exc inh 3 exc inh 4 exc unknown 5 exc unknown 6 exc exc 7 exc exc 8 exc inh 9 exc inh 10 exc unknown 11 exc unknown 12 inh exc 13 inh exc 14 inh inh 15 inh inh 16 inh unknown 17 inh unknown 18 inh exc 19 inh exc 20 inh inh 21 inh inh 22 inh unknown 23 inh unknown 24 unknown exc 25 unknown exc 26 unknown inh 27 unknown inh 28 unknown unknown 29 unknown unknown 30 unknown exc 31 unknown exc 32 unknown inh 33 unknown inh 34 unknown unknown 35 unknown unknown Next, we specify which parts of the network we actually want to change (in this case, all synapses which have excitatory or inhibitory presynaptic neurons):
df = net.edges\ndf = df.query(f\"pre_cell_type in ['exc', 'inh']\")\nprint(f\"There are {len(df)} synapses to be changed.\")\n\nsubnetwork = net.select(edges=df.index)\n
There are 24 synapses to be changed.\n
As the last step, we again have to specify parameter sharing by setting controlled_by_param
. In this case, we want to share parameters that have the same pre- and post-synaptic neuron. We achieve this by grouping the synpases by their pre- and post-synaptic cell type (see pd.DataFrame.groupby for details):
# Step 6: use groupby to specify parameter sharing and make the parameters trainable.\nsubnetwork.edges[\"controlled_by_param\"] = subnetwork.edges.groupby([\"pre_cell_type\", \"post_cell_type\"]).ngroup()\nsubnetwork.make_trainable(\"IonotropicSynapse_gS\")\n
Number of newly added trainable parameters: 6. Total number of trainable parameters: 6\n
This created six trainable parameters, which makes sense as we have two types of pre-synaptic neurons (excitatory and inhibitory) and each has three options for the postsynaptic neuron (pre, post, unknown).
"},{"location":"tutorial/10_advanced_parameter_sharing/#summary","title":"Summary","text":"In this tutorial, you learned how you can flexibly share synaptic parameters. This works by first using select()
to identify which synapses to make trainable, and by then modifying controlled_by_param
to customize parameter sharing.